onlineengine.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, and Daniel Lowd.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, and Daniel Lowd. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, and Daniel Lowd in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef ONLINEENGINE_H_
00067 #define ONLINEENGINE_H_
00068 
00069 #include "inference.h"
00070 #include "infer.h"
00071 #include "variablestate.h"
00072 #include "arguments.h"
00073 #include "util.h"
00074 
00075   // Set to true for more output
00076 const bool oedebug = false;
00077 
00078   // TODO: List the arguments common to learnwts and inference in
00079   // inferenceargs.h. This can't be done with a static array.
00080 ARGS ARGS::Args[] = 
00081 {
00082     // BEGIN: Common arguments
00083   ARGS("i", ARGS::Req, ainMLNFiles, 
00084        "Comma-separated input .mln files."),
00085 
00086   ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00087        "Specified non-evidence atoms (comma-separated with no space) are "
00088        "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00089        "appearing here cannot be query atoms and cannot appear in the -o "
00090        "option."),
00091 
00092   ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00093        "Specified evidence atoms (comma-separated with no space) are open "
00094        "world, while other evidence atoms are closed-world. "
00095        "Atoms appearing here cannot appear in the -c option."),
00096     // END: Common arguments
00097 
00098     // BEGIN: Common inference arguments
00099   ARGS("m", ARGS::Tog, amapPos, 
00100        "Run MAP inference and return only positive query atoms."),
00101 
00102   ARGS("a", ARGS::Tog, amapAll, 
00103        "Run MAP inference and show 0/1 results for all query atoms."),
00104 
00105   ARGS("p", ARGS::Tog, agibbsInfer, 
00106        "Run inference using MCMC (Gibbs sampling) and return probabilities "
00107        "for all query atoms."),
00108   
00109   ARGS("ms", ARGS::Tog, amcsatInfer,
00110        "Run inference using MC-SAT and return probabilities "
00111        "for all query atoms"),
00112 
00113   ARGS("simtp", ARGS::Tog, asimtpInfer,
00114        "Run inference using simulated tempering and return probabilities "
00115        "for all query atoms"),
00116 
00117   ARGS("seed", ARGS::Opt, aSeed,
00118        "[random] Seed used to initialize the randomizer in the inference "
00119        "algorithm. If not set, seed is initialized from the current date and "
00120        "time."),
00121 
00122   ARGS("lazy", ARGS::Opt, aLazy, 
00123        "[false] Run lazy version of inference if this flag is set."),
00124   
00125   ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox, 
00126        "[false] Lazy version of inference will not approximate by deactivating "
00127        "atoms to save memory. This flag is ignored if -lazy is not set."),
00128   
00129   ARGS("memLimit", ARGS::Opt, aMemLimit, 
00130        "[-1] Maximum limit in kbytes which should be used for inference. "
00131        "-1 means main memory available on system is used."),
00132     // END: Common inference arguments
00133 
00134     // BEGIN: MaxWalkSat args
00135   ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00136        "[1000000] (MaxWalkSat) The max number of steps taken."),
00137 
00138   ARGS("tries", ARGS::Opt, amwsTries, 
00139        "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00140 
00141   ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00142        "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00143        "with weight <= specified weight."),
00144 
00145   ARGS("hard", ARGS::Opt, amwsHard, 
00146        "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00147        "satisfy a soft one."),
00148   
00149   ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00150        "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00151        "2 = TABU, 3 = SAMPLESAT)."),
00152   
00153   ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00154        "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00155        "atom when using the tabu heuristic in MaxWalkSat." ),
00156 
00157   ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState, 
00158        "[false] (MaxWalkSat) If false, the naive way of saving low states "
00159        "(each time a low state is found, the whole state is saved) is used; "
00160        "otherwise, a list of variables flipped since the last low state is "
00161        "kept and the low state is reconstructed. This can be much faster for "
00162        "very large data sets."),  
00163     // END: MaxWalkSat args
00164 
00165     // BEGIN: MCMC args
00166   ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00167        "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00168 
00169   ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00170        "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00171 
00172   ARGS("minSteps", ARGS::Opt, amcmcMinSteps, 
00173        "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00174 
00175   ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps, 
00176        "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00177 
00178   ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds, 
00179        "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00180     // END: MCMC args
00181   
00182     // BEGIN: Simulated tempering args
00183   ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00184         "[2] (Simulated Tempering) Selection interval between swap attempts"),
00185 
00186   ARGS("numRuns", ARGS::Opt, asimtpNumST,
00187         "[3] (Simulated Tempering) Number of simulated tempering runs"),
00188 
00189   ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00190         "[10] (Simulated Tempering) Number of swapping chains"),
00191     // END: Simulated tempering args
00192 
00193     // BEGIN: MC-SAT args
00194   //ARGS("numStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00195   //     "[1] (MC-SAT) Number of total steps (mcsat + gibbs) for every mcsat "
00196   //     "step"),
00197     // END: MC-SAT args
00198 
00199     // BEGIN: SampleSat args
00200   ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00201        "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00202 
00203   ARGS("saRatio", ARGS::Opt, assSaRatio,
00204        "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00205        "MC-SAT"),
00206 
00207   ARGS("saTemperature", ARGS::Opt, assSaTemp,
00208         "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00209         "SampleSat"),
00210 
00211   ARGS("lateSa", ARGS::Tog, assLateSa,
00212        "[false] Run simulated annealing from the start in SampleSat"),
00213     // END: SampleSat args
00214 
00215     // BEGIN: Gibbs sampling args
00216   ARGS("numChains", ARGS::Opt, amcmcNumChains, 
00217        "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00218        "at least 2)."),
00219 
00220   ARGS("delta", ARGS::Opt, agibbsDelta,
00221        "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00222        "exceeded is less than this value."),
00223 
00224   ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00225        "[0.01] (Gibbs) Fractional error from true probability."),
00226 
00227   ARGS("fracConverged", ARGS::Opt, agibbsFracConverged, 
00228        "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00229        "have converged."),
00230 
00231   ARGS("walksatType", ARGS::Opt, agibbsWalksatType, 
00232        "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00233        "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00234 
00235   ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest, 
00236        "[100] Perform convergence test once after this many number of samples "
00237        "per chain."),
00238     // END: Gibbs sampling args
00239 
00240     // BEGIN: Args specific to stand-alone inference
00241   ARGS("e", ARGS::Req, aevidenceFiles, 
00242        "Comma-separated .db files containing known ground atoms (evidence), "
00243        "including function definitions."),
00244 
00245   ARGS("q", ARGS::Opt, aqueryPredsStr, 
00246        "Query atoms (comma-separated with no space)  "
00247        ",e.g., cancer,smokes(x),friends(Stan,x). Query atoms are always "
00248        "open world."),
00249 
00250   ARGS("f", ARGS::Opt, aqueryFile,
00251        "A .db file containing ground query atoms, "
00252        "which are are always open world."),
00253     // END: Args specific to stand-alone inference
00254 
00255   ARGS()
00256 };
00257 
00258 
00268 class OnlineEngine
00269 {
00270  public:
00271  
00278   OnlineEngine(const string& inferString)
00279   {
00280     Inference* inference = NULL;
00281     parseInferString(inferString, inference);
00282     setInference(inference);
00283   }
00284   
00290   OnlineEngine(Inference* inference)
00291   {
00292     setInference(inference);
00293   }
00294 
00298   ~OnlineEngine()
00299   {
00300     delete inference_;
00301   }
00302      
00306   void init()
00307   {
00308     inference_->init();
00309   }
00310   
00321   void infer(vector<string>& nonZeroAtoms, vector<float>& probs)
00322   {
00323     nonZeroAtoms.clear();
00324     probs.clear();
00325     inference_->infer();
00326       // Fill in vectors
00327     inference_->getPredsWithNonZeroProb(nonZeroAtoms, probs);
00328     assert(nonZeroAtoms.size() == probs.size());
00329   }
00330 
00337   void addTrueEvidence(const vector<string>& evidence)
00338   {
00339     addRemoveEvidenceHelper(evidence, true, true);
00340   }
00341   
00348   void addFalseEvidence(const vector<string>& evidence)
00349   {
00350     addRemoveEvidenceHelper(evidence, true, false);
00351   }
00352   
00359   void removeEvidence(const vector<string>& oldEvidence)
00360   {
00361     addRemoveEvidenceHelper(oldEvidence, false, false);
00362   }
00363 
00369   void setInference(Inference* inference)
00370   {
00371     inference_ = inference;
00372   }
00373   
00381   void setMaxInferenceSteps(const int& inferenceSteps)
00382   {
00383       // Check if using MWS
00384     if (MaxWalkSat* mws = dynamic_cast<MaxWalkSat*>(inference_))
00385     {
00386       mws->setMaxSteps(inferenceSteps);
00387     }    
00388   }
00389   
00390  private:
00391 
00403   void addRemoveEvidenceHelper(const vector<string>& evidence,
00404                                const bool& addEvidence,
00405                                const bool& trueEvidence)
00406   {
00407     vector<string>::const_iterator it = evidence.begin();
00408     for (; it != evidence.end(); it++)
00409     {
00410       GroundPredicate* p = NULL;
00411       parseGroundPredicate((*it), p);
00412       if (addEvidence)
00413         inference_->getState()->setAsEvidence(p, trueEvidence);
00414       else
00415         inference_->getState()->setAsQuery(p);
00416     }
00417     inference_->getState()->reinit();
00418   }
00419  
00428   void parseGroundPredicate(const string& predicateAsString,
00429                             GroundPredicate*& predicate)
00430   {
00431     const Domain* domain = inference_->getState()->getDomain();
00432       
00433       // Parse left to right predname ( constant1 , constant2 , ... )
00434     string rest = string(predicateAsString);
00435     string::size_type leftPar = rest.find("(", 0);
00436     if (leftPar == string::npos)
00437     {
00438       cout << predicateAsString
00439            << " was given as a predicate but it is not well-formed" << endl;
00440       exit(-1);
00441     }
00442     string name = Util::trim(rest.substr(0, leftPar));
00443       // rest is constant1 , constant2 , ... )
00444     rest = rest.substr(leftPar+1);
00445     const PredicateTemplate* pt = domain->getPredicateTemplate(name.c_str());
00446     if (pt)
00447     {
00448       Predicate* p = new Predicate(pt);
00449       string::size_type comma = rest.find(",", 0);
00450       while (comma != string::npos)
00451       {
00452         string constant = Util::trim(rest.substr(0, comma));
00453         appendConstantToPredicate(p, constant);
00454         rest = Util::trim(rest.substr(comma+1));
00455         comma = rest.find(",", 0);
00456       }
00457       string::size_type rightPar = rest.find(")", 0);
00458       if (rightPar == string::npos)
00459       {
00460         cout << predicateAsString
00461              << " was given as a predicate but it is not well-formed" << endl;
00462         exit(-1);
00463       }
00464       string constant = Util::trim(rest.substr(0, rightPar));
00465       appendConstantToPredicate(p, constant);
00466       assert(p->isGrounded());
00467       predicate = new GroundPredicate(p);
00468     }
00469     else
00470     {
00471       cout << "Predicate " << name << " is not known. Exiting..." << endl;
00472       exit(-1);
00473     }
00474   }
00475 
00484   void appendConstantToPredicate(Predicate*& pred, const string& constant)
00485   {
00486     const char* name = constant.c_str();
00487     if (isupper(name[0]) || name[0] == '"')  // if is a constant
00488     {
00489       const Domain* domain = inference_->getState()->getDomain();
00490       int constId = domain->getConstantId(name);
00491       if (constId < 0)
00492       {
00493         cout << "appendConstantToPredicate(): failed to find constant " << name;
00494         exit(-1);
00495       }
00496       
00497         // if exceeded the number of terms
00498       int exp, unexp;
00499       if ((unexp = pred->getNumTerms()) ==
00500           (exp = pred->getTemplate()->getNumTerms()))
00501       {
00502         cout << "Wrong number of terms for predicate " << pred->getName()
00503              << ". Expected " << exp << " but given " << unexp; 
00504         exit(-1);
00505       }
00506 
00507         // Check that constant has same type as that of predicate term
00508       int typeId = pred->getTermTypeAsInt(pred->getNumTerms());
00509       int unexpId; 
00510       if (typeId != (unexpId = domain->getConstantTypeId(constId)))
00511       {
00512         const char* expName = domain->getTypeName(typeId);
00513         const char* unexpName = domain->getTypeName(unexpId);
00514         cout << "Constant " << name
00515              << " is of the wrong type. Expected " << expName
00516              << " but given " << unexpName;
00517         exit(-1);
00518       }
00519   
00520         // At this point, we have the right num of terms and right types
00521       if (pred != NULL) pred->appendTerm(new Term(constId, (void*)pred, true));
00522     }
00523     else  // is a variable
00524     {
00525       cout << constant
00526            << " does not appear to be a constant." << endl;
00527       exit(-1);  
00528     }
00529   }
00530   
00531   
00539   void parseInferString(const string& inferString, Inference*& inference)
00540   {
00541     int inferArgc = 0;
00542     char **inferArgv = new char*[200];
00543     for (int i = 0; i < 200; i++)
00544     {
00545       inferArgv[i] = new char[500];
00546     }
00547 
00548     extractArgs(inferString.c_str(), inferArgc, inferArgv);
00549     cout << "extractArgs " << inferArgc << endl;
00550     for (int i = 0; i < inferArgc; i++)
00551     {
00552       cout << i << ": " << inferArgv[i] << endl;
00553     }
00554 
00556     ARGS::parse(inferArgc, inferArgv, &cout);
00557 
00558       // HACK: Argument parser doesn't parse the ARGS::Tog right, so do
00559       // it here by hand
00560     for (int i = 0; i < inferArgc; i++)
00561     {
00562       if (string(inferArgv[i]) == "-m") amapPos = true;
00563       else if (string(inferArgv[i]) == "-a") amapAll = true;
00564       else if (string(inferArgv[i]) == "-p") agibbsInfer = true;
00565       else if (string(inferArgv[i]) == "-ms") amcsatInfer = true;
00566       else if (string(inferArgv[i]) == "-simtp") asimtpInfer = true;
00567     }
00568       // Delete memory allocated for args
00569     for (int i = 0; i < 200; i++)
00570     {
00571       delete[] inferArgv[i];
00572     }
00573     delete[] inferArgv;
00574     
00575     Domain* domain = NULL;
00576     if (!buildInference(inference, domain))
00577       exit(-1);
00578  }
00579 
00580  private:
00581  
00582   Inference* inference_;
00583 
00584 };
00585 
00586 #endif /*ONLINEENGINE_H_*/

Generated on Sun Jun 7 11:55:17 2009 for Alchemy by  doxygen 1.5.1