infer.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #ifndef _INFER_H_OCT_30_2005
00068 #define _INFER_H_OCT_30_2005
00069 
00074 #include "util.h"
00075 #include "fol.h"
00076 #include "mrf.h"
00077 #include "learnwts.h"
00078 #include "inferenceargs.h"
00079 #include "maxwalksat.h"
00080 #include "mcsat.h"
00081 #include "gibbssampler.h"
00082 #include "simulatedtempering.h"
00083 #include "bp.h"
00084 #include "variablestate.h"
00085 #include "hvariablestate.h"
00086 #include "hmcsat.h"
00087 #include "lbfgsp.h"
00088 
00089 // Variables for holding inference command line args are in inferenceargs.h
00090 
00091 char* aevidenceFiles  = NULL;
00092 char* aresultsFile    = NULL;
00093 char* aqueryPredsStr  = NULL;
00094 char* aqueryFile      = NULL;
00095 
00096 char* atestcont = NULL;
00097 char* atestdis = NULL;
00098 char* aHMWSDis = NULL;
00099 char* aMWSrst = NULL;
00100 
00101 int anumerator = 50;
00102 int adenominator = 100;
00103 
00104 char* aMaxSeconds = NULL;
00105 bool aGenRandom = false;
00106 bool aStartPt = false;
00107 
00108 int aLineNum = -1;
00109 char* aLinePara = NULL;
00110 char* aLineName = NULL;
00111 
00112 int saInterval = 100;
00113 
00114 char* aGndPredIdxMapFile = NULL;
00115 bool aPrintSamplePerIteration = false;
00116 
00117 string queryPredsStr, queryFile;
00118 GroundPredicateHashArray queries;
00119 GroundPredicateHashArray knownQueries;
00120 
00149 bool createComLineQueryPreds(const string& queryPredsStr,
00150                              const Domain* const & domain,
00151                              Database* const & db,
00152                              GroundPredicateHashArray* const & queries,
00153                              GroundPredicateHashArray* const & knownQueries,
00154                              Array<int>* const & allPredGndingsAreQueries,
00155                              bool printToFile, ostream& out, bool amapPos,
00156                             const GroundPredicateHashArray* const & trueQueries,
00157                              const Array<double>* const & trueProbs,
00158                              Array<Array<Predicate* >* >* queryConjs)
00159 {
00160   if (queryPredsStr.length() == 0) return true;
00161   string predConjs = Util::trim(queryPredsStr);
00162 
00163     //replace the comma or semi-colon between query predicates with '\n'
00164   int balparen = 0;
00165   for (unsigned int i = 0; i < predConjs.length(); i++)
00166   {
00167     if (predConjs.at(i)=='(')                     balparen++;
00168     else if (predConjs.at(i)==')')                balparen--;
00169     else if ((predConjs.at(i)==';' || predConjs.at(i)==',') &&
00170              balparen==0) predConjs.at(i) = '\n';
00171   }
00172   
00173   bool ret = true;
00174   string predConj;
00175   istringstream iss(predConjs);
00176   char delimit[2]; delimit[1] = '\0';
00177 
00178     // for each query formula
00179   while (getline(iss, predConj))
00180   {
00181       // replace the ^ with '\n'
00182     int balparen = 0;
00183     for (unsigned int i = 0; i < predConj.length(); i++)
00184     {
00185       if (predConj.at(i)=='(')                     balparen++;
00186       else if (predConj.at(i)==')')                balparen--;
00187       else if (predConj.at(i)=='^' && balparen==0) predConj.at(i) = '\n';
00188     }
00189     bool onlyPredName;
00190     unsigned int cur;
00191     int termId, varIdCnt = 0;
00192     hash_map<string, int, HashString, EqualString> varToId;
00193     hash_map<string, int, HashString, EqualString>::iterator it;
00194     Array<VarsTypeId*>* vtiArr;
00195     string pred, predName, term;
00196     istringstream iss2(predConj);
00197     const PredicateTemplate* ptemplate;
00198     int predicate = 0;
00199 
00200     Array<Predicate* >* predArray = new Array<Predicate*>;
00201     if (queryConjs) queryConjs->append(predArray);
00202       // for each query pred on command line
00203     while (getline(iss2, pred))
00204     {
00205       pred = Util::trim(pred);
00206       onlyPredName = false;
00207       varToId.clear();
00208       varIdCnt = 0;
00209       cur = 0;
00210       bool negated = false;
00211 
00212         // find if pred is negated
00213       if (pred.at(0) == '!')
00214       {
00215         negated = true;
00216         pred.at(0) = ' ';
00217         pred = Util::trim(pred);
00218       }
00219         // get predicate name
00220       if (!Util::substr(pred, cur, predName, "("))
00221       {
00222         predName = pred;
00223         onlyPredName = true;
00224       }
00225     
00226         // Predicate must be in the domain
00227       ptemplate = domain->getPredicateTemplate(predName.c_str());
00228       if (ptemplate == NULL)
00229       {
00230         cout << "ERROR: Cannot find command line query predicate" << predName 
00231              << " in domain." << endl;
00232         ret = false;
00233         continue;
00234       }
00235       Predicate ppred(ptemplate);
00236 
00237         // if the terms of the query predicate are also specified
00238       if (!onlyPredName)
00239       {
00240           // get term name
00241         for (int i = 0; i < 2; i++)
00242         {
00243           if (i == 0) delimit[0] = ',';
00244           else        delimit[0] = ')';
00245           while (Util::substr(pred, cur, term, delimit))
00246           {
00247               // this is a constant
00248             if (isupper(term.at(0)) || term.at(0) == '"' || isdigit(term.at(0)))
00249             {
00250               termId = domain->getConstantId(term.c_str());
00251               if (termId < 0) 
00252               {
00253                 cout <<"ERROR: Cannot find constant "<<term<<" in database"<<endl;
00254                 ret = false;
00255               }
00256             }
00257             else
00258             {   // it is a variable        
00259               if ((it=varToId.find(term)) == varToId.end()) 
00260               {
00261                 termId = --varIdCnt;
00262                 varToId[term] = varIdCnt; 
00263               }
00264               else
00265                 termId = (*it).second;
00266             }
00267             ppred.appendTerm(new Term(termId, (void*)&ppred, true));
00268           }
00269         }
00270       }
00271       else
00272       {   // if only the predicate name is specified
00273           // HACK DEBUG
00274         //(*allPredGndingsAreQueries)[ptemplate->getId()] = true;
00275         for (int i = 0; i < ptemplate->getNumTerms(); i++)
00276           ppred.appendTerm(new Term(--varIdCnt, (void*)&ppred, true));
00277       }  
00278 
00279         // Check if number of terms is correct
00280       if (ppred.getNumTerms() != ptemplate->getNumTerms())
00281       {
00282         cout << "ERROR: " << predName << " requires " << ptemplate->getNumTerms()
00283              << " terms but given " << ppred.getNumTerms() << endl;
00284         ret = false;
00285       }
00286       if (!ret) continue;
00287     
00289       vtiArr = NULL;
00290       ppred.createVarsTypeIdArr(vtiArr);
00291 
00292         // If a ground predicate was specified on command line
00293       if (vtiArr->size() <= 1)
00294       {
00295         assert(ppred.isGrounded());
00296         assert(!db->isClosedWorld(ppred.getId()));
00297         TruthValue tv = db->getValue(&ppred);
00298 
00299         if (negated) ppred.setSense(false);
00300         predArray->append(new Predicate(ppred));
00301         
00302         GroundPredicate* gndPred = new GroundPredicate(&ppred);
00303           // If just printing to file, then all values must be known
00304         if (printToFile) assert(tv != UNKNOWN);
00305         if (tv == UNKNOWN)
00306         {
00307           if (queries->append(gndPred) < 0) delete gndPred;
00308         }
00309         else
00310         {
00311             // If just printing to file
00312           if (printToFile)
00313           {
00314               // If trueQueries is given as argument, then get prob. from there
00315             if (trueQueries)
00316             {
00317               double prob = 0.0;
00318               if (domain->getDB()->getEvidenceStatus(&ppred))
00319               {
00320                   // Don't print out evidence atoms
00321                 continue;
00322                 //prob = (tv == TRUE) ? 1.0 : 0.0;
00323               }
00324               else
00325               {
00326                 int found = trueQueries->find(gndPred);
00327                 if (found >= 0) prob = (*trueProbs)[found];
00328                 else
00329                     // Uniform smoothing
00330                   prob = (prob*10000+1/2.0)/(10000+1.0);
00331               
00332               }
00333               gndPred->print(out, domain); out << " " << prob << endl;
00334             }
00335             else
00336             {
00337               if (amapPos) //if show postive ground query predicates only
00338               {
00339                     if (tv == TRUE)
00340                 {
00341                       ppred.printWithStrVar(out, domain);
00342                       out << endl;
00343                 }
00344               }
00345               else //print all ground query predicates
00346               {
00347                 ppred.printWithStrVar(out, domain);
00348                 out << " " << tv << endl;
00349               }
00350             }
00351             delete gndPred;
00352           }
00353           else // Building queries for HashArray
00354           {
00355             //if (tv == TRUE) gndPred->setProbTrue(1);
00356             //else            gndPred->setProbTrue(0);
00357             if (knownQueries->append(gndPred) < 0) delete gndPred;  
00358           }
00359         }
00360       }
00361       else // Variables need to be grounded
00362       {
00363         ArraysAccessor<int> acc;
00364         for (int i = 1; i < vtiArr->size(); i++)
00365         {
00366           const Array<int>* cons =
00367             domain->getConstantsByType((*vtiArr)[i]->typeId);
00368           acc.appendArray(cons);
00369         } 
00370 
00371           // form all groundings of the predicate
00372         Array<int> constIds;
00373         while (acc.getNextCombination(constIds))
00374         {
00375           assert(constIds.size() == vtiArr->size()-1);
00376           for (int j = 0; j < constIds.size(); j++)
00377           {
00378             Array<Term*>& terms = (*vtiArr)[j+1]->vars;
00379             for (int k = 0; k < terms.size(); k++)
00380               terms[k]->setId(constIds[j]);
00381           }
00382 
00383           // at this point the predicate is grounded
00384           assert(!db->isClosedWorld(ppred.getId()));
00385  
00386           TruthValue tv = db->getValue(&ppred);        
00387           GroundPredicate* gndPred = new GroundPredicate(&ppred);
00388 
00389             // If just printing to file, then all values must be known
00390           if (printToFile) assert(tv != UNKNOWN);
00391           if (tv == UNKNOWN)
00392           {
00393             if (queries->append(gndPred) < 0) delete gndPred;
00394           }
00395           else
00396           {
00397               // If just printing to file
00398             if (printToFile)
00399             {
00400                 // If trueQueries is given as argument, then get prob. from there
00401               if (trueQueries)
00402               {
00403                 double prob = 0.0;
00404                 if (domain->getDB()->getEvidenceStatus(&ppred))
00405                 {
00406                     // Don't print out evidence atoms
00407                   continue;
00408                   //prob = (tv == TRUE) ? 1.0 : 0.0;
00409                 }
00410                 else
00411                 {
00412                   int found = trueQueries->find(gndPred);
00413                   if (found >= 0) prob = (*trueProbs)[found];
00414                   else
00415                       // Uniform smoothing
00416                     prob = (prob*10000+1/2.0)/(10000+1.0);
00417                 }
00418                   // Uniform smoothing
00419                 //prob = (prob*10000+1/2.0)/(10000+1.0);
00420                 gndPred->print(out, domain); out << " " << prob << endl;
00421               }
00422               else
00423               {
00424                 if (amapPos) //if show postive ground query predicates only
00425                 {
00426                   if (tv == TRUE)
00427                   {
00428                     ppred.printWithStrVar(out, domain);
00429                     out << endl;
00430                   }
00431                 }
00432                 else //print all ground query predicates
00433                 {
00434                   ppred.printWithStrVar(out, domain);
00435                   out << " " << tv << endl;
00436                 }
00437               }
00438               delete gndPred;          
00439             }
00440             else // Building queries
00441             {
00442                 //if (tv == TRUE) gndPred->setProbTrue(1);
00443                     //else            gndPred->setProbTrue(0);
00444               if (knownQueries->append(gndPred) < 0) delete gndPred;  
00445             }
00446           }
00447         }
00448       }
00449       
00450       ppred.deleteVarsTypeIdArr(vtiArr);
00451       predicate++;
00452     } // for each query pred on command line
00453     if (predArray->size() == 0)
00454     {
00455       if (queryConjs) queryConjs->removeLastItem();
00456       delete predArray;
00457     }    
00458   } // for each query formula
00459 
00460   if (!printToFile)
00461   {
00462         queries->compress();
00463         knownQueries->compress();
00464   }
00465   
00466   return ret;
00467 }
00468 
00473 bool createComLineQueryPreds(const string& queryPredsStr,
00474                              const Domain* const & domain,
00475                              Database* const & db,
00476                              GroundPredicateHashArray* const & queries,
00477                              GroundPredicateHashArray* const & knownQueries,
00478                              Array<int>* const & allPredGndingsAreQueries,
00479                              Array<Array<Predicate* >* >* queryConjs)
00480 {
00481   return createComLineQueryPreds(queryPredsStr, domain, db,
00482                                  queries, knownQueries,
00483                                  allPredGndingsAreQueries,
00484                                  false, cout, false, NULL, NULL, queryConjs);
00485 }
00486 
00497 bool extractPredNames(string preds, const string* queryFile, 
00498                       StringHashArray& predNames)
00499 { 
00500   predNames.clear();
00501 
00502     // first extract the query pred names specified on command line
00503   string::size_type cur = 0, ws, ltparen;
00504   string qpred, predName;
00505   
00506   if (preds.length() > 0)
00507   {
00508     preds.append(" "); //terminate preds with a whitespace
00509     
00510       //replace the comma or semi-colon between query predicates with ' '
00511     int balparen = 0;
00512     for (unsigned int i = 0; i < preds.length(); i++)
00513     {
00514       if (preds.at(i) == '(')      balparen++;
00515       else if (preds.at(i) == ')') balparen--;
00516       else if ((preds.at(i) == ',' || preds.at(i) == ';') &&
00517                balparen == 0) preds.at(i) = ' ';
00518     }
00519     
00520     while (preds.at(cur) == ' ') cur++;
00521     while (true)
00522     {
00523       ws = preds.find(" ", cur);
00524       if (ws == string::npos) break;
00525       qpred = preds.substr(cur,ws-cur+1);
00526       cur = ws+1;
00527       while (cur < preds.length() &&
00528              (preds.at(cur) == ' ' || preds.at(cur) == '^' ||
00529               preds.at(cur) == '!')) cur++;
00530       ltparen = qpred.find("(",0);
00531       
00532       if (ltparen == string::npos) 
00533       { 
00534         ws = qpred.find(" ");
00535         if (ws != string::npos) qpred = qpred.substr(0,ws);
00536         predName = qpred; 
00537       }
00538       else
00539         predName = qpred.substr(0,ltparen);
00540       
00541       predNames.append(predName);
00542     }
00543   }
00544 
00545   if (queryFile == NULL || queryFile->length() == 0) return true;
00546 
00547     // next extract query predicates specified in query file  
00548   ifstream in((*queryFile).c_str());
00549   if (!in.good())
00550   {
00551     cout << "ERROR: unable to open " << *queryFile << endl;
00552     return false;
00553   }
00554   string buffer;
00555   while (getline(in, buffer))
00556   {
00557     cur = 0;
00558     while (cur < buffer.length() &&
00559            (buffer.at(cur) == ' ' || buffer.at(cur) == '^' ||
00560             buffer.at(cur) == '!')) cur++;
00561     ltparen = buffer.find("(", cur);
00562     if (ltparen == string::npos) continue;
00563     predName = buffer.substr(cur, ltparen-cur);
00564     predNames.append(predName);
00565   }
00566 
00567   in.close();
00568   return true;
00569 }
00570 
00575 char getTruthValueFirstChar(const TruthValue& tv)
00576 {
00577   if (tv == TRUE)    return 'T';
00578   if (tv == FALSE)   return 'F';
00579   if (tv == UNKNOWN) return 'U';
00580   assert(false);
00581   exit(-1);
00582   return '#'; //avoid compilation warning
00583 }
00584 
00588 void setsrand()
00589 {
00590   struct timeval tv;
00591   struct timezone tzp;
00592   gettimeofday(&tv,&tzp);
00593   unsigned int seed = (( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec;
00594   srand(seed);
00595 } 
00596 
00597 
00598   //copy srcFile to dstFile, & append '#include "dbFiles"' to latter
00599 void copyFileAndAppendDbFile(const string& srcFile, string& dstFile, 
00600                              const Array<string>& dbFilesArr,
00601                              const Array<string>& constFilesArr)
00602 {
00603   ofstream out(dstFile.c_str());
00604   ifstream in(srcFile.c_str());
00605   if (!out.good()) { cout<<"ERROR: failed to open "<<dstFile<<endl;exit(-1);}
00606   if (!in.good()) { cout<<"ERROR: failed to open "<<srcFile<<endl;exit(-1);}
00607   
00608   string buffer;
00609   while(getline(in, buffer)) out << buffer << endl;
00610   in.close();
00611 
00612   out << endl;
00613   for (int i = 0; i < constFilesArr.size(); i++) 
00614     out << "#include \"" << constFilesArr[i] << "\"" << endl;
00615   out << endl;
00616   for (int i = 0; i < dbFilesArr.size(); i++)    
00617     out << "#include \"" << dbFilesArr[i] << "\"" << endl;
00618   out.close();
00619 }
00620 
00621 
00622 bool checkQueryPredsNotInClosedWorldPreds(const StringHashArray& qpredNames,
00623                                           const StringHashArray& cwPredNames)
00624 {
00625   bool ok = true;
00626   for (int i = 0; i < qpredNames.size(); i++)
00627     if (cwPredNames.contains(qpredNames[i]))
00628     {
00629       cout << "ERROR: query predicate " << qpredNames[i] 
00630            << " cannot be specified as closed world" << endl; 
00631       ok = false;
00632     }
00633   return ok;
00634 }
00635 
00664 bool createQueryFilePreds(const string& queryFile,
00665                           const Domain* const & domain,
00666                           Database* const & db,
00667                           GroundPredicateHashArray* const &queries,
00668                           GroundPredicateHashArray* const &knownQueries,
00669                           bool printToFile, ostream& out, bool amapPos,
00670                           const GroundPredicateHashArray* const &trueQueries,
00671                           const Array<double>* const & trueProbs,
00672                           Array<Array<Predicate* >* >* queryConjs)
00673 {
00674   if (queryFile.length() == 0) return true;
00675 
00676   bool ret = true;
00677   string predConj;
00678   ifstream in(queryFile.c_str());
00679   char delimit[2]; delimit[1] = '\0';
00680 
00681     // for each query formula
00682   while (getline(in, predConj))
00683   {
00684       // replace the ^ with '\n'
00685     int balparen = 0;
00686     for (unsigned int i = 0; i < predConj.length(); i++)
00687     {
00688       if (predConj.at(i)=='(')                     balparen++;
00689       else if (predConj.at(i)==')')                balparen--;
00690       else if (predConj.at(i)=='^' && balparen==0) predConj.at(i) = '\n';
00691     }
00692 
00693     bool onlyPredName;
00694     unsigned int cur;
00695     int termId, varIdCnt = 0;
00696     hash_map<string, int, HashString, EqualString> varToId;
00697     hash_map<string, int, HashString, EqualString>::iterator it;
00698     Array<VarsTypeId*>* vtiArr;
00699     string pred, predName, term;
00700     istringstream iss2(predConj);
00701     const PredicateTemplate* ptemplate;
00702     int predicate = 0;
00703 
00704     Array<Predicate* >* predArray = new Array<Predicate*>;
00705     queryConjs->append(predArray);
00706       // for each query pred on command line
00707     while (getline(iss2, pred))
00708     {
00709       pred = Util::trim(pred);
00710       onlyPredName = false;
00711       varToId.clear();
00712       varIdCnt = 0;
00713       cur = 0;
00714       bool negated = false;
00715 
00716         // find if pred is negated
00717       if (pred.at(0) == '!')
00718       {
00719         negated = true;
00720         pred.at(0) = ' ';
00721         pred = Util::trim(pred);
00722       }
00723         // get predicate name
00724       if (!Util::substr(pred, cur, predName, "("))
00725       {
00726         predName = pred;
00727         onlyPredName = true;
00728       }
00729     
00730         // Predicate must be in the domain
00731       ptemplate = domain->getPredicateTemplate(predName.c_str());
00732       if (ptemplate == NULL)
00733       {
00734         cout << "ERROR: Cannot find command line query predicate" << predName 
00735              << " in domain." << endl;
00736         ret = false;
00737         continue;
00738       }
00739       Predicate ppred(ptemplate);
00740 
00741         // if the terms of the query predicate are also specified
00742       if (!onlyPredName)
00743       {
00744           // get term name
00745         for (int i = 0; i < 2; i++)
00746         {
00747           if (i == 0) delimit[0] = ',';
00748           else        delimit[0] = ')';
00749           while (Util::substr(pred, cur, term, delimit))
00750           {
00751               // this is a constant
00752             if (isupper(term.at(0)) || term.at(0) == '"' || isdigit(term.at(0)))
00753             {
00754               termId = domain->getConstantId(term.c_str());
00755               if (termId < 0) 
00756               {
00757                 cout <<"ERROR: Cannot find constant "<<term<<" in database"<<endl;
00758                 ret = false;
00759               }
00760             }
00761             else
00762             {   // it is a variable        
00763               if ((it=varToId.find(term)) == varToId.end()) 
00764               {
00765                 termId = --varIdCnt;
00766                 varToId[term] = varIdCnt; 
00767               }
00768               else
00769                 termId = (*it).second;
00770             }
00771             ppred.appendTerm(new Term(termId, (void*)&ppred, true));
00772           }
00773         }
00774       }
00775       else
00776       {   // if only the predicate name is specified
00777           // HACK DEBUG
00778         //(*allPredGndingsAreQueries)[ptemplate->getId()] = true;
00779         for (int i = 0; i < ptemplate->getNumTerms(); i++)
00780           ppred.appendTerm(new Term(--varIdCnt, (void*)&ppred, true));
00781       }  
00782 
00783         // Check if number of terms is correct
00784       if (ppred.getNumTerms() != ptemplate->getNumTerms())
00785       {
00786         cout << "ERROR: " << predName << " requires " << ptemplate->getNumTerms()
00787              << " terms but given " << ppred.getNumTerms() << endl;
00788         ret = false;
00789       }
00790       if (!ret) continue;
00791     
00793       vtiArr = NULL;
00794       ppred.createVarsTypeIdArr(vtiArr);
00795 
00796         // If a ground predicate was specified on command line
00797       if (vtiArr->size() <= 1)
00798       {
00799         assert(ppred.isGrounded());
00800         assert(!db->isClosedWorld(ppred.getId()));
00801         TruthValue tv = db->getValue(&ppred);
00802 
00803         if (negated) ppred.setSense(false);
00804         predArray->append(new Predicate(ppred));
00805         
00806         GroundPredicate* gndPred = new GroundPredicate(&ppred);
00807           // If just printing to file, then all values must be known
00808         if (printToFile) assert(tv != UNKNOWN);
00809         if (tv == UNKNOWN)
00810         {
00811           if (queries->append(gndPred) < 0) delete gndPred;
00812         }
00813         else
00814         {
00815             // If just printing to file
00816           if (printToFile)
00817           {
00818               // If trueQueries is given as argument, then get prob. from there
00819             if (trueQueries)
00820             {
00821               double prob = 0.0;
00822               if (domain->getDB()->getEvidenceStatus(&ppred))
00823               {
00824                   // Don't print out evidence atoms
00825                 continue;
00826                 //prob = (tv == TRUE) ? 1.0 : 0.0;
00827               }
00828               else
00829               {
00830                 int found = trueQueries->find(gndPred);
00831                 if (found >= 0) prob = (*trueProbs)[found];
00832                 else
00833                     // Uniform smoothing
00834                   prob = (prob*10000+1/2.0)/(10000+1.0);
00835               
00836               }
00837               gndPred->print(out, domain); out << " " << prob << endl;
00838             }
00839             else
00840             {
00841               if (amapPos) //if show postive ground query predicates only
00842               {
00843                 if (tv == TRUE)
00844                 {
00845                   ppred.printWithStrVar(out, domain);
00846                   out << endl;
00847                 }
00848               }
00849               else //print all ground query predicates
00850               {
00851                 ppred.printWithStrVar(out, domain);
00852                 out << " " << tv << endl;
00853               }
00854             }
00855             delete gndPred;
00856           }
00857           else // Building queries for HashArray
00858           {
00859             //if (tv == TRUE) gndPred->setProbTrue(1);
00860             //else            gndPred->setProbTrue(0);
00861             if (knownQueries->append(gndPred) < 0) delete gndPred;  
00862           }
00863         }      
00864       }
00865       else // Variables need to be grounded
00866       {
00867         ArraysAccessor<int> acc;
00868         for (int i = 1; i < vtiArr->size(); i++)
00869         {
00870           const Array<int>* cons =
00871             domain->getConstantsByType((*vtiArr)[i]->typeId);
00872           acc.appendArray(cons);
00873         } 
00874 
00875           // form all groundings of the predicate
00876         Array<int> constIds;
00877         while (acc.getNextCombination(constIds))
00878         {
00879           assert(constIds.size() == vtiArr->size()-1);
00880           for (int j = 0; j < constIds.size(); j++)
00881           {
00882             Array<Term*>& terms = (*vtiArr)[j+1]->vars;
00883             for (int k = 0; k < terms.size(); k++)
00884               terms[k]->setId(constIds[j]);
00885           }
00886 
00887           // at this point the predicate is grounded
00888           assert(!db->isClosedWorld(ppred.getId()));
00889  
00890           TruthValue tv = db->getValue(&ppred);        
00891           GroundPredicate* gndPred = new GroundPredicate(&ppred);
00892 
00893             // If just printing to file, then all values must be known
00894           if (printToFile) assert(tv != UNKNOWN);
00895           if (tv == UNKNOWN)
00896           {
00897             if (queries->append(gndPred) < 0) delete gndPred;
00898           }
00899           else
00900           {
00901               // If just printing to file
00902             if (printToFile)
00903             {
00904                 // If trueQueries is given as argument, then get prob. from there
00905               if (trueQueries)
00906               {
00907                 double prob = 0.0;
00908                 if (domain->getDB()->getEvidenceStatus(&ppred))
00909                 {
00910                     // Don't print out evidence atoms
00911                   continue;
00912                   //prob = (tv == TRUE) ? 1.0 : 0.0;
00913                 }
00914                 else
00915                 {
00916                   int found = trueQueries->find(gndPred);
00917                   if (found >= 0) prob = (*trueProbs)[found];
00918                   else
00919                       // Uniform smoothing
00920                     prob = (prob*10000+1/2.0)/(10000+1.0);
00921                 }
00922                   // Uniform smoothing
00923                 //prob = (prob*10000+1/2.0)/(10000+1.0);
00924                 gndPred->print(out, domain); out << " " << prob << endl;
00925               }
00926               else
00927               {
00928                 if (amapPos) //if show postive ground query predicates only
00929                 {
00930                   if (tv == TRUE)
00931                   {
00932                     ppred.printWithStrVar(out, domain);
00933                     out << endl;
00934                   }
00935                 }
00936                 else //print all ground query predicates
00937                 {
00938                   ppred.printWithStrVar(out, domain);
00939                   out << " " << tv << endl;
00940                 }
00941               }
00942               delete gndPred;          
00943             }
00944             else // Building queries
00945             {
00946                 //if (tv == TRUE) gndPred->setProbTrue(1);
00947                 //else            gndPred->setProbTrue(0);
00948               if (knownQueries->append(gndPred) < 0) delete gndPred;  
00949             }
00950           }
00951         }
00952       }
00953       
00954       ppred.deleteVarsTypeIdArr(vtiArr);
00955       predicate++;
00956     } // for each query pred on command line
00957     if (predArray->size() == 0)
00958     {
00959       if (queryConjs) queryConjs->removeLastItem();
00960       delete predArray;
00961     }    
00962   } // fore each query formula
00963 
00964   if (!printToFile)
00965   {
00966     queries->compress();
00967     knownQueries->compress();
00968   }
00969   
00970   in.close();
00971   return ret;
00972 }
00973 
00978 bool createQueryFilePreds(const string& queryFile, const Domain* const & domain,
00979                           Database* const & db,
00980                           GroundPredicateHashArray* const &queries,
00981                           GroundPredicateHashArray* const &knownQueries,
00982                           Array<Array<Predicate* >* >* queryConjs)
00983 {
00984   return createQueryFilePreds(queryFile, domain, db, queries, knownQueries,
00985                               false, cout, false, NULL, NULL, queryConjs);
00986 }
00987 
00988 
00995 int buildInference(Inference*& inference, Domain*& domain,
00996                    bool const &aisQueryEvidence, Array<Predicate *> &queryPreds,
00997                    Array<TruthValue> &queryPredValues)
00998 {
00999   string inMLNFile, wkMLNFile, evidenceFile;
01000 
01001   StringHashArray queryPredNames;
01002   StringHashArray owPredNames;
01003   StringHashArray cwPredNames;
01004   MLN* mln = NULL;
01005   Array<string> constFilesArr;
01006   Array<string> evidenceFilesArr;
01007 
01008   
01009   //the second .mln file to the last one in ainMLNFiles _may_ be used 
01010   //to hold constants, so they are held in constFilesArr. They will be
01011   //included into the first .mln file.
01012 
01013     //extract .mln, .db file names
01014   extractFileNames(ainMLNFiles, constFilesArr);
01015   assert(constFilesArr.size() >= 1);
01016   inMLNFile.append(constFilesArr[0]);
01017   constFilesArr.removeItem(0);
01018   extractFileNames(aevidenceFiles, evidenceFilesArr);
01019   
01020   if (aqueryPredsStr) queryPredsStr.append(aqueryPredsStr);
01021   if (aqueryFile) queryFile.append(aqueryFile);
01022 
01023   if (queryPredsStr.length() == 0 && queryFile.length() == 0)
01024   { cout << "No query predicates specified" << endl; return -1; }
01025 
01026   if (agibbsInfer && agibbsTestConvergence && amcmcNumChains < 2) 
01027   {
01028     cout << "ERROR: If testing for convergence, there must be at least 2 "
01029          << "MCMC chains in Gibbs sampling" 
01030          << endl; return -1;
01031   }
01032 
01033   if (!asimtpInfer && !amapPos && !amapAll && !agibbsInfer && !amcsatInfer &&
01034       !aHybrid && !aSA && !abpInfer && !aoutputNetwork)
01035   {
01036       // If nothing specified, use MC-SAT
01037     amcsatInfer = true;
01038   }
01039 
01040     //extract names of all query predicates
01041   if (queryPredsStr.length() > 0 || queryFile.length() > 0)
01042   {
01043     if (!extractPredNames(queryPredsStr, &queryFile, queryPredNames)) return -1;
01044   }
01045 
01046   if (amwsMaxSteps <= 0)
01047   { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
01048 
01049   if (amwsTries <= 0)
01050   { cout << "ERROR: mwsTries must be positive" << endl; return -1; }
01051 
01052     //extract names of open-world evidence predicates
01053   if (aOpenWorldPredsStr)
01054   {
01055     if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames)) 
01056       return -1;
01057     assert(owPredNames.size() > 0);
01058   }
01059 
01060     //extract names of closed-world non-evidence predicates
01061   if (aClosedWorldPredsStr)
01062   {
01063     if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames)) 
01064       return -1;
01065     assert(cwPredNames.size() > 0);
01066     if (!checkQueryPredsNotInClosedWorldPreds(queryPredNames, cwPredNames))
01067       return -1;
01068   }
01069 
01070   // TODO: Check if query atom in -o -> error
01071 
01072   // TODO: Check if atom in -c and -o -> error
01073 
01074 
01075   // TODO: Check if ev. atom in -c or
01076   // non-evidence in -o -> warning (this is default)
01077 
01078 
01079     // Set SampleSat parameters
01080   SampleSatParams* ssparams = new SampleSatParams;
01081   ssparams->lateSa = assLateSa;
01082   ssparams->saRatio = assSaRatio;
01083   ssparams->saTemp = assSaTemp;
01084 
01085     // Set MaxWalksat parameters
01086   MaxWalksatParams* mwsparams = new MaxWalksatParams;
01087   mwsparams->ssParams = ssparams;
01088   mwsparams->maxSteps = amwsMaxSteps;
01089   mwsparams->maxTries = amwsTries;
01090   mwsparams->targetCost = amwsTargetWt;
01091   mwsparams->hard = amwsHard;
01092     // numSolutions only applies when used in SampleSat.
01093     // When just MWS, this is set to 1
01094   mwsparams->numSolutions = amwsNumSolutions;
01095   mwsparams->heuristic = amwsHeuristic;
01096   mwsparams->tabuLength = amwsTabuLength;
01097   mwsparams->lazyLowState = amwsLazyLowState;
01098 
01099     // Set MC-SAT parameters
01100   MCSatParams* msparams = new MCSatParams;
01101   msparams->mwsParams = mwsparams;
01102     // MC-SAT needs only one chain
01103   msparams->numChains          = 1;
01104   msparams->burnMinSteps       = amcmcBurnMinSteps;
01105   msparams->burnMaxSteps       = amcmcBurnMaxSteps;
01106   msparams->minSteps           = amcmcMinSteps;
01107   msparams->maxSteps           = amcmcMaxSteps;
01108   msparams->maxSeconds         = amcmcMaxSeconds;
01109   //msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
01110 
01111     // Set Gibbs parameters
01112   GibbsParams* gibbsparams = new GibbsParams;
01113   gibbsparams->mwsParams    = mwsparams;
01114   gibbsparams->numChains    = amcmcNumChains;
01115   gibbsparams->burnMinSteps = amcmcBurnMinSteps;
01116   gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
01117   gibbsparams->minSteps     = amcmcMinSteps;
01118   gibbsparams->maxSteps     = amcmcMaxSteps;
01119   gibbsparams->maxSeconds   = amcmcMaxSeconds;
01120 
01121   gibbsparams->gamma           = 1 - agibbsDelta;
01122   gibbsparams->epsilonError    = agibbsEpsilonError;
01123   gibbsparams->fracConverged   = agibbsFracConverged;
01124   gibbsparams->walksatType     = agibbsWalksatType;
01125   gibbsparams->testConvergence = agibbsTestConvergence;
01126   gibbsparams->samplesPerTest  = agibbsSamplesPerTest;
01127   
01128     // Set Sim. Tempering parameters
01129   SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
01130   stparams->mwsParams    = mwsparams;
01131   stparams->numChains    = amcmcNumChains;
01132   stparams->burnMinSteps = amcmcBurnMinSteps;
01133   stparams->burnMaxSteps = amcmcBurnMaxSteps;
01134   stparams->minSteps     = amcmcMinSteps;
01135   stparams->maxSteps     = amcmcMaxSteps;
01136   stparams->maxSeconds   = amcmcMaxSeconds;
01137 
01138   stparams->subInterval = asimtpSubInterval;
01139   stparams->numST       = asimtpNumST;
01140   stparams->numSwap     = asimtpNumSwap;
01141 
01142     // Set BP parameters
01143   BPParams* bpparams = new BPParams;
01144   bpparams->maxSteps               = amcmcMaxSteps;
01145   bpparams->maxSeconds             = amcmcMaxSeconds;
01146   bpparams->lifted                 = aliftedInfer;
01147   bpparams->convergenceThresh      = abpConvergenceThresh;
01148   bpparams->convergeRequiredItrCnt = abpConvergeRequiredItrCnt;
01149   bpparams->implicitRep            = !aexplicitRep;
01150   bpparams->outputNetwork          = aoutputNetwork;
01151 
01153 
01154   cout << "Reading formulas and evidence predicates..." << endl;
01155 
01156     // Copy inMLNFile to workingMLNFile & app '#include "evid.db"'
01157   string::size_type bslash = inMLNFile.rfind("/");
01158   string tmp = (bslash == string::npos) ? 
01159                inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
01160   char buf[100];
01161   sprintf(buf, "%s%d%s", tmp.c_str(), getpid(), ZZ_TMP_FILE_POSTFIX);
01162   wkMLNFile = buf;
01163   copyFileAndAppendDbFile(inMLNFile, wkMLNFile,
01164                           evidenceFilesArr, constFilesArr);
01165 
01166     // Parse wkMLNFile, and create the domain, MLN, database
01167   domain = new Domain;
01168   mln = new MLN();
01169   bool addUnitClauses = false;
01170   bool mustHaveWtOrFullStop = true;
01171   bool warnAboutDupGndPreds = true;
01172   bool flipWtsOfFlippedClause = true;
01173   bool allPredsExceptQueriesAreCW = false;
01174   //bool allPredsExceptQueriesAreCW = owPredNames.empty();
01175   Domain* forCheckingPlusTypes = NULL;
01176 
01177     // Parse as if lazy inference is set to true to set evidence atoms in DB
01178     // If lazy is not used, this is removed from DB
01179   if (!runYYParser(mln, domain, wkMLNFile.c_str(), allPredsExceptQueriesAreCW, 
01180                    &owPredNames, &cwPredNames, &queryPredNames, addUnitClauses, 
01181                    warnAboutDupGndPreds, 0, mustHaveWtOrFullStop, 
01182                    forCheckingPlusTypes, true, flipWtsOfFlippedClause))
01183   {
01184     unlink(wkMLNFile.c_str());
01185     return -1;
01186   }
01187 
01188   unlink(wkMLNFile.c_str());
01189 
01191 
01193   Array<int> allPredGndingsAreQueries;
01194   Array<Array<Predicate* >* >* queryFormulas =  new Array<Array<Predicate*> *>;
01195 
01196     // Eager inference: Build the queries for the mrf
01197     // Lazy version evaluates the query string / file when printing out
01198   if (!aLazy)
01199   {
01200     if (queryFile.length() > 0)
01201     {
01202       cout << "Reading query predicates that are specified in query file..."
01203            << endl;
01204       bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(),
01205                                      &queries, &knownQueries, queryFormulas);
01206       if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
01207     }
01208 
01209     allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
01210     if (queryPredsStr.length() > 0)
01211     {
01212       // unePreds = unknown non-evidence predicates
01213       // nePreds  = known non-evidence predicates
01214       GroundPredicateHashArray unePreds;
01215       GroundPredicateHashArray knePreds;
01216       bool ok = createComLineQueryPreds(queryPredsStr, domain, 
01217                                   domain->getDB(), &unePreds, &knePreds,
01218                                   &allPredGndingsAreQueries, queryFormulas);
01219       if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
01220 
01221       if (aisQueryEvidence)
01222       {
01223         // If the isQueryEvidence flag is set, then all query predicates
01224         // that are specified in the database are assumed to be the
01225         // set of queries we're actually interested in, while all
01226         // unspecified query predicates are assumed to be false evidence.
01227         // This is useful for doing inference with canopies,
01228         // e.g., (Singla & Domingos, 2005).
01229 
01230         // All unknown queries are actually false evidence 
01231         queryPredValues.clear();
01232         for (int predno = 0; predno < unePreds.size(); predno++) 
01233           domain->getDB()->setValue(unePreds[predno], FALSE);
01234         knownQueries = unePreds;
01235 
01236         // All known queries are actually unknown 
01237         for (int predno = 0; predno < knePreds.size(); predno++) 
01238         {
01239           TruthValue origValue 
01240               = domain->getDB()->setValue(knePreds[predno], UNKNOWN);
01241           queryPredValues.append(origValue);
01242         }
01243         queries = knePreds;
01244       }
01245       else
01246       {
01247         queries = unePreds;
01248         knownQueries = knePreds;
01249       }
01250     }
01251   }
01252 
01253   bool trackClauseTrueCnts = false;
01254   VariableState* state = NULL;
01255   HVariableState* hstate = NULL;
01256   FactorGraph* factorGraph = NULL;
01257   
01258   if (abpInfer || aoutputNetwork)
01259   {
01260     factorGraph = new FactorGraph(bpparams->lifted, mln, domain, queryFormulas);
01261     inference = new BP(factorGraph, bpparams, queryFormulas);
01262   }
01263   else
01264   {
01265       // Create inference algorithm and state based on queries and mln / domain
01266     bool markHardGndClauses = true;
01267     bool trackParentClauseWts = false;
01268     if (aHybrid)
01269     {
01270         // Create inference algorithm and state based on queries and mln / domain
01271       bool markHardGndClauses = true;
01272       bool trackParentClauseWts = false;
01273           hstate = new HVariableState(&queries, NULL, NULL,
01274                                           &allPredGndingsAreQueries,
01275                                           markHardGndClauses,
01276                                       trackParentClauseWts,
01277                                           mln, domain, aLazy);
01278 
01279           hstate->LoadContGroundedMLN(aContPartFile);
01280           hstate->WriteGndPredIdxMap(aGndPredIdxMapFile);
01281           hstate->bMaxOnly_ = amapPos;
01282     }
01283     else
01284     {
01285         // Lazy version: queries and allPredGndingsAreQueries are empty,
01286         // markHardGndClause and trackParentClauseWts are not used
01287       state = new VariableState(&queries, NULL, NULL,
01288                                 &allPredGndingsAreQueries,
01289                                 markHardGndClauses,
01290                                 trackParentClauseWts,
01291                                 mln, domain, aLazy);
01292     }
01293     
01294     if (aGenRandom && aHybrid)
01295     {
01296           hstate->setProposalStdev(aProposalStdev);
01297           hstate->initRandom();
01298           //hstate->printContAtoms()
01299           hstate->saveLowStateAll();
01300           ofstream oscont(atestcont), osdis(atestdis);
01301           hstate->printLowStateCont(oscont);
01302           hstate->printLowState(osdis);
01303           return 0;
01304     }
01305 
01306     if (aStartPt && aHybrid)
01307     {
01308           hstate->LoadDisEviValuesFromRst(atestdis);
01309           hstate->LoadContEviValues(atestcont);
01310     }
01311 
01312     if (aStartPt && !aHybrid)
01313     {
01314           state->LoadDisEviValuesFromRst(atestdis);
01315     }
01316     
01317       // MAP inference, MC-SAT, Gibbs or Sim. Temp.
01318     if (amapPos || amapAll || amcsatInfer || agibbsInfer || asimtpInfer ||
01319         aHybrid || aSA)
01320     {
01321       if ((amapPos || amapAll) && !aHybrid)
01322       { // MaxWalkSat
01323           // When standalone MWS, numSolutions is always 1
01324           // (maybe there is a better way to do this?)
01325         mwsparams->numSolutions = 1;
01326         inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts, mwsparams);
01327       }
01328       else if (amapPos && aHybrid && !amcsatInfer && !aSA)
01329           {
01330             hstate->setProposalStdev(aProposalStdev);
01331                   // maximizing all the numeric terms individually by l-bfgs, 
01332                   // and cache the optimal solution for each term
01333                 hstate->optimizeIndividualNumTerm();
01334                 mwsparams->numSolutions = 1;
01335                 mwsparams->heuristic = HMWS;
01336                 inference = new HMaxWalkSat(hstate, aSeed, trackClauseTrueCnts,
01337                                     mwsparams);         
01338                 HMaxWalkSat* p = (HMaxWalkSat*)inference;
01339                 p->SetNoisePra(anumerator, adenominator);
01340                 if (aMaxSeconds)
01341                 {
01342                   p->SetMaxSeconds(atof(aMaxSeconds));
01343                 }
01344                 
01345                 if (aStartPt)
01346                 {
01347                   hstate->setInitFromEvi(true);
01348                 }       
01349           }
01350           else if (amapPos && aHybrid && !amcsatInfer && aSA)
01351           {
01352                 hstate->setProposalStdev(aProposalStdev);
01353                 mwsparams->numSolutions = 1;
01354                 mwsparams->heuristic = HSA;
01355                 inference = new HMaxWalkSat(hstate, aSeed, trackClauseTrueCnts,
01356                                     mwsparams);
01357 
01358                 HMaxWalkSat* p = (HMaxWalkSat*)inference;
01359                 //p->SetNoisePra(anumerator, adenominator);
01360                 p->setHeuristic(HSA);
01361                 p->setSATempDownRatio(aSATempDownRatio);
01362                 p->SetMaxSeconds(atof(aMaxSeconds));
01363                 p->SetSAInterval(saInterval);
01364                 if (aStartPt)
01365                 {
01366           hstate->setInitFromEvi(true);
01367                 }
01368           }
01369           else if (aHybrid && amcsatInfer)
01370           {
01371                 cout << "Creating HMCSAT instance." << endl;
01372                 hstate->setProposalStdev(aProposalStdev);
01373                 if (aContSamples == NULL)
01374                 {
01375                         cout << "Numeric sample file is error." <<endl;
01376                 }
01377                 inference = new HMCSAT(hstate, aSeed, trackClauseTrueCnts, msparams);
01378                 HMCSAT* p = (HMCSAT*) inference;
01379                 p->SetContSampleFile(aContSamples);
01380         p->SetPrintVarsPerSample(aPrintSamplePerIteration);
01381           }
01382       else if (amcsatInfer && !aHybrid)
01383       { // MC-SAT
01384         inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams,
01385                               queryFormulas);
01386       }
01387       else if (asimtpInfer && !aHybrid)
01388       { // Simulated Tempering
01389           // When MWS is used in Sim. Temp., numSolutions is always 1
01390           // (maybe there is a better way to do this?)
01391         mwsparams->numSolutions = 1;
01392         inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
01393                                            stparams);
01394       }
01395       else if (agibbsInfer && !aHybrid)
01396       { // Gibbs sampling
01397           // When MWS is used in Gibbs, numSolutions is always 1
01398           // (maybe there is a better way to do this?)
01399         mwsparams->numSolutions = 1;
01400         inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
01401                                      gibbsparams);
01402       }
01403     }
01404   }
01405   return 1;
01406 }
01407  
01408   // Typedefs
01409 typedef hash_map<string, const Array<const char*>*, HashString, EqualString> 
01410 StringToStrArrayMap;
01411 
01412 #endif

Generated on Sun Jun 7 11:55:12 2009 for Alchemy by  doxygen 1.5.1