infer.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #include <unistd.h>
00068 #include <fstream>
00069 #include <climits>
00070 #include <sys/times.h>
00071 #include "fol.h"
00072 #include "arguments.h"
00073 #include "util.h"
00074 #include "infer.h"
00075 
00076 extern const char* ZZ_TMP_FILE_POSTFIX; //defined in fol.y
00077 
00078   // TODO: List the arguments common to learnwts and inference in
00079   // inferenceargs.h. This can't be done with a static array.
00080 ARGS ARGS::Args[] = 
00081 {
00082     // BEGIN: Common arguments
00083   ARGS("i", ARGS::Req, ainMLNFiles, 
00084        "Comma-separated input .mln files."),
00085 
00086   ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00087        "Specified non-evidence atoms (comma-separated with no space) are "
00088        "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00089        "appearing here cannot be query atoms and cannot appear in the -o "
00090        "option."),
00091 
00092   ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00093        "Specified evidence atoms (comma-separated with no space) are open "
00094        "world, while other evidence atoms are closed-world. "
00095        "Atoms appearing here cannot appear in the -c option."),
00096     // END: Common arguments
00097 
00098   ARGS("queryEvidence", ARGS::Tog, aisQueryEvidence, 
00099        "If this flag is set, then all the groundings of query preds not in db "
00100        "are assumed false evidence."),
00101 
00102     // BEGIN: Common inference arguments
00103   ARGS("m", ARGS::Tog, amapPos, 
00104        "Run MAP inference and return only positive query atoms."),
00105 
00106   ARGS("a", ARGS::Tog, amapAll, 
00107        "Run MAP inference and show 0/1 results for all query atoms."),
00108 
00109   ARGS("p", ARGS::Tog, agibbsInfer, 
00110        "Run inference using MCMC (Gibbs sampling) and return probabilities "
00111        "for all query atoms."),
00112   
00113   ARGS("ms", ARGS::Tog, amcsatInfer,
00114        "Run inference using MC-SAT and return probabilities "
00115        "for all query atoms"),
00116 
00117   ARGS("bp", ARGS::Tog, abpInfer,
00118        "Run inference using belief propagation and return probabilities "
00119        "for all query atoms"),
00120 
00121   ARGS("simtp", ARGS::Tog, asimtpInfer,
00122        "Run inference using simulated tempering and return probabilities "
00123        "for all query atoms"),
00124 
00125   ARGS("outputNetwork", ARGS::Tog, aoutputNetwork,
00126        "Build the network and output to results file, instead of "
00127        "running inference"),
00128 
00129   ARGS("counts", ARGS::Tog, aclauseCounts,
00130        "Write clause counts, not atom values or probabilities"),
00131 
00132   ARGS("seed", ARGS::Opt, aSeed,
00133        "[2350877] Seed used to initialize the randomizer in the inference "
00134        "algorithm. If not set, seed is initialized from a fixed random number."),
00135 
00136   ARGS("lazy", ARGS::Opt, aLazy, 
00137        "[false] Run lazy version of inference if this flag is set."),
00138   
00139   ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox, 
00140        "[false] Lazy version of inference will not approximate by deactivating "
00141        "atoms to save memory. This flag is ignored if -lazy is not set."),
00142   
00143   ARGS("memLimit", ARGS::Opt, aMemLimit, 
00144        "[-1] Maximum limit in kbytes which should be used for inference. "
00145        "-1 means main memory available on system is used."),
00146 
00147   ARGS("GndPredIdxMapFile", ARGS::Opt, aGndPredIdxMapFile,
00148        "File containing the mapping from ground predicates (dis & num) to their idx"),
00149   
00150   ARGS("PrintSamplePerIteration", ARGS::Opt, aPrintSamplePerIteration,
00151        "Whether to print out variable values at each HMCS sample round."),
00152         
00153   ARGS("SAInterval", ARGS::Opt, saInterval, "SA interval"),
00154 
00155   ARGS("LineNum", ARGS::Opt, aLineNum, "LineNum"),
00156 
00157   ARGS("LinePara", ARGS::Opt, aLinePara, "LinePara"),   
00158 
00159   ARGS("LineName", ARGS::Opt, aLineName, "LineName"),
00160 
00161   ARGS("MaxSeconds", ARGS::Opt, aMaxSeconds, "Max seconds for HMWS and SA."),
00162 
00163   ARGS("StartPt", ARGS::Opt, aStartPt, "Starting from a fixed point, load from testcont and testdis."),    
00164 
00165   ARGS("GenRandom", ARGS::Opt, aGenRandom, "generate random assignment as starting point, saving in testcont and testdis."),
00166 
00167   ARGS("SATempDownRation", ARGS::Opt, aSATempDownRatio, "simulated annealing temperature degrade ratio."),
00168 
00169   ARGS("SA", ARGS::Opt, aSA, "simulated annealing inference."),
00170 
00171   ARGS("MWSRST", ARGS::Opt, aMWSrst, "Result file for MaxwalkSAT inference."),
00172 
00173   ARGS("noisynum", ARGS::Opt, anumerator, 
00174        "numerator value for noisy pick in HMWS."),    
00175 
00176   ARGS("noisyden", ARGS::Opt, adenominator, 
00177        "denominator value for noisy pick in HMWS."),    
00178 
00179   ARGS("hmwsDis", ARGS::Opt, aHMWSDis, 
00180        "dis inference result file for hybrid maxwalksat."),    
00181 
00182   ARGS("testcont", ARGS::Opt, atestcont, 
00183        "file path containing assignment to cont variables."),    
00184 
00185   ARGS("testdis", ARGS::Opt, atestdis, 
00186        "file path containing assignment to dis variables."),    
00187 
00188   ARGS("mwsMax", ARGS::Opt, aMaxOrNot, 
00189        "[false] (MaxWalkSat) If false,MC-SAT uses WalkSAT, if true, MC-SAT use max-WalkSAT. "),  
00190 
00191   ARGS("hybrid", ARGS::Opt, aHybrid, 
00192            "Flag for HMLN inferences."),
00193 
00194   ARGS("propstdev", ARGS::Opt, aProposalStdev, 
00195            "[1.0]Proposal stdev for SA step in HybridSAT."),
00196 
00197   ARGS("conti", ARGS::Opt, aContPartFile, 
00198            "input file containing the description of continuous part of hybrid MLN."),
00199 
00200   ARGS("segl", ARGS::Opt, aSegListFile, 
00201            "input file containing the list of grounded segments."),
00202 
00203   ARGS("contSamples", ARGS::Opt, aContSamples, 
00204            "output file for continuous variable samples."),
00205 
00206     // END: Common inference arguments
00207 
00208     // BEGIN: MaxWalkSat args
00209   ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00210        "[100000] (MaxWalkSat) The max number of steps taken."),
00211 
00212   ARGS("tries", ARGS::Opt, amwsTries, 
00213        "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00214 
00215   ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00216        "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00217        "with weight <= specified weight."),
00218 
00219   ARGS("breakHardClauses", ARGS::Tog, amwsHard, 
00220        "[false] (MaxWalkSat) If true, MaxWalkSat can break a hard clause in "
00221        "order to satisfy a soft one."),
00222   
00223   ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00224        "[2] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00225        "2 = TABU, 3 = SAMPLESAT)."),
00226   
00227   ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00228        "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00229        "atom when using the tabu heuristic in MaxWalkSat." ),
00230 
00231   ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState, 
00232        "[false] (MaxWalkSat) If false, the naive way of saving low states "
00233        "(each time a low state is found, the whole state is saved) is used; "
00234        "otherwise, a list of variables flipped since the last low state is "
00235        "kept and the low state is reconstructed. This can be much faster for "
00236        "very large data sets."),  
00237     // END: MaxWalkSat args
00238 
00239     // BEGIN: MCMC args
00240   ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00241        "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00242 
00243   ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00244        "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00245 
00246   ARGS("minSteps", ARGS::Opt, amcmcMinSteps, 
00247        "[-1] (MCMC) Minimum number of MCMC sampling steps."),
00248 
00249   ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps, 
00250        "[1000] (MCMC) Maximum number of MCMC sampling steps."),
00251 
00252   ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds, 
00253        "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00254     // END: MCMC args
00255   
00256     // BEGIN: Simulated tempering args
00257   ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00258         "[2] (Simulated Tempering) Selection interval between swap attempts"),
00259 
00260   ARGS("numRuns", ARGS::Opt, asimtpNumST,
00261         "[3] (Simulated Tempering) Number of simulated tempering runs"),
00262 
00263   ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00264         "[10] (Simulated Tempering) Number of swapping chains"),
00265     // END: Simulated tempering args
00266 
00267     // BEGIN: BP args
00268   ARGS("lifted", ARGS::Tog, aliftedInfer, 
00269        "[false] If true, lifted inference is run"),
00270 
00271   ARGS("convThresh", ARGS::Opt, abpConvergenceThresh,
00272         "[1e-4] (BP) Max difference in probabilities to determine convergence"),
00273 
00274   ARGS("convIterations", ARGS::Opt, abpConvergeRequiredItrCnt,
00275         "[20] (BP) Number of converging iterations to determine convergence"),
00276 
00277   ARGS("explicitRep", ARGS::Tog, aexplicitRep, 
00278        "[false] If true, explicit representation type is used in lifted "
00279        "inference; otherwise, implicit representation type is used"), 
00280     // END: BP args
00281 
00282     // BEGIN: SampleSat args
00283   ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00284        "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00285 
00286   ARGS("saRatio", ARGS::Opt, assSaRatio,
00287        "[0] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00288        "MC-SAT"),
00289 
00290   ARGS("saTemperature", ARGS::Opt, assSaTemp,
00291         "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00292         "SampleSat"),
00293 
00294   ARGS("lateSa", ARGS::Tog, assLateSa,
00295        "[true] Run simulated annealing from the start in SampleSat"),
00296     // END: SampleSat args
00297 
00298     // BEGIN: Gibbs sampling args
00299   ARGS("numChains", ARGS::Opt, amcmcNumChains, 
00300        "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00301        "at least 2)."),
00302 
00303   ARGS("delta", ARGS::Opt, agibbsDelta,
00304        "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00305        "exceeded is less than this value."),
00306 
00307   ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00308        "[0.01] (Gibbs) Fractional error from true probability."),
00309 
00310   ARGS("fracConverged", ARGS::Opt, agibbsFracConverged, 
00311        "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00312        "have converged."),
00313 
00314   ARGS("walksatType", ARGS::Opt, agibbsWalksatType, 
00315        "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00316        "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00317 
00318   ARGS("testConvergence", ARGS::Opt, agibbsTestConvergence, 
00319        "[false] Perform convergence test for Gibbs sampling."),
00320 
00321   ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest, 
00322        "[100] Perform convergence test once after this many number of samples "
00323        "per chain."),
00324     // END: Gibbs sampling args
00325 
00326     // BEGIN: Args specific to stand-alone inference
00327   ARGS("e", ARGS::Req, aevidenceFiles, 
00328        "Comma-separated .db files containing known ground atoms (evidence), "
00329        "including function definitions."),
00330 
00331   ARGS("r", ARGS::Req, aresultsFile,
00332        "The probability estimates are written to this file."),
00333     
00334   ARGS("q", ARGS::Opt, aqueryPredsStr, 
00335        "Query atoms (comma-separated with no space)  "
00336        ",e.g., cancer,smokes(x),friends(Stan,x). Query atoms are always "
00337        "open world."),
00338 
00339   ARGS("f", ARGS::Opt, aqueryFile,
00340        "A .db file containing ground query atoms, "
00341        "which are are always open world."),
00342     // END: Args specific to stand-alone inference
00343 
00344   ARGS()
00345 };
00346 
00347 
00348 void printResults(const string& queryFile, const string& queryPredsStr,
00349                                   Domain *domain, ostream& out, 
00350                                   GroundPredicateHashArray* const &queries,
00351                                   Inference* const &inference, HVariableState* const &state)
00352 {
00353     // Lazy version: Have to generate the queries from the file or query string.
00354     // This involves calling createQueryFilePreds / createComLineQueryPreds
00355   if (aLazy)
00356   {
00357     const GroundPredicateHashArray* gndPredHashArray = NULL;
00358     Array<double>* gndPredProbs = NULL;
00359       // Inference algorithms with probs: have to retrieve this info from state.
00360       // These are the ground preds which have been brought into memory. All
00361       // others have always been false throughout sampling.
00362     if (!(amapPos || amapAll))
00363     {
00364       gndPredHashArray = state->getGndPredHashArrayPtr();
00365       gndPredProbs = new Array<double>;
00366       gndPredProbs->growToSize(gndPredHashArray->size());
00367       for (int i = 0; i < gndPredProbs->size(); i++)
00368         (*gndPredProbs)[i] = inference->getProbabilityH((*gndPredHashArray)[i]);
00369     }
00370 
00371     if (queryFile.length() > 0)
00372     {
00373       cout << "Writing query predicates that are specified in query file..."
00374            << endl;
00375       bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(),
00376                                      NULL, NULL, true, out, amapPos,
00377                                      gndPredHashArray, gndPredProbs, NULL);
00378       if (!ok)
00379       {
00380         cout <<"Failed to create query predicates."<< endl;
00381         exit(-1);
00382       }
00383     }
00384 
00385     Array<int> allPredGndingsAreQueries;
00386     allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00387     if (queryPredsStr.length() > 0)
00388     {
00389       cout << "Writing query predicates that are specified on command line..." 
00390            << endl;
00391       bool ok = createComLineQueryPreds(queryPredsStr, domain,
00392                                         domain->getDB(), NULL, NULL,
00393                                         &allPredGndingsAreQueries, true,
00394                                         out, amapPos, gndPredHashArray,
00395                                         gndPredProbs, NULL);
00396       if (!ok)
00397       {
00398         cout <<"Failed to create query predicates."<< endl; exit(-1);
00399       }
00400     }
00401 
00402     if (!(amapPos || amapAll))
00403       delete gndPredProbs;
00404   }
00405     // Eager version: Queries have already been generated and we can get the
00406     // information directly from the state
00407   else
00408   {
00409     if (amapPos)
00410       inference->printTruePredsH(out);
00411     else
00412     {
00413       for (int i = 0; i < queries->size(); i++)
00414       {
00415           // Prob is smoothed in inference->getProbability
00416         double prob = inference->getProbabilityH((*queries)[i]);
00417         (*queries)[i]->print(out, domain); out << " " << prob << endl;
00418       }
00419     }
00420   }
00421 }
00422 
00423 
00439 void printResults(const string& queryFile, const string& queryPredsStr,
00440                   Domain *domain, ostream& out, 
00441                   GroundPredicateHashArray* const &queries,
00442                   Inference* const &inference, VariableState* const &state,
00443                   Array<Predicate*> const &queryPreds,
00444                   Array<TruthValue> const &queryPredValues)
00445 {
00446     // Lazy version: Have to generate the queries from the file or query string.
00447     // This involves calling createQueryFilePreds / createComLineQueryPreds
00448   if (aLazy)
00449   {
00450       // Inference algorithms with probs: have to retrieve this info from state.
00451       // These are the ground preds which have been brought into memory. All
00452       // others have always been false throughout sampling.
00453 
00454     for (int i = 0; i < queryPreds.size(); i++)
00455     {
00456       int val = (queryPredValues[i] == TRUE) ? 1 : 0;
00457 
00458         // Prob is smoothed in inference->getProbability
00459       double prob = inference->getProbability((GroundPredicate*)queryPreds[i]);
00460 
00461       queryPreds[i]->print(out, domain);
00462       out << " " << prob << " " << val << endl;
00463     }
00464 
00465         /*
00466         if (!(amapPos || amapAll))
00467     {
00468       gndPredHashArray = state->getGndPredHashArrayPtr();
00469       gndPredProbs = new Array<double>;
00470       gndPredProbs->growToSize(gndPredHashArray->size());
00471       for (int i = 0; i < gndPredProbs->size(); i++)
00472         (*gndPredProbs)[i] =
00473           inference->getProbability((*gndPredHashArray)[i]);
00474     }
00475     
00476     if (queryFile.length() > 0)
00477     {
00478       cout << "Writing query predicates that are specified in query file..."
00479            << endl;
00480       bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(), NULL,
00481                                      NULL, true, out, amapPos,
00482                                      gndPredHashArray, gndPredProbs);
00483       if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00484     }
00485 
00486     Array<int> allPredGndingsAreQueries;
00487     allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00488     if (queryPredsStr.length() > 0)
00489     {
00490       cout << "Writing query predicates that are specified on command line..." 
00491            << endl;
00492       bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(), 
00493                                         NULL, NULL, &allPredGndingsAreQueries,
00494                                         true, out, amapPos, gndPredHashArray,
00495                                         gndPredProbs);
00496       if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00497     }
00498     
00499     if (!(amapPos || amapAll))
00500       delete gndPredProbs;
00501           */
00502   }
00503     // Eager version: Queries have already been generated and we can get the
00504     // information directly from the state
00505   else
00506   {
00507     if (amapPos)
00508       inference->printTruePreds(out);
00509     else
00510     {
00511       for (int i = 0; i < queryPreds.size(); i++)
00512       {
00513         int val = (queryPredValues[i] == TRUE) ? 1 : 0;
00514 
00515           // Prob is smoothed in inference->getProbability
00516         double prob = inference->getProbability((GroundPredicate*)queryPreds[i]);
00517 
00518         queryPreds[i]->print(out, domain);
00519         out << " " << prob << " " << val << endl;
00520       }
00521     }
00522   }
00523 }
00524 
00525 void printResults(const string& queryFile, const string& queryPredsStr,
00526                   Domain *domain, ostream& out,
00527                   GroundPredicateHashArray* const &queries,
00528                   Inference* const &inference, VariableState* const &state)
00529 {
00530         // Lazy version: Have to generate the queries from the file or query string.
00531         // This involves calling createQueryFilePreds / createComLineQueryPreds
00532   if (aLazy)
00533   {
00534     const GroundPredicateHashArray* gndPredHashArray = NULL;
00535     Array<double>* gndPredProbs = NULL;
00536       // Inference algorithms with probs: have to retrieve this info from state.
00537       // These are the ground preds which have been brought into memory. All
00538       // others have always been false throughout sampling.
00539     if (!(amapPos || amapAll))
00540     {
00541       gndPredHashArray = state->getGndPredHashArrayPtr();
00542       gndPredProbs = new Array<double>;
00543       gndPredProbs->growToSize(gndPredHashArray->size());
00544       for (int i = 0; i < gndPredProbs->size(); i++)
00545         (*gndPredProbs)[i] =
00546           inference->getProbability((*gndPredHashArray)[i]);
00547     }
00548 
00549     if (queryFile.length() > 0)
00550     {
00551       cout << "Writing query predicates that are specified in query file..."
00552            << endl;
00553       bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(), NULL,
00554                                      NULL, true, out, amapPos,
00555                                      gndPredHashArray, gndPredProbs, NULL);
00556       if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00557     }
00558 
00559     Array<int> allPredGndingsAreQueries;
00560     allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00561     if (queryPredsStr.length() > 0)
00562     {
00563       cout << "Writing query predicates that are specified on command line..." 
00564            << endl;
00565       bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(), 
00566                                         NULL, NULL, &allPredGndingsAreQueries,
00567                                         true, out, amapPos, gndPredHashArray,
00568                                         gndPredProbs, NULL);
00569       if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00570     }
00571 
00572     if (!(amapPos || amapAll))
00573       delete gndPredProbs;
00574   }
00575     // Eager version: Queries have already been generated and we can get the
00576     // information directly from the state
00577   else
00578   {
00579     if (amapPos)
00580       inference->printTruePreds(out);
00581     else
00582     {
00583       inference->printQFProbs(out, domain);
00584       for (int i = 0; i < queries->size(); i++)
00585       {
00586           // Prob is smoothed in inference->getProbability
00587         double prob = inference->getProbability((*queries)[i]);
00588         (*queries)[i]->print(out, domain); out << " " << prob << endl;
00589       }
00590     }
00591   }
00592 }
00593 
00594 
00603 int main(int argc, char* argv[])
00604 {
00606   ARGS::parse(argc, argv, &cout);
00607   Timer timer;
00608   double begSec = timer.time(); 
00609 
00610   Array<Predicate *> queryPreds;
00611   Array<TruthValue> queryPredValues;
00612 
00613   ofstream resultsOut(aresultsFile);
00614   if (!resultsOut.good())
00615   { cout << "ERROR: unable to open " << aresultsFile << endl; return -1; }
00616 
00617   Domain* domain = NULL;
00618   Inference* inference = NULL;
00619   if (buildInference(inference, domain, aisQueryEvidence, queryPreds,
00620                      queryPredValues) > -1)
00621   {
00622     inference->init();
00623       // No inference, just output network
00624     if (aoutputNetwork)
00625     {
00626       cout << "Writing network to file ..." << endl;
00627       inference->printNetwork(resultsOut);
00628     }
00629       // Perform inference
00630     else
00631     {
00632       inference->infer();
00633       if (aHybrid && !amapPos)
00634           { 
00635             printResults(queryFile, queryPredsStr, domain, resultsOut, &queries,
00636                      inference, inference->getHState());        
00637           }
00638       else
00639       {
00640             printResults(queryFile, queryPredsStr, domain, resultsOut, &queries,
00641                      inference, inference->getState());
00642       }
00643     }
00644   }
00645 
00646   resultsOut.close();
00647   if (domain) delete domain;
00648   for (int i = 0; i < knownQueries.size(); i++)
00649     if (knownQueries[i]) delete knownQueries[i];
00650   if (inference) delete inference;
00651   
00652   cout << "total time taken = "; Timer::printTime(cout, timer.time()-begSec);
00653   cout << endl;
00654 }
00655 

Generated on Sun Jun 7 11:55:12 2009 for Alchemy by  doxygen 1.5.1