infer.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #include <unistd.h>
00067 #include <fstream>
00068 #include <climits>
00069 #include <sys/times.h>
00070 #include "fol.h"
00071 #include "arguments.h"
00072 #include "util.h"
00073 #include "learnwts.h"
00074 #include "infer.h"
00075 #include "inferenceargs.h"
00076 #include "maxwalksat.h"
00077 #include "mcsat.h"
00078 #include "gibbssampler.h"
00079 #include "simulatedtempering.h"
00080 
00081 extern const char* ZZ_TMP_FILE_POSTFIX; //defined in fol.y
00082 
00083 // Variables for holding inference command line args are in inferenceargs.h
00084 
00085 char* aevidenceFiles  = NULL;
00086 char* aresultsFile    = NULL;
00087 char* aqueryPredsStr  = NULL;
00088 char* aqueryFile      = NULL;
00089 
00090   // TODO: List the arguments common to learnwts and inference in
00091   // inferenceargs.h. This can't be done with a static array.
00092 ARGS ARGS::Args[] = 
00093 {
00094     // BEGIN: Common arguments
00095   ARGS("i", ARGS::Req, ainMLNFiles, 
00096        "Comma-separated input .mln files."),
00097 
00098   ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00099        "Specified non-evidence atoms (comma-separated with no space) are "
00100        "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00101        "appearing here cannot be query atoms and cannot appear in the -o "
00102        "option."),
00103 
00104   ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00105        "Specified evidence atoms (comma-separated with no space) are open "
00106        "world, while other evidence atoms are closed-world. "
00107        "Atoms appearing here cannot appear in the -c option."),
00108     // END: Common arguments
00109 
00110     // BEGIN: Common inference arguments
00111   ARGS("m", ARGS::Tog, amapPos, 
00112        "Run MAP inference and return only positive query atoms."),
00113 
00114   ARGS("a", ARGS::Tog, amapAll, 
00115        "Run MAP inference and show 0/1 results for all query atoms."),
00116 
00117   ARGS("p", ARGS::Tog, agibbsInfer, 
00118        "Run inference using MCMC (Gibbs sampling) and return probabilities "
00119        "for all query atoms."),
00120   
00121   ARGS("ms", ARGS::Tog, amcsatInfer,
00122        "Run inference using MC-SAT and return probabilities "
00123        "for all query atoms"),
00124 
00125   ARGS("simtp", ARGS::Tog, asimtpInfer,
00126        "Run inference using simulated tempering and return probabilities "
00127        "for all query atoms"),
00128 
00129   ARGS("seed", ARGS::Opt, aSeed,
00130        "[random] Seed used to initialize the randomizer in the inference "
00131        "algorithm. If not set, seed is initialized from the current date and "
00132        "time."),
00133 
00134   ARGS("lazy", ARGS::Opt, aLazy, 
00135        "[false] Run lazy version of inference if this flag is set."),
00136   
00137   ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox, 
00138        "[false] Lazy version of inference will not approximate by deactivating "
00139        "atoms to save memory. This flag is ignored if -lazy is not set."),
00140   
00141   ARGS("memLimit", ARGS::Opt, aMemLimit, 
00142        "[-1] Maximum limit in kbytes which should be used for inference. "
00143        "-1 means main memory available on system is used."),
00144     // END: Common inference arguments
00145 
00146     // BEGIN: MaxWalkSat args
00147   ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00148        "[1000000] (MaxWalkSat) The max number of steps taken."),
00149 
00150   ARGS("mwsTries", ARGS::Opt, amwsTries, 
00151        "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00152 
00153   ARGS("mwsTargetWt", ARGS::Opt, amwsTargetWt,
00154        "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00155        "with weight <= specified weight."),
00156 
00157   ARGS("mwsHard", ARGS::Opt, amwsHard, 
00158        "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00159        "satisfy a soft one."),
00160   
00161   ARGS("mwsHeuristic", ARGS::Opt, amwsHeuristic,
00162        "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00163        "2 = TABU, 3 = SAMPLESAT)."),
00164   
00165   ARGS("mwsTabuLength", ARGS::Opt, amwsTabuLength,
00166        "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00167        "atom when using the tabu heuristic in MaxWalkSat." ),
00168 
00169   ARGS("mwsLazyLowState", ARGS::Opt, amwsLazyLowState, 
00170        "[false] (MaxWalkSat) If false, the naive way of saving low states "
00171        "(each time a low state is found, the whole state is saved) is used; "
00172        "otherwise, a list of variables flipped since the last low state is "
00173        "kept and the low state is reconstructed. This can be much faster for "
00174        "very large data sets."),  
00175     // END: MaxWalkSat args
00176 
00177     // BEGIN: MCMC args
00178   ARGS("mcmcBurnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00179        "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00180 
00181   ARGS("mcmcBurnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00182        "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00183 
00184   ARGS("mcmcMinSteps", ARGS::Opt, amcmcMinSteps, 
00185        "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00186 
00187   ARGS("mcmcMaxSteps", ARGS::Opt, amcmcMaxSteps, 
00188        "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00189 
00190   ARGS("mcmcMaxSeconds", ARGS::Opt, amcmcMaxSeconds, 
00191        "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00192     // END: MCMC args
00193   
00194     // BEGIN: Simulated tempering args
00195   ARGS("simtpSubInterval", ARGS::Opt, asimtpSubInterval,
00196         "[2] (Simulated Tempering) Selection interval between swap attempts"),
00197 
00198   ARGS("simtpNumRuns", ARGS::Opt, asimtpNumST,
00199         "[3] (Simulated Tempering) Number of simulated tempering runs"),
00200 
00201   ARGS("simtpNumSwap", ARGS::Opt, asimtpNumSwap,
00202         "[10] (Simulated Tempering) Number of swapping chains"),
00203     // END: Simulated tempering args
00204 
00205     // BEGIN: MC-SAT args
00206   ARGS("mcsatNumStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00207        "[1] (MC-SAT) Number of total steps (mcsat & gibbs) for every mcsat "
00208        "step"),
00209     // END: MC-SAT args
00210 
00211     // BEGIN: SampleSat args
00212   ARGS("mcsatNumSolutions", ARGS::Opt, amwsNumSolutions,
00213        "[10] Return nth SAT solution in SampleSat"),
00214 
00215   ARGS("mcsatSaRatio", ARGS::Opt, assSaRatio,
00216        "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00217        "MC-SAT"),
00218 
00219   ARGS("mcsatSaTemperature", ARGS::Opt, assSaTemp,
00220         "[10] Temperature (/100) for sim. annealing step in SampleSat"),
00221 
00222   ARGS("mcsatLateSa", ARGS::Tog, assLateSa,
00223        "[false] Run simulated annealing from the start in SampleSat"),
00224     // END: SampleSat args
00225 
00226     // BEGIN: Gibbs sampling args
00227   ARGS("gibbsNumChains", ARGS::Opt, amcmcNumChains, 
00228        "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00229        "at least 2)."),
00230 
00231   ARGS("gibbsDelta", ARGS::Opt, agibbsDelta,
00232        "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00233        "exceeded is less than this value."),
00234 
00235   ARGS("gibbsEpsilonError", ARGS::Opt, agibbsEpsilonError,
00236        "[0.01] (Gibbs) Fractional error from true probability."),
00237 
00238   ARGS("gibbsFracConverged", ARGS::Opt, agibbsFracConverged, 
00239        "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00240        "have converged."),
00241 
00242   ARGS("gibbsWalksatType", ARGS::Opt, agibbsWalksatType, 
00243        "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00244        "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00245 
00246   ARGS("gibbsSamplesPerTest", ARGS::Opt, agibbsSamplesPerTest, 
00247        "[100] Perform convergence test once after this many number of samples "
00248        "per chain."),
00249     // END: Gibbs sampling args
00250 
00251     // BEGIN: Args specific to stand-alone inference
00252   ARGS("e", ARGS::Req, aevidenceFiles, 
00253        "Comma-separated .db files containing known ground atoms (evidence), "
00254        "including function definitions."),
00255 
00256   ARGS("r", ARGS::Req, aresultsFile,
00257        "The probability estimates are written to this file."),
00258     
00259   ARGS("q", ARGS::Opt, aqueryPredsStr, 
00260        "Query atoms (comma-separated with no space)  "
00261        ",e.g., cancer,smokes(x),friends(Stan,x). Query atoms are always "
00262        "open world."),
00263 
00264   ARGS("f", ARGS::Opt, aqueryFile,
00265        "A .db file containing ground query atoms, "
00266        "which are are always open world."),
00267     // END: Args specific to stand-alone inference
00268 
00269   ARGS()
00270 };
00271 
00272 
00288 void printResults(const string& queryFile, const string& queryPredsStr,
00289                   Domain *domain, ostream& out, 
00290                   GroundPredicateHashArray* const &queries,
00291                   Inference* const &inference, VariableState* const &state)
00292 {
00293     // Lazy version: Have to generate the queries from the file or query string.
00294     // This involves calling createQueryFilePreds / createComLineQueryPreds
00295   if (aLazy)
00296   {
00297     const GroundPredicateHashArray* gndPredHashArray = NULL;
00298     Array<double>* gndPredProbs = NULL;
00299       // Inference algorithms with probs: have to retrieve this info from state.
00300       // These are the ground preds which have been brought into memory. All
00301       // others have always been false throughout sampling.
00302     if (!(amapPos || amapAll))
00303     {
00304       gndPredHashArray = state->getGndPredHashArrayPtr();
00305       gndPredProbs = new Array<double>;
00306       gndPredProbs->growToSize(gndPredHashArray->size());
00307       for (int i = 0; i < gndPredProbs->size(); i++)
00308         (*gndPredProbs)[i] =
00309           inference->getProbability((*gndPredHashArray)[i]);
00310     }
00311     
00312     if (queryFile.length() > 0)
00313     {
00314       cout << "Writing query predicates that are specified in query file..."
00315            << endl;
00316       bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(), NULL,
00317                                      NULL, true, out, amapPos,
00318                                      gndPredHashArray, gndPredProbs);
00319       if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00320     }
00321 
00322     Array<int> allPredGndingsAreQueries;
00323     allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00324     if (queryPredsStr.length() > 0)
00325     {
00326       cout << "Writing query predicates that are specified on command line..." 
00327            << endl;
00328       bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(), 
00329                                         NULL, NULL, &allPredGndingsAreQueries,
00330                                         true, out, amapPos, gndPredHashArray,
00331                                         gndPredProbs);
00332       if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00333     }
00334     
00335     if (!(amapPos || amapAll))
00336       delete gndPredProbs;
00337   }
00338     // Eager version: Queries have already been generated and we can get the
00339     // information directly from the state
00340   else
00341   {
00342     if (amapPos)
00343       inference->printTruePreds(out);
00344     else
00345     {
00346       for (int i = 0; i < queries->size(); i++)
00347       {
00348           // Prob is smoothed in inference->getProbability
00349         double prob = inference->getProbability((*queries)[i]);
00350         (*queries)[i]->print(out, domain); out << " " << prob << endl;
00351       }
00352     }
00353   }
00354 }
00355 
00356 
00365 int main(int argc, char* argv[])
00366 {
00368   ARGS::parse(argc, argv, &cout);
00369 
00370   Timer timer;
00371   double begSec = timer.time(); 
00372 
00373   string inMLNFile, wkMLNFile, evidenceFile, queryPredsStr, queryFile;
00374 
00375   StringHashArray queryPredNames;
00376   StringHashArray owPredNames;
00377   StringHashArray cwPredNames;
00378   Domain* domain = NULL;
00379   MLN* mln = NULL;
00380   GroundPredicateHashArray queries;
00381   GroundPredicateHashArray knownQueries;
00382   Array<string> constFilesArr;
00383   Array<string> evidenceFilesArr;
00384 
00385   Array<Predicate *> queryPreds;
00386   Array<TruthValue> queryPredValues;
00387   
00388   //the second .mln file to the last one in ainMLNFiles _may_ be used 
00389   //to hold constants, so they are held in constFilesArr. They will be
00390   //included into the first .mln file.
00391 
00392     //extract .mln, .db file names
00393   extractFileNames(ainMLNFiles, constFilesArr);
00394   assert(constFilesArr.size() >= 1);
00395   inMLNFile.append(constFilesArr[0]);
00396   constFilesArr.removeItem(0);
00397   extractFileNames(aevidenceFiles, evidenceFilesArr);
00398   
00399   if (aqueryPredsStr) queryPredsStr.append(aqueryPredsStr);
00400   if (aqueryFile) queryFile.append(aqueryFile);
00401 
00402   if (queryPredsStr.length() == 0 && queryFile.length() == 0)
00403   { cout << "No query predicates specified" << endl; return -1; }
00404 
00405   ofstream resultsOut(aresultsFile);
00406   if (!resultsOut.good())
00407   { cout << "ERROR: unable to open " << aresultsFile << endl; return -1; }
00408 
00409   if (agibbsInfer && amcmcNumChains < 2) 
00410   {
00411     cout << "ERROR: there must be at least 2 MCMC chains in Gibbs sampling" 
00412          << endl; return -1;
00413   }
00414 
00415   if (!asimtpInfer && !amapPos && !amapAll && !agibbsInfer && !amcsatInfer)
00416   {
00417     cout << "ERROR: must specify one of -ms/-simtp/-m/-a/-p flags." << endl;
00418     return-1;
00419   }
00420 
00421     //extract names of all query predicates
00422   if (queryPredsStr.length() > 0 || queryFile.length() > 0)
00423   {
00424     if (!extractPredNames(queryPredsStr, &queryFile, queryPredNames)) return -1;
00425   }
00426 
00427   if (amwsMaxSteps <= 0)
00428   { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
00429 
00430   if (amwsTries <= 0)
00431   { cout << "ERROR: mwsTries must be positive" << endl; return -1; }
00432 
00433     //extract names of open-world evidence predicates
00434   if (aOpenWorldPredsStr)
00435   {
00436     if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames)) 
00437       return -1;
00438     assert(owPredNames.size() > 0);
00439   }
00440 
00441     //extract names of closed-world non-evidence predicates
00442   if (aClosedWorldPredsStr)
00443   {
00444     if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames)) 
00445       return -1;
00446     assert(cwPredNames.size() > 0);
00447     if (!checkQueryPredsNotInClosedWorldPreds(queryPredNames, cwPredNames))
00448       return -1;
00449   }
00450 
00451   // TODO: Check if query atom in -o -> error
00452 
00453   // TODO: Check if atom in -c and -o -> error
00454 
00455 
00456   // TODO: Check if ev. atom in -c or
00457   // non-evidence in -o -> warning (this is default)
00458 
00459 
00460     // Set SampleSat parameters
00461   SampleSatParams* ssparams = new SampleSatParams;
00462   ssparams->lateSa = assLateSa;
00463   ssparams->saRatio = assSaRatio;
00464   ssparams->saTemp = assSaTemp;
00465 
00466     // Set MaxWalksat parameters
00467   MaxWalksatParams* mwsparams = NULL;
00468   mwsparams = new MaxWalksatParams;
00469   mwsparams->ssParams = ssparams;
00470   mwsparams->maxSteps = amwsMaxSteps;
00471   mwsparams->maxTries = amwsTries;
00472   mwsparams->targetCost = amwsTargetWt;
00473   mwsparams->hard = amwsHard;
00474     // numSolutions only applies when used in SampleSat.
00475     // When just MWS, this is set to 1
00476   mwsparams->numSolutions = amwsNumSolutions;
00477   mwsparams->heuristic = amwsHeuristic;
00478   mwsparams->tabuLength = amwsTabuLength;
00479   mwsparams->lazyLowState = amwsLazyLowState;
00480 
00481     // Set MC-SAT parameters
00482   MCSatParams* msparams = new MCSatParams;
00483   msparams->mwsParams = mwsparams;
00484     // MC-SAT needs only one chain
00485   msparams->numChains          = 1;
00486   msparams->burnMinSteps       = amcmcBurnMinSteps;
00487   msparams->burnMaxSteps       = amcmcBurnMaxSteps;
00488   msparams->minSteps           = amcmcMinSteps;
00489   msparams->maxSteps           = amcmcMaxSteps;
00490   msparams->maxSeconds         = amcmcMaxSeconds;
00491   msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
00492 
00493     // Set Gibbs parameters
00494   GibbsParams* gibbsparams = new GibbsParams;
00495   gibbsparams->mwsParams    = mwsparams;
00496   gibbsparams->numChains    = amcmcNumChains;
00497   gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00498   gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00499   gibbsparams->minSteps     = amcmcMinSteps;
00500   gibbsparams->maxSteps     = amcmcMaxSteps;
00501   gibbsparams->maxSeconds   = amcmcMaxSeconds;
00502 
00503   gibbsparams->gamma          = 1 - agibbsDelta;
00504   gibbsparams->epsilonError   = agibbsEpsilonError;
00505   gibbsparams->fracConverged  = agibbsFracConverged;
00506   gibbsparams->walksatType    = agibbsWalksatType;
00507   gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00508   
00509     // Set Sim. Tempering parameters
00510   SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00511   stparams->mwsParams    = mwsparams;
00512   stparams->numChains    = amcmcNumChains;
00513   stparams->burnMinSteps = amcmcBurnMinSteps;
00514   stparams->burnMaxSteps = amcmcBurnMaxSteps;
00515   stparams->minSteps     = amcmcMinSteps;
00516   stparams->maxSteps     = amcmcMaxSteps;
00517   stparams->maxSeconds   = amcmcMaxSeconds;
00518 
00519   stparams->subInterval = asimtpSubInterval;
00520   stparams->numST       = asimtpNumST;
00521   stparams->numSwap     = asimtpNumSwap;
00522 
00524 
00525   cout << "Reading formulas and evidence predicates..." << endl;
00526 
00527     // Copy inMLNFile to workingMLNFile & app '#include "evid.db"'
00528   string::size_type bslash = inMLNFile.rfind("/");
00529   string tmp = (bslash == string::npos) ? 
00530                inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
00531   char buf[100];
00532   sprintf(buf, "%s%s", tmp.c_str(), ZZ_TMP_FILE_POSTFIX);
00533   wkMLNFile = buf;
00534   copyFileAndAppendDbFile(inMLNFile, wkMLNFile,
00535                           evidenceFilesArr, constFilesArr);
00536 
00537     // Parse wkMLNFile, and create the domain, MLN, database
00538   domain = new Domain;
00539   mln = new MLN();
00540   bool addUnitClauses = false;
00541   bool mustHaveWtOrFullStop = true;
00542   bool warnAboutDupGndPreds = true;
00543   bool flipWtsOfFlippedClause = true;
00544   //bool allPredsExceptQueriesAreCW = true;
00545   bool allPredsExceptQueriesAreCW = owPredNames.empty();
00546   Domain* forCheckingPlusTypes = NULL;
00547 
00548         // Parse as if lazy inference is set to true to set evidence atoms in DB
00549     // If lazy is not used, this is removed from DB
00550   if (!runYYParser(mln, domain, wkMLNFile.c_str(), allPredsExceptQueriesAreCW, 
00551                    &owPredNames, &queryPredNames, addUnitClauses, 
00552                    warnAboutDupGndPreds, 0, mustHaveWtOrFullStop, 
00553                    forCheckingPlusTypes, true, flipWtsOfFlippedClause))
00554   {
00555     unlink(wkMLNFile.c_str());
00556     return -1;
00557   }
00558 
00559   unlink(wkMLNFile.c_str());
00560 
00562 
00564   Array<int> allPredGndingsAreQueries;
00565 
00566     // Eager inference: Build the queries for the mrf
00567     // Lazy version evaluates the query string / file when printing out
00568   if (!aLazy)
00569   {
00570     if (queryFile.length() > 0)
00571     {
00572       cout << "Reading query predicates that are specified in query file..."
00573            << endl;
00574       bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(),
00575                                      &queries, &knownQueries);
00576       if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
00577     }
00578 
00579     allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00580     if (queryPredsStr.length() > 0)
00581     {
00582       cout << "Creating query predicates that are specified on command line..." 
00583            << endl;
00584       bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(), 
00585                                         &queries, &knownQueries, 
00586                                         &allPredGndingsAreQueries);
00587       if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
00588     }
00589   }
00590 
00591     // Create inference algorithm and state based on queries and mln / domain
00592   bool markHardGndClauses = false;
00593   bool trackParentClauseWts = false;
00594     // Lazy version: queries and allPredGndingsAreQueries are empty,
00595     // markHardGndClause and trackParentClauseWts are not used
00596   VariableState* state = new VariableState(&queries, NULL, NULL,
00597                                            &allPredGndingsAreQueries,
00598                                            markHardGndClauses,
00599                                            trackParentClauseWts,
00600                                            mln, domain, aLazy);
00601   Inference* inference = NULL;
00602   bool trackClauseTrueCnts = false;
00603     // MAP inference, MC-SAT, Gibbs or Sim. Temp.
00604   if (amapPos || amapAll || amcsatInfer || agibbsInfer || asimtpInfer)
00605   {
00606     if (amapPos || amapAll)
00607     { // MaxWalkSat
00608         // When standalone MWS, numSolutions is always 1
00609         // (maybe there is a better way to do this?)
00610       mwsparams->numSolutions = 1;
00611       inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts, mwsparams);
00612     }
00613     else if (amcsatInfer)
00614     { // MC-SAT
00615       inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
00616     }
00617     else if (asimtpInfer)
00618     { // Simulated Tempering
00619         // When MWS is used in Sim. Temp., numSolutions is always 1
00620         // (maybe there is a better way to do this?)
00621       mwsparams->numSolutions = 1;
00622       inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
00623                                          stparams);
00624     }
00625     else if (agibbsInfer)
00626     { // Gibbs sampling
00627         // When MWS is used in Gibbs, numSolutions is always 1
00628         // (maybe there is a better way to do this?)
00629       mwsparams->numSolutions = 1;
00630       inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
00631                                    gibbsparams);
00632     }
00633 
00634     inference->init();
00635     inference->infer();
00636     
00637     printResults(queryFile, queryPredsStr, domain, resultsOut, &queries,
00638                  inference, state);
00639   }
00640 
00641   resultsOut.close();
00642   if (mwsparams) delete mwsparams;
00643   if (ssparams) delete ssparams;
00644   if (msparams) delete msparams;
00645   if (gibbsparams) delete gibbsparams;
00646   if (stparams) delete stparams;
00647   delete domain;
00648   for (int i = 0; i < knownQueries.size(); i++)  delete knownQueries[i];
00649   delete inference;
00650   delete state;
00651   
00652   cout << "total time taken = "; Timer::printTime(cout, timer.time()-begSec);
00653   cout << endl;
00654 }
00655 

Generated on Tue Jan 16 05:30:02 2007 for Alchemy by  doxygen 1.5.1