00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include <unistd.h>
00068 #include <fstream>
00069 #include <climits>
00070 #include <sys/times.h>
00071 #include "fol.h"
00072 #include "arguments.h"
00073 #include "util.h"
00074 #include "infer.h"
00075
00076 extern const char* ZZ_TMP_FILE_POSTFIX;
00077
00078
00079
00080 ARGS ARGS::Args[] =
00081 {
00082
00083 ARGS("i", ARGS::Req, ainMLNFiles,
00084 "Comma-separated input .mln files."),
00085
00086 ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00087 "Specified non-evidence atoms (comma-separated with no space) are "
00088 "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00089 "appearing here cannot be query atoms and cannot appear in the -o "
00090 "option."),
00091
00092 ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00093 "Specified evidence atoms (comma-separated with no space) are open "
00094 "world, while other evidence atoms are closed-world. "
00095 "Atoms appearing here cannot appear in the -c option."),
00096
00097
00098 ARGS("queryEvidence", ARGS::Tog, aisQueryEvidence,
00099 "If this flag is set, then all the groundings of query preds not in db "
00100 "are assumed false evidence."),
00101
00102
00103 ARGS("m", ARGS::Tog, amapPos,
00104 "Run MAP inference and return only positive query atoms."),
00105
00106 ARGS("a", ARGS::Tog, amapAll,
00107 "Run MAP inference and show 0/1 results for all query atoms."),
00108
00109 ARGS("p", ARGS::Tog, agibbsInfer,
00110 "Run inference using MCMC (Gibbs sampling) and return probabilities "
00111 "for all query atoms."),
00112
00113 ARGS("ms", ARGS::Tog, amcsatInfer,
00114 "Run inference using MC-SAT and return probabilities "
00115 "for all query atoms"),
00116
00117 ARGS("bp", ARGS::Tog, abpInfer,
00118 "Run inference using belief propagation and return probabilities "
00119 "for all query atoms"),
00120
00121 ARGS("simtp", ARGS::Tog, asimtpInfer,
00122 "Run inference using simulated tempering and return probabilities "
00123 "for all query atoms"),
00124
00125 ARGS("outputNetwork", ARGS::Tog, aoutputNetwork,
00126 "Build the network and output to results file, instead of "
00127 "running inference"),
00128
00129 ARGS("counts", ARGS::Tog, aclauseCounts,
00130 "Write clause counts, not atom values or probabilities"),
00131
00132 ARGS("seed", ARGS::Opt, aSeed,
00133 "[2350877] Seed used to initialize the randomizer in the inference "
00134 "algorithm. If not set, seed is initialized from a fixed random number."),
00135
00136 ARGS("lazy", ARGS::Opt, aLazy,
00137 "[false] Run lazy version of inference if this flag is set."),
00138
00139 ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox,
00140 "[false] Lazy version of inference will not approximate by deactivating "
00141 "atoms to save memory. This flag is ignored if -lazy is not set."),
00142
00143 ARGS("memLimit", ARGS::Opt, aMemLimit,
00144 "[-1] Maximum limit in kbytes which should be used for inference. "
00145 "-1 means main memory available on system is used."),
00146
00147 ARGS("GndPredIdxMapFile", ARGS::Opt, aGndPredIdxMapFile,
00148 "File containing the mapping from ground predicates (dis & num) to their idx"),
00149
00150 ARGS("PrintSamplePerIteration", ARGS::Opt, aPrintSamplePerIteration,
00151 "Whether to print out variable values at each HMCS sample round."),
00152
00153 ARGS("SAInterval", ARGS::Opt, saInterval, "SA interval"),
00154
00155 ARGS("LineNum", ARGS::Opt, aLineNum, "LineNum"),
00156
00157 ARGS("LinePara", ARGS::Opt, aLinePara, "LinePara"),
00158
00159 ARGS("LineName", ARGS::Opt, aLineName, "LineName"),
00160
00161 ARGS("MaxSeconds", ARGS::Opt, aMaxSeconds, "Max seconds for HMWS and SA."),
00162
00163 ARGS("StartPt", ARGS::Opt, aStartPt, "Starting from a fixed point, load from testcont and testdis."),
00164
00165 ARGS("GenRandom", ARGS::Opt, aGenRandom, "generate random assignment as starting point, saving in testcont and testdis."),
00166
00167 ARGS("SATempDownRation", ARGS::Opt, aSATempDownRatio, "simulated annealing temperature degrade ratio."),
00168
00169 ARGS("SA", ARGS::Opt, aSA, "simulated annealing inference."),
00170
00171 ARGS("MWSRST", ARGS::Opt, aMWSrst, "Result file for MaxwalkSAT inference."),
00172
00173 ARGS("noisynum", ARGS::Opt, anumerator,
00174 "numerator value for noisy pick in HMWS."),
00175
00176 ARGS("noisyden", ARGS::Opt, adenominator,
00177 "denominator value for noisy pick in HMWS."),
00178
00179 ARGS("hmwsDis", ARGS::Opt, aHMWSDis,
00180 "dis inference result file for hybrid maxwalksat."),
00181
00182 ARGS("testcont", ARGS::Opt, atestcont,
00183 "file path containing assignment to cont variables."),
00184
00185 ARGS("testdis", ARGS::Opt, atestdis,
00186 "file path containing assignment to dis variables."),
00187
00188 ARGS("mwsMax", ARGS::Opt, aMaxOrNot,
00189 "[false] (MaxWalkSat) If false,MC-SAT uses WalkSAT, if true, MC-SAT use max-WalkSAT. "),
00190
00191 ARGS("hybrid", ARGS::Opt, aHybrid,
00192 "Flag for HMLN inferences."),
00193
00194 ARGS("propstdev", ARGS::Opt, aProposalStdev,
00195 "[1.0]Proposal stdev for SA step in HybridSAT."),
00196
00197 ARGS("conti", ARGS::Opt, aContPartFile,
00198 "input file containing the description of continuous part of hybrid MLN."),
00199
00200 ARGS("segl", ARGS::Opt, aSegListFile,
00201 "input file containing the list of grounded segments."),
00202
00203 ARGS("contSamples", ARGS::Opt, aContSamples,
00204 "output file for continuous variable samples."),
00205
00206
00207
00208
00209 ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00210 "[100000] (MaxWalkSat) The max number of steps taken."),
00211
00212 ARGS("tries", ARGS::Opt, amwsTries,
00213 "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00214
00215 ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00216 "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00217 "with weight <= specified weight."),
00218
00219 ARGS("breakHardClauses", ARGS::Tog, amwsHard,
00220 "[false] (MaxWalkSat) If true, MaxWalkSat can break a hard clause in "
00221 "order to satisfy a soft one."),
00222
00223 ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00224 "[2] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00225 "2 = TABU, 3 = SAMPLESAT)."),
00226
00227 ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00228 "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00229 "atom when using the tabu heuristic in MaxWalkSat." ),
00230
00231 ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState,
00232 "[false] (MaxWalkSat) If false, the naive way of saving low states "
00233 "(each time a low state is found, the whole state is saved) is used; "
00234 "otherwise, a list of variables flipped since the last low state is "
00235 "kept and the low state is reconstructed. This can be much faster for "
00236 "very large data sets."),
00237
00238
00239
00240 ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00241 "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00242
00243 ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00244 "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00245
00246 ARGS("minSteps", ARGS::Opt, amcmcMinSteps,
00247 "[-1] (MCMC) Minimum number of MCMC sampling steps."),
00248
00249 ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps,
00250 "[1000] (MCMC) Maximum number of MCMC sampling steps."),
00251
00252 ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds,
00253 "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00254
00255
00256
00257 ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00258 "[2] (Simulated Tempering) Selection interval between swap attempts"),
00259
00260 ARGS("numRuns", ARGS::Opt, asimtpNumST,
00261 "[3] (Simulated Tempering) Number of simulated tempering runs"),
00262
00263 ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00264 "[10] (Simulated Tempering) Number of swapping chains"),
00265
00266
00267
00268 ARGS("lifted", ARGS::Tog, aliftedInfer,
00269 "[false] If true, lifted inference is run"),
00270
00271 ARGS("convThresh", ARGS::Opt, abpConvergenceThresh,
00272 "[1e-4] (BP) Max difference in probabilities to determine convergence"),
00273
00274 ARGS("convIterations", ARGS::Opt, abpConvergeRequiredItrCnt,
00275 "[20] (BP) Number of converging iterations to determine convergence"),
00276
00277 ARGS("explicitRep", ARGS::Tog, aexplicitRep,
00278 "[false] If true, explicit representation type is used in lifted "
00279 "inference; otherwise, implicit representation type is used"),
00280
00281
00282
00283 ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00284 "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00285
00286 ARGS("saRatio", ARGS::Opt, assSaRatio,
00287 "[0] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00288 "MC-SAT"),
00289
00290 ARGS("saTemperature", ARGS::Opt, assSaTemp,
00291 "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00292 "SampleSat"),
00293
00294 ARGS("lateSa", ARGS::Tog, assLateSa,
00295 "[true] Run simulated annealing from the start in SampleSat"),
00296
00297
00298
00299 ARGS("numChains", ARGS::Opt, amcmcNumChains,
00300 "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00301 "at least 2)."),
00302
00303 ARGS("delta", ARGS::Opt, agibbsDelta,
00304 "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00305 "exceeded is less than this value."),
00306
00307 ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00308 "[0.01] (Gibbs) Fractional error from true probability."),
00309
00310 ARGS("fracConverged", ARGS::Opt, agibbsFracConverged,
00311 "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00312 "have converged."),
00313
00314 ARGS("walksatType", ARGS::Opt, agibbsWalksatType,
00315 "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00316 "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00317
00318 ARGS("testConvergence", ARGS::Opt, agibbsTestConvergence,
00319 "[false] Perform convergence test for Gibbs sampling."),
00320
00321 ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest,
00322 "[100] Perform convergence test once after this many number of samples "
00323 "per chain."),
00324
00325
00326
00327 ARGS("e", ARGS::Req, aevidenceFiles,
00328 "Comma-separated .db files containing known ground atoms (evidence), "
00329 "including function definitions."),
00330
00331 ARGS("r", ARGS::Req, aresultsFile,
00332 "The probability estimates are written to this file."),
00333
00334 ARGS("q", ARGS::Opt, aqueryPredsStr,
00335 "Query atoms (comma-separated with no space) "
00336 ",e.g., cancer,smokes(x),friends(Stan,x). Query atoms are always "
00337 "open world."),
00338
00339 ARGS("f", ARGS::Opt, aqueryFile,
00340 "A .db file containing ground query atoms, "
00341 "which are are always open world."),
00342
00343
00344 ARGS()
00345 };
00346
00347
00348 void printResults(const string& queryFile, const string& queryPredsStr,
00349 Domain *domain, ostream& out,
00350 GroundPredicateHashArray* const &queries,
00351 Inference* const &inference, HVariableState* const &state)
00352 {
00353
00354
00355 if (aLazy)
00356 {
00357 const GroundPredicateHashArray* gndPredHashArray = NULL;
00358 Array<double>* gndPredProbs = NULL;
00359
00360
00361
00362 if (!(amapPos || amapAll))
00363 {
00364 gndPredHashArray = state->getGndPredHashArrayPtr();
00365 gndPredProbs = new Array<double>;
00366 gndPredProbs->growToSize(gndPredHashArray->size());
00367 for (int i = 0; i < gndPredProbs->size(); i++)
00368 (*gndPredProbs)[i] = inference->getProbabilityH((*gndPredHashArray)[i]);
00369 }
00370
00371 if (queryFile.length() > 0)
00372 {
00373 cout << "Writing query predicates that are specified in query file..."
00374 << endl;
00375 bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(),
00376 NULL, NULL, true, out, amapPos,
00377 gndPredHashArray, gndPredProbs, NULL);
00378 if (!ok)
00379 {
00380 cout <<"Failed to create query predicates."<< endl;
00381 exit(-1);
00382 }
00383 }
00384
00385 Array<int> allPredGndingsAreQueries;
00386 allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00387 if (queryPredsStr.length() > 0)
00388 {
00389 cout << "Writing query predicates that are specified on command line..."
00390 << endl;
00391 bool ok = createComLineQueryPreds(queryPredsStr, domain,
00392 domain->getDB(), NULL, NULL,
00393 &allPredGndingsAreQueries, true,
00394 out, amapPos, gndPredHashArray,
00395 gndPredProbs, NULL);
00396 if (!ok)
00397 {
00398 cout <<"Failed to create query predicates."<< endl; exit(-1);
00399 }
00400 }
00401
00402 if (!(amapPos || amapAll))
00403 delete gndPredProbs;
00404 }
00405
00406
00407 else
00408 {
00409 if (amapPos)
00410 inference->printTruePredsH(out);
00411 else
00412 {
00413 for (int i = 0; i < queries->size(); i++)
00414 {
00415
00416 double prob = inference->getProbabilityH((*queries)[i]);
00417 (*queries)[i]->print(out, domain); out << " " << prob << endl;
00418 }
00419 }
00420 }
00421 }
00422
00423
00439 void printResults(const string& queryFile, const string& queryPredsStr,
00440 Domain *domain, ostream& out,
00441 GroundPredicateHashArray* const &queries,
00442 Inference* const &inference, VariableState* const &state,
00443 Array<Predicate*> const &queryPreds,
00444 Array<TruthValue> const &queryPredValues)
00445 {
00446
00447
00448 if (aLazy)
00449 {
00450
00451
00452
00453
00454 for (int i = 0; i < queryPreds.size(); i++)
00455 {
00456 int val = (queryPredValues[i] == TRUE) ? 1 : 0;
00457
00458
00459 double prob = inference->getProbability((GroundPredicate*)queryPreds[i]);
00460
00461 queryPreds[i]->print(out, domain);
00462 out << " " << prob << " " << val << endl;
00463 }
00464
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483
00484
00485
00486
00487
00488
00489
00490
00491
00492
00493
00494
00495
00496
00497
00498
00499
00500
00501
00502 }
00503
00504
00505 else
00506 {
00507 if (amapPos)
00508 inference->printTruePreds(out);
00509 else
00510 {
00511 for (int i = 0; i < queryPreds.size(); i++)
00512 {
00513 int val = (queryPredValues[i] == TRUE) ? 1 : 0;
00514
00515
00516 double prob = inference->getProbability((GroundPredicate*)queryPreds[i]);
00517
00518 queryPreds[i]->print(out, domain);
00519 out << " " << prob << " " << val << endl;
00520 }
00521 }
00522 }
00523 }
00524
00525 void printResults(const string& queryFile, const string& queryPredsStr,
00526 Domain *domain, ostream& out,
00527 GroundPredicateHashArray* const &queries,
00528 Inference* const &inference, VariableState* const &state)
00529 {
00530
00531
00532 if (aLazy)
00533 {
00534 const GroundPredicateHashArray* gndPredHashArray = NULL;
00535 Array<double>* gndPredProbs = NULL;
00536
00537
00538
00539 if (!(amapPos || amapAll))
00540 {
00541 gndPredHashArray = state->getGndPredHashArrayPtr();
00542 gndPredProbs = new Array<double>;
00543 gndPredProbs->growToSize(gndPredHashArray->size());
00544 for (int i = 0; i < gndPredProbs->size(); i++)
00545 (*gndPredProbs)[i] =
00546 inference->getProbability((*gndPredHashArray)[i]);
00547 }
00548
00549 if (queryFile.length() > 0)
00550 {
00551 cout << "Writing query predicates that are specified in query file..."
00552 << endl;
00553 bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(), NULL,
00554 NULL, true, out, amapPos,
00555 gndPredHashArray, gndPredProbs, NULL);
00556 if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00557 }
00558
00559 Array<int> allPredGndingsAreQueries;
00560 allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
00561 if (queryPredsStr.length() > 0)
00562 {
00563 cout << "Writing query predicates that are specified on command line..."
00564 << endl;
00565 bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(),
00566 NULL, NULL, &allPredGndingsAreQueries,
00567 true, out, amapPos, gndPredHashArray,
00568 gndPredProbs, NULL);
00569 if (!ok) { cout <<"Failed to create query predicates."<< endl; exit(-1); }
00570 }
00571
00572 if (!(amapPos || amapAll))
00573 delete gndPredProbs;
00574 }
00575
00576
00577 else
00578 {
00579 if (amapPos)
00580 inference->printTruePreds(out);
00581 else
00582 {
00583 inference->printQFProbs(out, domain);
00584 for (int i = 0; i < queries->size(); i++)
00585 {
00586
00587 double prob = inference->getProbability((*queries)[i]);
00588 (*queries)[i]->print(out, domain); out << " " << prob << endl;
00589 }
00590 }
00591 }
00592 }
00593
00594
00603 int main(int argc, char* argv[])
00604 {
00606 ARGS::parse(argc, argv, &cout);
00607 Timer timer;
00608 double begSec = timer.time();
00609
00610 Array<Predicate *> queryPreds;
00611 Array<TruthValue> queryPredValues;
00612
00613 ofstream resultsOut(aresultsFile);
00614 if (!resultsOut.good())
00615 { cout << "ERROR: unable to open " << aresultsFile << endl; return -1; }
00616
00617 Domain* domain = NULL;
00618 Inference* inference = NULL;
00619 if (buildInference(inference, domain, aisQueryEvidence, queryPreds,
00620 queryPredValues) > -1)
00621 {
00622 inference->init();
00623
00624 if (aoutputNetwork)
00625 {
00626 cout << "Writing network to file ..." << endl;
00627 inference->printNetwork(resultsOut);
00628 }
00629
00630 else
00631 {
00632 inference->infer();
00633 if (aHybrid && !amapPos)
00634 {
00635 printResults(queryFile, queryPredsStr, domain, resultsOut, &queries,
00636 inference, inference->getHState());
00637 }
00638 else
00639 {
00640 printResults(queryFile, queryPredsStr, domain, resultsOut, &queries,
00641 inference, inference->getState());
00642 }
00643 }
00644 }
00645
00646 resultsOut.close();
00647 if (domain) delete domain;
00648 for (int i = 0; i < knownQueries.size(); i++)
00649 if (knownQueries[i]) delete knownQueries[i];
00650 if (inference) delete inference;
00651
00652 cout << "total time taken = "; Timer::printTime(cout, timer.time()-begSec);
00653 cout << endl;
00654 }
00655