00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #include <unistd.h>
00067 #include <fstream>
00068 #include <climits>
00069 #include <sys/times.h>
00070 #include "fol.h"
00071 #include "arguments.h"
00072 #include "util.h"
00073 #include "learnwts.h"
00074 #include "infer.h"
00075 #include "inferenceargs.h"
00076 #include "maxwalksat.h"
00077 #include "mcsat.h"
00078 #include "gibbssampler.h"
00079 #include "simulatedtempering.h"
00080
00081 extern const char* ZZ_TMP_FILE_POSTFIX;
00082
00083
00084
00085 char* aevidenceFiles = NULL;
00086 char* aresultsFile = NULL;
00087 char* aqueryPredsStr = NULL;
00088 char* aqueryFile = NULL;
00089
00090
00091
00092 ARGS ARGS::Args[] =
00093 {
00094
00095 ARGS("i", ARGS::Req, ainMLNFiles,
00096 "Comma-separated input .mln files."),
00097
00098 ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00099 "Specified non-evidence atoms (comma-separated with no space) are "
00100 "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00101 "appearing here cannot be query atoms and cannot appear in the -o "
00102 "option."),
00103
00104 ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00105 "Specified evidence atoms (comma-separated with no space) are open "
00106 "world, while other evidence atoms are closed-world. "
00107 "Atoms appearing here cannot appear in the -c option."),
00108
00109
00110
00111 ARGS("m", ARGS::Tog, amapPos,
00112 "Run MAP inference and return only positive query atoms."),
00113
00114 ARGS("a", ARGS::Tog, amapAll,
00115 "Run MAP inference and show 0/1 results for all query atoms."),
00116
00117 ARGS("p", ARGS::Tog, agibbsInfer,
00118 "Run inference using MCMC (Gibbs sampling) and return probabilities "
00119 "for all query atoms."),
00120
00121 ARGS("ms", ARGS::Tog, amcsatInfer,
00122 "Run inference using MC-SAT and return probabilities "
00123 "for all query atoms"),
00124
00125 ARGS("simtp", ARGS::Tog, asimtpInfer,
00126 "Run inference using simulated tempering and return probabilities "
00127 "for all query atoms"),
00128
00129 ARGS("seed", ARGS::Opt, aSeed,
00130 "[random] Seed used to initialize the randomizer in the inference "
00131 "algorithm. If not set, seed is initialized from the current date and "
00132 "time."),
00133
00134 ARGS("lazy", ARGS::Opt, aLazy,
00135 "[false] Run lazy version of inference if this flag is set."),
00136
00137 ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox,
00138 "[false] Lazy version of inference will not approximate by deactivating "
00139 "atoms to save memory. This flag is ignored if -lazy is not set."),
00140
00141 ARGS("memLimit", ARGS::Opt, aMemLimit,
00142 "[-1] Maximum limit in kbytes which should be used for inference. "
00143 "-1 means main memory available on system is used."),
00144
00145
00146
00147 ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00148 "[1000000] (MaxWalkSat) The max number of steps taken."),
00149
00150 ARGS("mwsTries", ARGS::Opt, amwsTries,
00151 "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00152
00153 ARGS("mwsTargetWt", ARGS::Opt, amwsTargetWt,
00154 "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00155 "with weight <= specified weight."),
00156
00157 ARGS("mwsHard", ARGS::Opt, amwsHard,
00158 "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00159 "satisfy a soft one."),
00160
00161 ARGS("mwsHeuristic", ARGS::Opt, amwsHeuristic,
00162 "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00163 "2 = TABU, 3 = SAMPLESAT)."),
00164
00165 ARGS("mwsTabuLength", ARGS::Opt, amwsTabuLength,
00166 "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00167 "atom when using the tabu heuristic in MaxWalkSat." ),
00168
00169 ARGS("mwsLazyLowState", ARGS::Opt, amwsLazyLowState,
00170 "[false] (MaxWalkSat) If false, the naive way of saving low states "
00171 "(each time a low state is found, the whole state is saved) is used; "
00172 "otherwise, a list of variables flipped since the last low state is "
00173 "kept and the low state is reconstructed. This can be much faster for "
00174 "very large data sets."),
00175
00176
00177
00178 ARGS("mcmcBurnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00179 "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00180
00181 ARGS("mcmcBurnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00182 "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00183
00184 ARGS("mcmcMinSteps", ARGS::Opt, amcmcMinSteps,
00185 "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00186
00187 ARGS("mcmcMaxSteps", ARGS::Opt, amcmcMaxSteps,
00188 "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00189
00190 ARGS("mcmcMaxSeconds", ARGS::Opt, amcmcMaxSeconds,
00191 "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00192
00193
00194
00195 ARGS("simtpSubInterval", ARGS::Opt, asimtpSubInterval,
00196 "[2] (Simulated Tempering) Selection interval between swap attempts"),
00197
00198 ARGS("simtpNumRuns", ARGS::Opt, asimtpNumST,
00199 "[3] (Simulated Tempering) Number of simulated tempering runs"),
00200
00201 ARGS("simtpNumSwap", ARGS::Opt, asimtpNumSwap,
00202 "[10] (Simulated Tempering) Number of swapping chains"),
00203
00204
00205
00206 ARGS("mcsatNumStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00207 "[1] (MC-SAT) Number of total steps (mcsat & gibbs) for every mcsat "
00208 "step"),
00209
00210
00211
00212 ARGS("mcsatNumSolutions", ARGS::Opt, amwsNumSolutions,
00213 "[10] Return nth SAT solution in SampleSat"),
00214
00215 ARGS("mcsatSaRatio", ARGS::Opt, assSaRatio,
00216 "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00217 "MC-SAT"),
00218
00219 ARGS("mcsatSaTemperature", ARGS::Opt, assSaTemp,
00220 "[10] Temperature (/100) for sim. annealing step in SampleSat"),
00221
00222 ARGS("mcsatLateSa", ARGS::Tog, assLateSa,
00223 "[false] Run simulated annealing from the start in SampleSat"),
00224
00225
00226
00227 ARGS("gibbsNumChains", ARGS::Opt, amcmcNumChains,
00228 "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00229 "at least 2)."),
00230
00231 ARGS("gibbsDelta", ARGS::Opt, agibbsDelta,
00232 "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00233 "exceeded is less than this value."),
00234
00235 ARGS("gibbsEpsilonError", ARGS::Opt, agibbsEpsilonError,
00236 "[0.01] (Gibbs) Fractional error from true probability."),
00237
00238 ARGS("gibbsFracConverged", ARGS::Opt, agibbsFracConverged,
00239 "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00240 "have converged."),
00241
00242 ARGS("gibbsWalksatType", ARGS::Opt, agibbsWalksatType,
00243 "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00244 "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00245
00246 ARGS("gibbsSamplesPerTest", ARGS::Opt, agibbsSamplesPerTest,
00247 "[100] Perform convergence test once after this many number of samples "
00248 "per chain."),
00249
00250
00251
00252 ARGS("e", ARGS::Req, aevidenceFiles,
00253 "Comma-separated .db files containing known ground atoms (evidence), "
00254 "including function definitions."),
00255
00256 ARGS("r", ARGS::Req, aresultsFile,
00257 "The probability estimates are written to this file."),
00258
00259 ARGS("q", ARGS::Opt, aqueryPredsStr,
00260 "Query atoms (comma-separated with no space) "
00261 ",e.g., cancer,smokes(x),friends(Stan,x). Query atoms are always "
00262 "open world."),
00263
00264 ARGS("f", ARGS::Opt, aqueryFile,
00265 "A .db file containing ground query atoms, "
00266 "which are are always open world."),
00267
00268
00269 ARGS()
00270 };
00271
00272
00288 void printResults(const string& queryFile, const string& queryPredsStr,
00289 Domain *domain, ostream& out,
00290 GroundPredicateHashArray* const &queries,
00291 Inference* const &inference, VariableState* const &state)
00292 {
00293
00294
00295 if (aLazy)
00296 {
00297 const GroundPredicateHashArray* gndPredHashArray = NULL;
00298 Array<double>* gndPredProbs = NULL;
00299
00300
00301
00302 if (!(amapPos || amapAll))
00303 {
00304 gndPredHashArray = state->getGndPredHashArrayPtr();
00305 gndPredProbs = new Array<double>;
00306 gndPredProbs->growToSize(gndPredHashArray->size());
00307 for (int i = 0; i < gndPredProbs->size(); i++)
00308 (*gndPredProbs)[i] =
00309 inference->getProbability((*gndPredHashArray)[i]);
00310 }
00311
00312 if (queryFile.length() > 0)
00313 {
00314 cout << "Writing query predicates that are specified in query file..."
00315 << endl;
00316 bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(), NULL,
00317 NULL, true, out, amapPos,
00318 gndPredHashArray, gndPredProbs);
00319 if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00320 }
00321
00322 Array<int> allPredGndingsAreQueries;
00323 allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00324 if (queryPredsStr.length() > 0)
00325 {
00326 cout << "Writing query predicates that are specified on command line..."
00327 << endl;
00328 bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(),
00329 NULL, NULL, &allPredGndingsAreQueries,
00330 true, out, amapPos, gndPredHashArray,
00331 gndPredProbs);
00332 if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00333 }
00334
00335 if (!(amapPos || amapAll))
00336 delete gndPredProbs;
00337 }
00338
00339
00340 else
00341 {
00342 if (amapPos)
00343 inference->printTruePreds(out);
00344 else
00345 {
00346 for (int i = 0; i < queries->size(); i++)
00347 {
00348
00349 double prob = inference->getProbability((*queries)[i]);
00350 (*queries)[i]->print(out, domain); out << " " << prob << endl;
00351 }
00352 }
00353 }
00354 }
00355
00356
00365 int main(int argc, char* argv[])
00366 {
00368 ARGS::parse(argc, argv, &cout);
00369
00370 Timer timer;
00371 double begSec = timer.time();
00372
00373 string inMLNFile, wkMLNFile, evidenceFile, queryPredsStr, queryFile;
00374
00375 StringHashArray queryPredNames;
00376 StringHashArray owPredNames;
00377 StringHashArray cwPredNames;
00378 Domain* domain = NULL;
00379 MLN* mln = NULL;
00380 GroundPredicateHashArray queries;
00381 GroundPredicateHashArray knownQueries;
00382 Array<string> constFilesArr;
00383 Array<string> evidenceFilesArr;
00384
00385 Array<Predicate *> queryPreds;
00386 Array<TruthValue> queryPredValues;
00387
00388
00389
00390
00391
00392
00393 extractFileNames(ainMLNFiles, constFilesArr);
00394 assert(constFilesArr.size() >= 1);
00395 inMLNFile.append(constFilesArr[0]);
00396 constFilesArr.removeItem(0);
00397 extractFileNames(aevidenceFiles, evidenceFilesArr);
00398
00399 if (aqueryPredsStr) queryPredsStr.append(aqueryPredsStr);
00400 if (aqueryFile) queryFile.append(aqueryFile);
00401
00402 if (queryPredsStr.length() == 0 && queryFile.length() == 0)
00403 { cout << "No query predicates specified" << endl; return -1; }
00404
00405 ofstream resultsOut(aresultsFile);
00406 if (!resultsOut.good())
00407 { cout << "ERROR: unable to open " << aresultsFile << endl; return -1; }
00408
00409 if (agibbsInfer && amcmcNumChains < 2)
00410 {
00411 cout << "ERROR: there must be at least 2 MCMC chains in Gibbs sampling"
00412 << endl; return -1;
00413 }
00414
00415 if (!asimtpInfer && !amapPos && !amapAll && !agibbsInfer && !amcsatInfer)
00416 {
00417 cout << "ERROR: must specify one of -ms/-simtp/-m/-a/-p flags." << endl;
00418 return-1;
00419 }
00420
00421
00422 if (queryPredsStr.length() > 0 || queryFile.length() > 0)
00423 {
00424 if (!extractPredNames(queryPredsStr, &queryFile, queryPredNames)) return -1;
00425 }
00426
00427 if (amwsMaxSteps <= 0)
00428 { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
00429
00430 if (amwsTries <= 0)
00431 { cout << "ERROR: mwsTries must be positive" << endl; return -1; }
00432
00433
00434 if (aOpenWorldPredsStr)
00435 {
00436 if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames))
00437 return -1;
00438 assert(owPredNames.size() > 0);
00439 }
00440
00441
00442 if (aClosedWorldPredsStr)
00443 {
00444 if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames))
00445 return -1;
00446 assert(cwPredNames.size() > 0);
00447 if (!checkQueryPredsNotInClosedWorldPreds(queryPredNames, cwPredNames))
00448 return -1;
00449 }
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459
00460
00461 SampleSatParams* ssparams = new SampleSatParams;
00462 ssparams->lateSa = assLateSa;
00463 ssparams->saRatio = assSaRatio;
00464 ssparams->saTemp = assSaTemp;
00465
00466
00467 MaxWalksatParams* mwsparams = NULL;
00468 mwsparams = new MaxWalksatParams;
00469 mwsparams->ssParams = ssparams;
00470 mwsparams->maxSteps = amwsMaxSteps;
00471 mwsparams->maxTries = amwsTries;
00472 mwsparams->targetCost = amwsTargetWt;
00473 mwsparams->hard = amwsHard;
00474
00475
00476 mwsparams->numSolutions = amwsNumSolutions;
00477 mwsparams->heuristic = amwsHeuristic;
00478 mwsparams->tabuLength = amwsTabuLength;
00479 mwsparams->lazyLowState = amwsLazyLowState;
00480
00481
00482 MCSatParams* msparams = new MCSatParams;
00483 msparams->mwsParams = mwsparams;
00484
00485 msparams->numChains = 1;
00486 msparams->burnMinSteps = amcmcBurnMinSteps;
00487 msparams->burnMaxSteps = amcmcBurnMaxSteps;
00488 msparams->minSteps = amcmcMinSteps;
00489 msparams->maxSteps = amcmcMaxSteps;
00490 msparams->maxSeconds = amcmcMaxSeconds;
00491 msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
00492
00493
00494 GibbsParams* gibbsparams = new GibbsParams;
00495 gibbsparams->mwsParams = mwsparams;
00496 gibbsparams->numChains = amcmcNumChains;
00497 gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00498 gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00499 gibbsparams->minSteps = amcmcMinSteps;
00500 gibbsparams->maxSteps = amcmcMaxSteps;
00501 gibbsparams->maxSeconds = amcmcMaxSeconds;
00502
00503 gibbsparams->gamma = 1 - agibbsDelta;
00504 gibbsparams->epsilonError = agibbsEpsilonError;
00505 gibbsparams->fracConverged = agibbsFracConverged;
00506 gibbsparams->walksatType = agibbsWalksatType;
00507 gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00508
00509
00510 SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00511 stparams->mwsParams = mwsparams;
00512 stparams->numChains = amcmcNumChains;
00513 stparams->burnMinSteps = amcmcBurnMinSteps;
00514 stparams->burnMaxSteps = amcmcBurnMaxSteps;
00515 stparams->minSteps = amcmcMinSteps;
00516 stparams->maxSteps = amcmcMaxSteps;
00517 stparams->maxSeconds = amcmcMaxSeconds;
00518
00519 stparams->subInterval = asimtpSubInterval;
00520 stparams->numST = asimtpNumST;
00521 stparams->numSwap = asimtpNumSwap;
00522
00524
00525 cout << "Reading formulas and evidence predicates..." << endl;
00526
00527
00528 string::size_type bslash = inMLNFile.rfind("/");
00529 string tmp = (bslash == string::npos) ?
00530 inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
00531 char buf[100];
00532 sprintf(buf, "%s%s", tmp.c_str(), ZZ_TMP_FILE_POSTFIX);
00533 wkMLNFile = buf;
00534 copyFileAndAppendDbFile(inMLNFile, wkMLNFile,
00535 evidenceFilesArr, constFilesArr);
00536
00537
00538 domain = new Domain;
00539 mln = new MLN();
00540 bool addUnitClauses = false;
00541 bool mustHaveWtOrFullStop = true;
00542 bool warnAboutDupGndPreds = true;
00543 bool flipWtsOfFlippedClause = true;
00544
00545 bool allPredsExceptQueriesAreCW = owPredNames.empty();
00546 Domain* forCheckingPlusTypes = NULL;
00547
00548
00549
00550 if (!runYYParser(mln, domain, wkMLNFile.c_str(), allPredsExceptQueriesAreCW,
00551 &owPredNames, &queryPredNames, addUnitClauses,
00552 warnAboutDupGndPreds, 0, mustHaveWtOrFullStop,
00553 forCheckingPlusTypes, true, flipWtsOfFlippedClause))
00554 {
00555 unlink(wkMLNFile.c_str());
00556 return -1;
00557 }
00558
00559 unlink(wkMLNFile.c_str());
00560
00562
00564 Array<int> allPredGndingsAreQueries;
00565
00566
00567
00568 if (!aLazy)
00569 {
00570 if (queryFile.length() > 0)
00571 {
00572 cout << "Reading query predicates that are specified in query file..."
00573 << endl;
00574 bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(),
00575 &queries, &knownQueries);
00576 if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
00577 }
00578
00579 allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00580 if (queryPredsStr.length() > 0)
00581 {
00582 cout << "Creating query predicates that are specified on command line..."
00583 << endl;
00584 bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(),
00585 &queries, &knownQueries,
00586 &allPredGndingsAreQueries);
00587 if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
00588 }
00589 }
00590
00591
00592 bool markHardGndClauses = false;
00593 bool trackParentClauseWts = false;
00594
00595
00596 VariableState* state = new VariableState(&queries, NULL, NULL,
00597 &allPredGndingsAreQueries,
00598 markHardGndClauses,
00599 trackParentClauseWts,
00600 mln, domain, aLazy);
00601 Inference* inference = NULL;
00602 bool trackClauseTrueCnts = false;
00603
00604 if (amapPos || amapAll || amcsatInfer || agibbsInfer || asimtpInfer)
00605 {
00606 if (amapPos || amapAll)
00607 {
00608
00609
00610 mwsparams->numSolutions = 1;
00611 inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts, mwsparams);
00612 }
00613 else if (amcsatInfer)
00614 {
00615 inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
00616 }
00617 else if (asimtpInfer)
00618 {
00619
00620
00621 mwsparams->numSolutions = 1;
00622 inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
00623 stparams);
00624 }
00625 else if (agibbsInfer)
00626 {
00627
00628
00629 mwsparams->numSolutions = 1;
00630 inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
00631 gibbsparams);
00632 }
00633
00634 inference->init();
00635 inference->infer();
00636
00637 printResults(queryFile, queryPredsStr, domain, resultsOut, &queries,
00638 inference, state);
00639 }
00640
00641 resultsOut.close();
00642 if (mwsparams) delete mwsparams;
00643 if (ssparams) delete ssparams;
00644 if (msparams) delete msparams;
00645 if (gibbsparams) delete gibbsparams;
00646 if (stparams) delete stparams;
00647 delete domain;
00648 for (int i = 0; i < knownQueries.size(); i++) delete knownQueries[i];
00649 delete inference;
00650 delete state;
00651
00652 cout << "total time taken = "; Timer::printTime(cout, timer.time()-begSec);
00653 cout << endl;
00654 }
00655