00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef ONLINEENGINE_H_
00067 #define ONLINEENGINE_H_
00068
00069 #include "inference.h"
00070 #include "infer.h"
00071 #include "variablestate.h"
00072 #include "arguments.h"
00073 #include "util.h"
00074
00075
00076 const bool oedebug = false;
00077
00078
00079
00080 ARGS ARGS::Args[] =
00081 {
00082
00083 ARGS("i", ARGS::Req, ainMLNFiles,
00084 "Comma-separated input .mln files."),
00085
00086 ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00087 "Specified non-evidence atoms (comma-separated with no space) are "
00088 "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00089 "appearing here cannot be query atoms and cannot appear in the -o "
00090 "option."),
00091
00092 ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00093 "Specified evidence atoms (comma-separated with no space) are open "
00094 "world, while other evidence atoms are closed-world. "
00095 "Atoms appearing here cannot appear in the -c option."),
00096
00097
00098
00099 ARGS("m", ARGS::Tog, amapPos,
00100 "Run MAP inference and return only positive query atoms."),
00101
00102 ARGS("a", ARGS::Tog, amapAll,
00103 "Run MAP inference and show 0/1 results for all query atoms."),
00104
00105 ARGS("p", ARGS::Tog, agibbsInfer,
00106 "Run inference using MCMC (Gibbs sampling) and return probabilities "
00107 "for all query atoms."),
00108
00109 ARGS("ms", ARGS::Tog, amcsatInfer,
00110 "Run inference using MC-SAT and return probabilities "
00111 "for all query atoms"),
00112
00113 ARGS("simtp", ARGS::Tog, asimtpInfer,
00114 "Run inference using simulated tempering and return probabilities "
00115 "for all query atoms"),
00116
00117 ARGS("seed", ARGS::Opt, aSeed,
00118 "[random] Seed used to initialize the randomizer in the inference "
00119 "algorithm. If not set, seed is initialized from the current date and "
00120 "time."),
00121
00122 ARGS("lazy", ARGS::Opt, aLazy,
00123 "[false] Run lazy version of inference if this flag is set."),
00124
00125 ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox,
00126 "[false] Lazy version of inference will not approximate by deactivating "
00127 "atoms to save memory. This flag is ignored if -lazy is not set."),
00128
00129 ARGS("memLimit", ARGS::Opt, aMemLimit,
00130 "[-1] Maximum limit in kbytes which should be used for inference. "
00131 "-1 means main memory available on system is used."),
00132
00133
00134
00135 ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00136 "[1000000] (MaxWalkSat) The max number of steps taken."),
00137
00138 ARGS("tries", ARGS::Opt, amwsTries,
00139 "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00140
00141 ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00142 "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00143 "with weight <= specified weight."),
00144
00145 ARGS("hard", ARGS::Opt, amwsHard,
00146 "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00147 "satisfy a soft one."),
00148
00149 ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00150 "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00151 "2 = TABU, 3 = SAMPLESAT)."),
00152
00153 ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00154 "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00155 "atom when using the tabu heuristic in MaxWalkSat." ),
00156
00157 ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState,
00158 "[false] (MaxWalkSat) If false, the naive way of saving low states "
00159 "(each time a low state is found, the whole state is saved) is used; "
00160 "otherwise, a list of variables flipped since the last low state is "
00161 "kept and the low state is reconstructed. This can be much faster for "
00162 "very large data sets."),
00163
00164
00165
00166 ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00167 "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00168
00169 ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00170 "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00171
00172 ARGS("minSteps", ARGS::Opt, amcmcMinSteps,
00173 "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00174
00175 ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps,
00176 "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00177
00178 ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds,
00179 "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00180
00181
00182
00183 ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00184 "[2] (Simulated Tempering) Selection interval between swap attempts"),
00185
00186 ARGS("numRuns", ARGS::Opt, asimtpNumST,
00187 "[3] (Simulated Tempering) Number of simulated tempering runs"),
00188
00189 ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00190 "[10] (Simulated Tempering) Number of swapping chains"),
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200 ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00201 "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00202
00203 ARGS("saRatio", ARGS::Opt, assSaRatio,
00204 "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00205 "MC-SAT"),
00206
00207 ARGS("saTemperature", ARGS::Opt, assSaTemp,
00208 "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00209 "SampleSat"),
00210
00211 ARGS("lateSa", ARGS::Tog, assLateSa,
00212 "[false] Run simulated annealing from the start in SampleSat"),
00213
00214
00215
00216 ARGS("numChains", ARGS::Opt, amcmcNumChains,
00217 "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00218 "at least 2)."),
00219
00220 ARGS("delta", ARGS::Opt, agibbsDelta,
00221 "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00222 "exceeded is less than this value."),
00223
00224 ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00225 "[0.01] (Gibbs) Fractional error from true probability."),
00226
00227 ARGS("fracConverged", ARGS::Opt, agibbsFracConverged,
00228 "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00229 "have converged."),
00230
00231 ARGS("walksatType", ARGS::Opt, agibbsWalksatType,
00232 "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00233 "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00234
00235 ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest,
00236 "[100] Perform convergence test once after this many number of samples "
00237 "per chain."),
00238
00239
00240
00241 ARGS("e", ARGS::Req, aevidenceFiles,
00242 "Comma-separated .db files containing known ground atoms (evidence), "
00243 "including function definitions."),
00244
00245 ARGS("q", ARGS::Opt, aqueryPredsStr,
00246 "Query atoms (comma-separated with no space) "
00247 ",e.g., cancer,smokes(x),friends(Stan,x). Query atoms are always "
00248 "open world."),
00249
00250 ARGS("f", ARGS::Opt, aqueryFile,
00251 "A .db file containing ground query atoms, "
00252 "which are are always open world."),
00253
00254
00255 ARGS()
00256 };
00257
00258
00268 class OnlineEngine
00269 {
00270 public:
00271
00278 OnlineEngine(const string& inferString)
00279 {
00280 Inference* inference = NULL;
00281 parseInferString(inferString, inference);
00282 setInference(inference);
00283 }
00284
00290 OnlineEngine(Inference* inference)
00291 {
00292 setInference(inference);
00293 }
00294
00298 ~OnlineEngine()
00299 {
00300 delete inference_;
00301 }
00302
00306 void init()
00307 {
00308 inference_->init();
00309 }
00310
00321 void infer(vector<string>& nonZeroAtoms, vector<float>& probs)
00322 {
00323 nonZeroAtoms.clear();
00324 probs.clear();
00325 inference_->infer();
00326
00327 inference_->getPredsWithNonZeroProb(nonZeroAtoms, probs);
00328 assert(nonZeroAtoms.size() == probs.size());
00329 }
00330
00337 void addTrueEvidence(const vector<string>& evidence)
00338 {
00339 addRemoveEvidenceHelper(evidence, true, true);
00340 }
00341
00348 void addFalseEvidence(const vector<string>& evidence)
00349 {
00350 addRemoveEvidenceHelper(evidence, true, false);
00351 }
00352
00359 void removeEvidence(const vector<string>& oldEvidence)
00360 {
00361 addRemoveEvidenceHelper(oldEvidence, false, false);
00362 }
00363
00369 void setInference(Inference* inference)
00370 {
00371 inference_ = inference;
00372 }
00373
00381 void setMaxInferenceSteps(const int& inferenceSteps)
00382 {
00383
00384 if (MaxWalkSat* mws = dynamic_cast<MaxWalkSat*>(inference_))
00385 {
00386 mws->setMaxSteps(inferenceSteps);
00387 }
00388 }
00389
00390 private:
00391
00403 void addRemoveEvidenceHelper(const vector<string>& evidence,
00404 const bool& addEvidence,
00405 const bool& trueEvidence)
00406 {
00407 vector<string>::const_iterator it = evidence.begin();
00408 for (; it != evidence.end(); it++)
00409 {
00410 GroundPredicate* p = NULL;
00411 parseGroundPredicate((*it), p);
00412 if (addEvidence)
00413 inference_->getState()->setAsEvidence(p, trueEvidence);
00414 else
00415 inference_->getState()->setAsQuery(p);
00416 }
00417 inference_->getState()->reinit();
00418 }
00419
00428 void parseGroundPredicate(const string& predicateAsString,
00429 GroundPredicate*& predicate)
00430 {
00431 const Domain* domain = inference_->getState()->getDomain();
00432
00433
00434 string rest = string(predicateAsString);
00435 string::size_type leftPar = rest.find("(", 0);
00436 if (leftPar == string::npos)
00437 {
00438 cout << predicateAsString
00439 << " was given as a predicate but it is not well-formed" << endl;
00440 exit(-1);
00441 }
00442 string name = Util::trim(rest.substr(0, leftPar));
00443
00444 rest = rest.substr(leftPar+1);
00445 const PredicateTemplate* pt = domain->getPredicateTemplate(name.c_str());
00446 if (pt)
00447 {
00448 Predicate* p = new Predicate(pt);
00449 string::size_type comma = rest.find(",", 0);
00450 while (comma != string::npos)
00451 {
00452 string constant = Util::trim(rest.substr(0, comma));
00453 appendConstantToPredicate(p, constant);
00454 rest = Util::trim(rest.substr(comma+1));
00455 comma = rest.find(",", 0);
00456 }
00457 string::size_type rightPar = rest.find(")", 0);
00458 if (rightPar == string::npos)
00459 {
00460 cout << predicateAsString
00461 << " was given as a predicate but it is not well-formed" << endl;
00462 exit(-1);
00463 }
00464 string constant = Util::trim(rest.substr(0, rightPar));
00465 appendConstantToPredicate(p, constant);
00466 assert(p->isGrounded());
00467 predicate = new GroundPredicate(p);
00468 }
00469 else
00470 {
00471 cout << "Predicate " << name << " is not known. Exiting..." << endl;
00472 exit(-1);
00473 }
00474 }
00475
00484 void appendConstantToPredicate(Predicate*& pred, const string& constant)
00485 {
00486 const char* name = constant.c_str();
00487 if (isupper(name[0]) || name[0] == '"')
00488 {
00489 const Domain* domain = inference_->getState()->getDomain();
00490 int constId = domain->getConstantId(name);
00491 if (constId < 0)
00492 {
00493 cout << "appendConstantToPredicate(): failed to find constant " << name;
00494 exit(-1);
00495 }
00496
00497
00498 int exp, unexp;
00499 if ((unexp = pred->getNumTerms()) ==
00500 (exp = pred->getTemplate()->getNumTerms()))
00501 {
00502 cout << "Wrong number of terms for predicate " << pred->getName()
00503 << ". Expected " << exp << " but given " << unexp;
00504 exit(-1);
00505 }
00506
00507
00508 int typeId = pred->getTermTypeAsInt(pred->getNumTerms());
00509 int unexpId;
00510 if (typeId != (unexpId = domain->getConstantTypeId(constId)))
00511 {
00512 const char* expName = domain->getTypeName(typeId);
00513 const char* unexpName = domain->getTypeName(unexpId);
00514 cout << "Constant " << name
00515 << " is of the wrong type. Expected " << expName
00516 << " but given " << unexpName;
00517 exit(-1);
00518 }
00519
00520
00521 if (pred != NULL) pred->appendTerm(new Term(constId, (void*)pred, true));
00522 }
00523 else
00524 {
00525 cout << constant
00526 << " does not appear to be a constant." << endl;
00527 exit(-1);
00528 }
00529 }
00530
00531
00539 void parseInferString(const string& inferString, Inference*& inference)
00540 {
00541 int inferArgc = 0;
00542 char **inferArgv = new char*[200];
00543 for (int i = 0; i < 200; i++)
00544 {
00545 inferArgv[i] = new char[500];
00546 }
00547
00548 extractArgs(inferString.c_str(), inferArgc, inferArgv);
00549 cout << "extractArgs " << inferArgc << endl;
00550 for (int i = 0; i < inferArgc; i++)
00551 {
00552 cout << i << ": " << inferArgv[i] << endl;
00553 }
00554
00556 ARGS::parse(inferArgc, inferArgv, &cout);
00557
00558
00559
00560 for (int i = 0; i < inferArgc; i++)
00561 {
00562 if (string(inferArgv[i]) == "-m") amapPos = true;
00563 else if (string(inferArgv[i]) == "-a") amapAll = true;
00564 else if (string(inferArgv[i]) == "-p") agibbsInfer = true;
00565 else if (string(inferArgv[i]) == "-ms") amcsatInfer = true;
00566 else if (string(inferArgv[i]) == "-simtp") asimtpInfer = true;
00567 }
00568
00569 for (int i = 0; i < 200; i++)
00570 {
00571 delete[] inferArgv[i];
00572 }
00573 delete[] inferArgv;
00574
00575 Domain* domain = NULL;
00576 if (!buildInference(inference, domain))
00577 exit(-1);
00578 }
00579
00580 private:
00581
00582 Inference* inference_;
00583
00584 };
00585
00586 #endif