infer.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef _INFER_H_OCT_30_2005
00067 #define _INFER_H_OCT_30_2005
00068 
00073 #include "util.h"
00074 #include "fol.h"
00075 #include "mrf.h"
00076 #include "learnwts.h"
00077 #include "inferenceargs.h"
00078 #include "maxwalksat.h"
00079 #include "mcsat.h"
00080 #include "gibbssampler.h"
00081 #include "simulatedtempering.h"
00082 
00083 // Variables for holding inference command line args are in inferenceargs.h
00084 
00085 char* aevidenceFiles  = NULL;
00086 char* aresultsFile    = NULL;
00087 char* aqueryPredsStr  = NULL;
00088 char* aqueryFile      = NULL;
00089 
00090 string queryPredsStr, queryFile;
00091 GroundPredicateHashArray queries;
00092 GroundPredicateHashArray knownQueries;
00093 
00125 bool createComLineQueryPreds(const string& queryPredsStr,
00126                              const Domain* const & domain,
00127                              Database* const & db,
00128                              GroundPredicateHashArray* const & queries,
00129                              GroundPredicateHashArray* const & knownQueries,
00130                              Array<int>* const & allPredGndingsAreQueries,
00131                              bool printToFile, ostream& out, bool amapPos,
00132                             const GroundPredicateHashArray* const & trueQueries,
00133                              const Array<double>* const & trueProbs)
00134 {
00135   if (queryPredsStr.length() == 0) return true;
00136   string preds = Util::trim(queryPredsStr);
00137 
00138     //replace the comma between query predicates with '\n'
00139   int balparen = 0;
00140   for (unsigned int i = 0; i < preds.length(); i++)
00141   {
00142     if (preds.at(i)=='(')                     balparen++;
00143     else if (preds.at(i)==')')                balparen--;
00144     else if (preds.at(i)==',' && balparen==0) preds.at(i)='\n';
00145   }
00146 
00147   bool onlyPredName;
00148   bool ret = true;
00149   unsigned int cur;
00150   int termId, varIdCnt = 0;
00151   hash_map<string, int, HashString, EqualString> varToId;
00152   hash_map<string, int, HashString, EqualString>::iterator it;
00153   Array<VarsTypeId*>* vtiArr;
00154   string pred, predName, term;
00155   const PredicateTemplate* ptemplate;
00156   istringstream iss(preds);
00157   char delimit[2]; delimit[1] = '\0';
00158 
00159     // for each query pred on command line
00160   while (getline(iss, pred))
00161   {
00162     onlyPredName = false;
00163     varToId.clear();
00164     varIdCnt = 0;
00165     cur = 0;
00166 
00167       // get predicate name
00168     if (!Util::substr(pred,cur,predName,"("))
00169     {
00170       predName = pred;
00171       onlyPredName = true;
00172     }
00173     
00174       // Predicate must be in the domain
00175     ptemplate = domain->getPredicateTemplate(predName.c_str());
00176     if (ptemplate == NULL)
00177     {
00178       cout << "ERROR: Cannot find command line query predicate" << predName 
00179            << " in domain." << endl;
00180       ret = false;
00181       continue;
00182     }
00183     Predicate ppred(ptemplate);
00184 
00185       // if the terms of the query predicate are also specified
00186     if (!onlyPredName)
00187     {
00188         // get term name
00189       for (int i = 0; i < 2; i++)
00190       {       
00191         if (i == 0) delimit[0] = ',';
00192         else        delimit[0] = ')';
00193         while(Util::substr(pred, cur, term, delimit))
00194         {
00195             // this is a constant
00196           if (isupper(term.at(0)) || term.at(0) == '"')
00197           {
00198             termId = domain->getConstantId(term.c_str());
00199             if (termId < 0) 
00200             {
00201               cout <<"ERROR: Cannot find constant "<<term<<" in database"<<endl;
00202               ret = false;
00203             }        
00204           }
00205           else
00206           {   // it is a variable        
00207             if ((it=varToId.find(term)) == varToId.end()) 
00208             {
00209               termId = --varIdCnt;
00210               varToId[term] = varIdCnt; 
00211             }
00212             else
00213               termId = (*it).second;
00214           }
00215           ppred.appendTerm(new Term(termId, (void*)&ppred, true));
00216         }
00217       }
00218     }
00219     else
00220     {   // if only the predicate name is specified
00221       (*allPredGndingsAreQueries)[ptemplate->getId()] = true;
00222       for (int i = 0; i < ptemplate->getNumTerms(); i++)
00223         ppred.appendTerm(new Term(--varIdCnt, (void*)&ppred, true));
00224     }  
00225 
00226       // Check if number of terms is correct
00227     if (ppred.getNumTerms() != ptemplate->getNumTerms())
00228     {
00229       cout << "ERROR: " << predName << " requires " << ptemplate->getNumTerms()
00230            << " terms but given " << ppred.getNumTerms() << endl;
00231       ret = false;
00232     }
00233     if (!ret) continue;
00234 
00235     
00237     vtiArr = NULL;
00238     ppred.createVarsTypeIdArr(vtiArr);
00239 
00240       // If a ground predicate was specified on command line
00241     if (vtiArr->size() <= 1)
00242     {
00243       assert(ppred.isGrounded());
00244       assert(!db->isClosedWorld(ppred.getId()));
00245       TruthValue tv = db->getValue(&ppred);
00246       GroundPredicate* gndPred = new GroundPredicate(&ppred);
00247       
00248         // If just printing to file, then all values must be known
00249       if (printToFile) assert(tv != UNKNOWN);
00250       if (tv == UNKNOWN)
00251       {
00252         if (queries->append(gndPred) < 0) delete gndPred;
00253       }
00254       else
00255       {
00256           // If just printing to file
00257         if (printToFile)
00258         {
00259             // If trueQueries is given as argument, then get prob. from there
00260           if (trueQueries)
00261           {
00262             double prob = 0.0;
00263             if (domain->getDB()->getEvidenceStatus(&ppred))
00264             {
00265                 // Don't print out evidence atoms
00266               continue;
00267               //prob = (tv == TRUE) ? 1.0 : 0.0;
00268             }
00269             else
00270             {
00271               int found = trueQueries->find(gndPred);
00272               if (found >= 0) prob = (*trueProbs)[found];
00273               else
00274                   // Uniform smoothing
00275                 prob = (prob*10000+1/2.0)/(10000+1.0);
00276               
00277             }
00278             gndPred->print(out, domain); out << " " << prob << endl;
00279           }
00280           else
00281           {
00282             if (amapPos) //if show postive ground query predicates only
00283             {
00284                   if (tv == TRUE)
00285                   {
00286                     ppred.printWithStrVar(out, domain);
00287                     out << endl;
00288                   }
00289             }
00290             else //print all ground query predicates
00291             {
00292               ppred.printWithStrVar(out, domain);
00293               out << " " << tv << endl;
00294             }
00295           }
00296           delete gndPred;
00297         }
00298         else // Building queries for HashArray
00299         {
00300           //if (tv == TRUE) gndPred->setProbTrue(1);
00301           //else            gndPred->setProbTrue(0);
00302 
00303           if (knownQueries->append(gndPred) < 0) delete gndPred;  
00304         }
00305       }      
00306     }
00307     else // Variables need to be grounded
00308     {
00309       ArraysAccessor<int> acc;
00310       for (int i = 1; i < vtiArr->size(); i++)
00311       {
00312         const Array<int>* cons=domain->getConstantsByType((*vtiArr)[i]->typeId);
00313         acc.appendArray(cons);
00314       } 
00315 
00316         // form all groundings of the predicate
00317       Array<int> constIds;
00318       while (acc.getNextCombination(constIds))
00319       {
00320         assert(constIds.size() == vtiArr->size()-1);
00321         for (int j = 0; j < constIds.size(); j++)
00322         {
00323           Array<Term*>& terms = (*vtiArr)[j+1]->vars;
00324           for (int k = 0; k < terms.size(); k++)
00325             terms[k]->setId(constIds[j]);
00326         }
00327 
00328         // at this point the predicate is grounded
00329         assert(!db->isClosedWorld(ppred.getId()));
00330 
00331         TruthValue tv = db->getValue(&ppred);        
00332         GroundPredicate* gndPred = new GroundPredicate(&ppred);
00333 
00334           // If just printing to file, then all values must be known
00335             if (printToFile) assert(tv != UNKNOWN);
00336         if (tv == UNKNOWN)
00337         {
00338           if (queries->append(gndPred) < 0) delete gndPred;
00339         }
00340         else
00341         {
00342                 // If just printing to file
00343           if (printToFile)
00344           {
00345               // If trueQueries is given as argument, then get prob. from there
00346             if (trueQueries)
00347             {
00348               double prob = 0.0;
00349               if (domain->getDB()->getEvidenceStatus(&ppred))
00350               {
00351                   // Don't print out evidence atoms
00352                 continue;
00353                 //prob = (tv == TRUE) ? 1.0 : 0.0;
00354               }
00355               else
00356               {
00357                 int found = trueQueries->find(gndPred);
00358                 if (found >= 0) prob = (*trueProbs)[found];
00359                 else
00360                     // Uniform smoothing
00361                   prob = (prob*10000+1/2.0)/(10000+1.0);
00362               }
00363                 // Uniform smoothing
00364               //prob = (prob*10000+1/2.0)/(10000+1.0);
00365               gndPred->print(out, domain); out << " " << prob << endl;
00366             }
00367             else
00368             {
00369               if (amapPos) //if show postive ground query predicates only
00370               {
00371                 if (tv == TRUE)
00372                 {
00373                   ppred.printWithStrVar(out, domain);
00374                   out << endl;
00375                 }
00376               }
00377               else //print all ground query predicates
00378               {
00379                 ppred.printWithStrVar(out, domain);
00380                 out << " " << tv << endl;
00381               }
00382             }
00383             delete gndPred;          
00384           }
00385           else // Building queries
00386           {
00387                 //if (tv == TRUE) gndPred->setProbTrue(1);
00388                 //else            gndPred->setProbTrue(0);
00389                 if (knownQueries->append(gndPred) < 0) delete gndPred;  
00390           }
00391         }        
00392       }
00393     }
00394       
00395     ppred.deleteVarsTypeIdArr(vtiArr);
00396   } // for each query pred on command line
00397 
00398   if (!printToFile)
00399   {
00400         queries->compress();
00401         knownQueries->compress();
00402   }
00403   
00404   return ret;
00405 }
00406 
00411 bool createComLineQueryPreds(const string& queryPredsStr,
00412                              const Domain* const & domain,
00413                              Database* const & db,
00414                              GroundPredicateHashArray* const & queries,
00415                              GroundPredicateHashArray* const & knownQueries,
00416                              Array<int>* const & allPredGndingsAreQueries)
00417 {
00418   return createComLineQueryPreds(queryPredsStr, domain, db,
00419                                  queries, knownQueries,
00420                                  allPredGndingsAreQueries,
00421                                  false, cout, false, NULL, NULL);
00422 }
00423 
00434 bool extractPredNames(string preds, const string* queryFile, 
00435                       StringHashArray& predNames)
00436 { 
00437   predNames.clear();
00438 
00439     // first extract the query pred names specified on command line
00440   string::size_type cur = 0, ws, ltparen;
00441   string qpred, predName;
00442   
00443   if (preds.length() > 0)
00444   {
00445     preds.append(" "); //terminate preds with a whitespace
00446     
00447       //replace the comma between query predicates with ' '
00448     int balparen = 0;
00449     for (unsigned int i = 0; i < preds.length(); i++)
00450     {
00451       if (preds.at(i) == '(')      balparen++;
00452       else if (preds.at(i) == ')') balparen--;
00453       else if (preds.at(i) == ',' && balparen == 0) preds.at(i) = ' ';
00454     }
00455     
00456     while (preds.at(cur) == ' ') cur++;
00457     while (true)
00458     {
00459       ws = preds.find(" ", cur);
00460       if (ws == string::npos) break;
00461       qpred = preds.substr(cur,ws-cur+1);
00462       cur = ws+1;
00463       while (cur < preds.length() && preds.at(cur) == ' ') cur++;
00464       ltparen = qpred.find("(",0);
00465       
00466       if (ltparen == string::npos) 
00467       { 
00468         ws = qpred.find(" ");
00469         if (ws != string::npos) qpred = qpred.substr(0,ws);
00470         predName = qpred; 
00471       }
00472       else
00473         predName = qpred.substr(0,ltparen);
00474       
00475       predNames.append(predName);
00476     }
00477   }
00478 
00479   if (queryFile == NULL || queryFile->length() == 0) return true;
00480 
00481     // next extract query predicates specified in query file  
00482   ifstream in((*queryFile).c_str());
00483   if (!in.good())
00484   {
00485     cout << "ERROR: unable to open " << *queryFile << endl;
00486     return false;
00487   }
00488   string buffer;
00489   while (getline(in, buffer))
00490   {
00491     cur = 0;
00492     while (cur < buffer.length() && buffer.at(cur) == ' ') cur++;
00493     ltparen = buffer.find("(", cur);
00494     if (ltparen == string::npos) continue;
00495     predName = buffer.substr(cur, ltparen-cur);
00496     predNames.append(predName);
00497   }
00498 
00499   in.close();
00500   return true;
00501 }
00502 
00507 char getTruthValueFirstChar(const TruthValue& tv)
00508 {
00509   if (tv == TRUE)    return 'T';
00510   if (tv == FALSE)   return 'F';
00511   if (tv == UNKNOWN) return 'U';
00512   assert(false);
00513   exit(-1);
00514   return '#'; //avoid compilation warning
00515 }
00516 
00520 void setsrand()
00521 {
00522   struct timeval tv;
00523   struct timezone tzp;
00524   gettimeofday(&tv,&tzp);
00525   unsigned int seed = (( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec;
00526   srand(seed);
00527 } 
00528 
00529 
00530   //copy srcFile to dstFile, & append '#include "dbFiles"' to latter
00531 void copyFileAndAppendDbFile(const string& srcFile, string& dstFile, 
00532                              const Array<string>& dbFilesArr,
00533                              const Array<string>& constFilesArr)
00534 {
00535   ofstream out(dstFile.c_str());
00536   ifstream in(srcFile.c_str());
00537   if (!out.good()) { cout<<"ERROR: failed to open "<<dstFile<<endl;exit(-1);}
00538   if (!in.good()) { cout<<"ERROR: failed to open "<<srcFile<<endl;exit(-1);}
00539   
00540   string buffer;
00541   while(getline(in, buffer)) out << buffer << endl;
00542   in.close();
00543 
00544   out << endl;
00545   for (int i = 0; i < constFilesArr.size(); i++) 
00546     out << "#include \"" << constFilesArr[i] << "\"" << endl;
00547   out << endl;
00548   for (int i = 0; i < dbFilesArr.size(); i++)    
00549     out << "#include \"" << dbFilesArr[i] << "\"" << endl;
00550   out.close();
00551 }
00552 
00553 
00554 bool checkQueryPredsNotInClosedWorldPreds(const StringHashArray& qpredNames,
00555                                           const StringHashArray& cwPredNames)
00556 {
00557   bool ok = true;
00558   for (int i = 0; i < qpredNames.size(); i++)
00559     if (cwPredNames.contains(qpredNames[i]))
00560     {
00561       cout << "ERROR: query predicate " << qpredNames[i] 
00562            << " cannot be specified as closed world" << endl; 
00563       ok = false;
00564     }
00565   return ok;
00566 }
00567 
00596 bool createQueryFilePreds(const string& queryFile,
00597                           const Domain* const & domain,
00598                           Database* const & db,
00599                           GroundPredicateHashArray* const &queries,
00600                           GroundPredicateHashArray* const &knownQueries,
00601                           bool printToFile, ostream& out, bool amapPos,
00602                           const GroundPredicateHashArray* const &trueQueries,
00603                           const Array<double>* const & trueProbs)
00604 {
00605   if (queryFile.length() == 0) return true;
00606 
00607   bool ret = true;
00608   ifstream in(queryFile.c_str());
00609   unsigned int line = 0;
00610   unsigned int cur;
00611   int constId, predId;
00612   bool ok;
00613   string predStr, predName, constant;
00614   Array<int> constIds;
00615   const PredicateTemplate* ptemplate;
00616 
00617   while (getline(in, predStr))
00618   {
00619     line++;
00620     cur = 0;
00621 
00622       // get predicate name
00623     ok = Util::substr(predStr, cur, predName, "(");
00624     if (!ok) continue;
00625 
00626     predId = domain->getPredicateId(predName.c_str());
00627     ptemplate = domain->getPredicateTemplate(predId);
00628 
00629     if (predId < 0 || ptemplate == NULL)
00630     {
00631       cout << "ERROR: Cannot find " << predName << " in domain on line " 
00632            << line << " of query file " << queryFile << endl;
00633       ret = false;
00634       continue;
00635     }
00636 
00637       // get constant name
00638     constIds.clear();
00639     while (Util::substr(predStr, cur, constant, ","))
00640     {
00641       constId = domain->getConstantId(constant.c_str());
00642       constIds.append(constId);
00643       if (constId < 0)
00644       {
00645         cout << "ERROR: Cannot find " << constant << " in database on line " 
00646              << line << " of query file " << queryFile << endl;
00647         ret = false;
00648       }
00649     }
00650 
00651       // get constant name
00652     ok = Util::substr(predStr, cur, constant, ")"); 
00653     if (!ok)
00654     {
00655       cout << "ERROR: Failed to parse ground predicate on line " << line
00656            << " of query file " << queryFile << endl;
00657       ret = false;
00658       continue;
00659     }
00660 
00661     constId = domain->getConstantId(constant.c_str());
00662     constIds.append(constId);
00663     if (constId < 0)
00664     {
00665       cout << "ERROR: Cannot find " << constant << " in database on line " 
00666            << line << " of query file " << queryFile << endl;
00667       ret = false;
00668     }
00669 
00670     if (!ret) continue;
00672     
00673       // create Predicate and set its truth value to UNKNOWN
00674     if (constIds.size() != ptemplate->getNumTerms())
00675     {
00676       cout << "ERROR: incorrect number of terms for " << predName 
00677            << ". Expected " << ptemplate->getNumTerms() << ", given " 
00678            << constIds.size() << endl;
00679       ret = false;
00680       continue;
00681     }
00682     
00683     Predicate pred(ptemplate);
00684     for (int i = 0; i < constIds.size(); i++)
00685     {
00686       if (pred.getTermTypeAsInt(i) != domain->getConstantTypeId(constIds[i]))
00687       {
00688         cout << "ERROR: wrong type for term " 
00689              << domain->getConstantName(constIds[i]) << " for predicate " 
00690              << predName  << " on line " << line << " of query file " 
00691              << queryFile << endl;
00692         ret = false;
00693         continue;
00694       }
00695       pred.appendTerm(new Term(constIds[i], (void*)&pred, true));
00696     }
00697     if (!ret) continue;
00698 
00699     assert(!db->isClosedWorld(predId));
00700 
00701     TruthValue tv = db->getValue(&pred);
00702     GroundPredicate* gndPred = new GroundPredicate(&pred);
00703     
00704     // If just printing to file, then all values must be known
00705     if (printToFile) assert(tv != UNKNOWN);
00706     if (tv == UNKNOWN)
00707     {
00708       if (queries->append(gndPred) < 0) delete gndPred;
00709     }
00710     else
00711     {
00712         // If just printing to file
00713       if (printToFile)
00714       {
00715 
00716           // If trueQueries is given as argument, then get prob. from there
00717         if (trueQueries)
00718         {
00719           double prob = 0.0;
00720           if (domain->getDB()->getEvidenceStatus(&pred))
00721           {
00722               // Don't print out evidence atoms
00723             continue;
00724             //prob = (tv == TRUE) ? 1.0 : 0.0;
00725           }
00726           else
00727           {
00728             int found = trueQueries->find(gndPred);
00729             if (found >= 0) prob = (*trueProbs)[found];
00730             else
00731                 // Uniform smoothing
00732               prob = (prob*10000+1/2.0)/(10000+1.0);            
00733           }
00734           gndPred->print(out, domain); out << " " << prob << endl;
00735         }
00736         else
00737         {
00738           if (amapPos) //if show postive ground query predicates only
00739           {
00740             if (tv == TRUE)
00741             {
00742               pred.printWithStrVar(out, domain);
00743               out << endl;
00744             }
00745           }
00746           else //print all ground query predicates
00747           {
00748             pred.printWithStrVar(out, domain);
00749             out << " " << tv << endl;
00750           }
00751         }
00752         delete gndPred;
00753       }
00754       else // Building queries
00755       {
00756         //if (tv == TRUE) gndPred->setProbTrue(1);
00757         //else            gndPred->setProbTrue(0);
00758         if (knownQueries->append(gndPred) < 0) delete gndPred;
00759       }
00760     }
00761   } // while (getline(in, predStr))
00762 
00763   in.close();
00764   return ret;
00765 }
00766 
00771 bool createQueryFilePreds(const string& queryFile, const Domain* const & domain,
00772                           Database* const & db,
00773                           GroundPredicateHashArray* const &queries,
00774                           GroundPredicateHashArray* const &knownQueries)
00775 {
00776   return createQueryFilePreds(queryFile, domain, db, queries, knownQueries,
00777                               false, cout, false, NULL, NULL);
00778 }
00779 
00780 void readPredValuesAndSetToUnknown(const StringHashArray& predNames,
00781                                    Domain *domain,
00782                                    Array<Predicate *> &queryPreds,
00783                                    Array<TruthValue> &queryPredValues,
00784                                    bool isQueryEvidence)
00785 {
00786   Array<Predicate*> ppreds;
00787 
00788     //cout << endl << "Getting the counts for the domain " << i << endl;
00789   queryPreds.clear();
00790   queryPredValues.clear();
00791   for (int predno = 0; predno < predNames.size(); predno++) 
00792   {
00793     ppreds.clear();
00794     int predid = domain->getPredicateId(predNames[predno].c_str());
00795     Predicate::createAllGroundings(predid, domain, ppreds);
00796     for (int gpredno = 0; gpredno < ppreds.size(); gpredno++)
00797     {
00798       Predicate *pred = ppreds[gpredno];
00799       TruthValue tv = domain->getDB()->getValue(pred);
00800       if (tv == UNKNOWN)
00801         domain->getDB()->setValue(pred,FALSE);
00802           
00803         // if first order query pred groundings are allowed to be evidence
00804         // - we assume all the predicates not in db to be false
00805         // evidence - need a better way code this.
00806       if (isQueryEvidence && tv == UNKNOWN)
00807         delete pred;
00808       else
00809         queryPreds.append(pred);
00810     }
00811   }
00812     //set all the query preds to unknown, reading in the TRUE/FALSE status
00813     //for verification at a later time
00814   domain->getDB()->setValuesToUnknown(&queryPreds, &queryPredValues);
00815 }
00816 
00827 void setPredsToGivenValues(const StringHashArray& predNames, Domain *domain,
00828                            Array<TruthValue> &gpredValues)
00829 {
00830   Array<Predicate*> gpreds;
00831   Array<Predicate*> ppreds;
00832   Array<TruthValue> tmpValues;
00833     
00834     //cout << endl << "Getting the counts for the domain " << i << endl;
00835   gpreds.clear();
00836   tmpValues.clear();
00837   for (int predno = 0; predno < predNames.size(); predno++)
00838   {
00839     ppreds.clear();
00840     int predid = domain->getPredicateId(predNames[predno].c_str());
00841     Predicate::createAllGroundings(predid, domain, ppreds);
00842       //cout<<"size of gnd for pred " << predid << " = "<<ppreds.size()<<endl;
00843     gpreds.append(ppreds);
00844   }
00845   
00846   domain->getDB()->setValuesToGivenValues(&gpreds, &gpredValues);
00847   for (int gpredno = 0; gpredno < gpreds.size(); gpredno++)
00848     delete gpreds[gpredno];
00849 }
00850 
00851 
00858 int buildInference(Inference*& inference, Domain*& domain)
00859 {
00860   string inMLNFile, wkMLNFile, evidenceFile;
00861 
00862   StringHashArray queryPredNames;
00863   StringHashArray owPredNames;
00864   StringHashArray cwPredNames;
00865   MLN* mln = NULL;
00866   Array<string> constFilesArr;
00867   Array<string> evidenceFilesArr;
00868 
00869   Array<Predicate *> queryPreds;
00870   Array<TruthValue> queryPredValues;
00871   
00872   //the second .mln file to the last one in ainMLNFiles _may_ be used 
00873   //to hold constants, so they are held in constFilesArr. They will be
00874   //included into the first .mln file.
00875 
00876     //extract .mln, .db file names
00877   extractFileNames(ainMLNFiles, constFilesArr);
00878   assert(constFilesArr.size() >= 1);
00879   inMLNFile.append(constFilesArr[0]);
00880   constFilesArr.removeItem(0);
00881   extractFileNames(aevidenceFiles, evidenceFilesArr);
00882   
00883   if (aqueryPredsStr) queryPredsStr.append(aqueryPredsStr);
00884   if (aqueryFile) queryFile.append(aqueryFile);
00885 
00886   if (queryPredsStr.length() == 0 && queryFile.length() == 0)
00887   { cout << "No query predicates specified" << endl; return -1; }
00888 
00889   if (agibbsInfer && amcmcNumChains < 2) 
00890   {
00891     cout << "ERROR: there must be at least 2 MCMC chains in Gibbs sampling" 
00892          << endl; return -1;
00893   }
00894 
00895   if (!asimtpInfer && !amapPos && !amapAll && !agibbsInfer && !amcsatInfer)
00896   {
00897     cout << "ERROR: must specify one of -ms/-simtp/-m/-a/-p flags." << endl;
00898     return -1;
00899   }
00900 
00901     //extract names of all query predicates
00902   if (queryPredsStr.length() > 0 || queryFile.length() > 0)
00903   {
00904     if (!extractPredNames(queryPredsStr, &queryFile, queryPredNames)) return -1;
00905   }
00906 
00907   if (amwsMaxSteps <= 0)
00908   { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
00909 
00910   if (amwsTries <= 0)
00911   { cout << "ERROR: mwsTries must be positive" << endl; return -1; }
00912 
00913     //extract names of open-world evidence predicates
00914   if (aOpenWorldPredsStr)
00915   {
00916     if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames)) 
00917       return -1;
00918     assert(owPredNames.size() > 0);
00919   }
00920 
00921     //extract names of closed-world non-evidence predicates
00922   if (aClosedWorldPredsStr)
00923   {
00924     if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames)) 
00925       return -1;
00926     assert(cwPredNames.size() > 0);
00927     if (!checkQueryPredsNotInClosedWorldPreds(queryPredNames, cwPredNames))
00928       return -1;
00929   }
00930 
00931   // TODO: Check if query atom in -o -> error
00932 
00933   // TODO: Check if atom in -c and -o -> error
00934 
00935 
00936   // TODO: Check if ev. atom in -c or
00937   // non-evidence in -o -> warning (this is default)
00938 
00939 
00940     // Set SampleSat parameters
00941   SampleSatParams* ssparams = new SampleSatParams;
00942   ssparams->lateSa = assLateSa;
00943   ssparams->saRatio = assSaRatio;
00944   ssparams->saTemp = assSaTemp;
00945 
00946     // Set MaxWalksat parameters
00947   MaxWalksatParams* mwsparams = new MaxWalksatParams;
00948   mwsparams->ssParams = ssparams;
00949   mwsparams->maxSteps = amwsMaxSteps;
00950   mwsparams->maxTries = amwsTries;
00951   mwsparams->targetCost = amwsTargetWt;
00952   mwsparams->hard = amwsHard;
00953     // numSolutions only applies when used in SampleSat.
00954     // When just MWS, this is set to 1
00955   mwsparams->numSolutions = amwsNumSolutions;
00956   mwsparams->heuristic = amwsHeuristic;
00957   mwsparams->tabuLength = amwsTabuLength;
00958   mwsparams->lazyLowState = amwsLazyLowState;
00959 
00960     // Set MC-SAT parameters
00961   MCSatParams* msparams = new MCSatParams;
00962   msparams->mwsParams = mwsparams;
00963     // MC-SAT needs only one chain
00964   msparams->numChains          = 1;
00965   msparams->burnMinSteps       = amcmcBurnMinSteps;
00966   msparams->burnMaxSteps       = amcmcBurnMaxSteps;
00967   msparams->minSteps           = amcmcMinSteps;
00968   msparams->maxSteps           = amcmcMaxSteps;
00969   msparams->maxSeconds         = amcmcMaxSeconds;
00970   msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
00971 
00972     // Set Gibbs parameters
00973   GibbsParams* gibbsparams = new GibbsParams;
00974   gibbsparams->mwsParams    = mwsparams;
00975   gibbsparams->numChains    = amcmcNumChains;
00976   gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00977   gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00978   gibbsparams->minSteps     = amcmcMinSteps;
00979   gibbsparams->maxSteps     = amcmcMaxSteps;
00980   gibbsparams->maxSeconds   = amcmcMaxSeconds;
00981 
00982   gibbsparams->gamma          = 1 - agibbsDelta;
00983   gibbsparams->epsilonError   = agibbsEpsilonError;
00984   gibbsparams->fracConverged  = agibbsFracConverged;
00985   gibbsparams->walksatType    = agibbsWalksatType;
00986   gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00987   
00988     // Set Sim. Tempering parameters
00989   SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00990   stparams->mwsParams    = mwsparams;
00991   stparams->numChains    = amcmcNumChains;
00992   stparams->burnMinSteps = amcmcBurnMinSteps;
00993   stparams->burnMaxSteps = amcmcBurnMaxSteps;
00994   stparams->minSteps     = amcmcMinSteps;
00995   stparams->maxSteps     = amcmcMaxSteps;
00996   stparams->maxSeconds   = amcmcMaxSeconds;
00997 
00998   stparams->subInterval = asimtpSubInterval;
00999   stparams->numST       = asimtpNumST;
01000   stparams->numSwap     = asimtpNumSwap;
01001 
01003 
01004   cout << "Reading formulas and evidence predicates..." << endl;
01005 
01006     // Copy inMLNFile to workingMLNFile & app '#include "evid.db"'
01007   string::size_type bslash = inMLNFile.rfind("/");
01008   string tmp = (bslash == string::npos) ? 
01009                inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
01010   char buf[100];
01011   sprintf(buf, "%s%s", tmp.c_str(), ZZ_TMP_FILE_POSTFIX);
01012   wkMLNFile = buf;
01013   copyFileAndAppendDbFile(inMLNFile, wkMLNFile,
01014                           evidenceFilesArr, constFilesArr);
01015 
01016     // Parse wkMLNFile, and create the domain, MLN, database
01017   domain = new Domain;
01018   mln = new MLN();
01019   bool addUnitClauses = false;
01020   bool mustHaveWtOrFullStop = true;
01021   bool warnAboutDupGndPreds = true;
01022   bool flipWtsOfFlippedClause = true;
01023   //bool allPredsExceptQueriesAreCW = true;
01024   bool allPredsExceptQueriesAreCW = owPredNames.empty();
01025   Domain* forCheckingPlusTypes = NULL;
01026 
01027     // Parse as if lazy inference is set to true to set evidence atoms in DB
01028     // If lazy is not used, this is removed from DB
01029   if (!runYYParser(mln, domain, wkMLNFile.c_str(), allPredsExceptQueriesAreCW, 
01030                    &owPredNames, &queryPredNames, addUnitClauses, 
01031                    warnAboutDupGndPreds, 0, mustHaveWtOrFullStop, 
01032                    forCheckingPlusTypes, true, flipWtsOfFlippedClause))
01033   {
01034     unlink(wkMLNFile.c_str());
01035     return -1;
01036   }
01037 
01038   unlink(wkMLNFile.c_str());
01039 
01041 
01043   Array<int> allPredGndingsAreQueries;
01044 
01045     // Eager inference: Build the queries for the mrf
01046     // Lazy version evaluates the query string / file when printing out
01047   if (!aLazy)
01048   {
01049     if (queryFile.length() > 0)
01050     {
01051       cout << "Reading query predicates that are specified in query file..."
01052            << endl;
01053       bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(),
01054                                      &queries, &knownQueries);
01055       if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
01056     }
01057 
01058     allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
01059     if (queryPredsStr.length() > 0)
01060     {
01061       cout << "Creating query predicates that are specified on command line..." 
01062            << endl;
01063       bool ok = createComLineQueryPreds(queryPredsStr, domain, domain->getDB(), 
01064                                         &queries, &knownQueries, 
01065                                         &allPredGndingsAreQueries);
01066       if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
01067     }
01068   }
01069 
01070     // Create inference algorithm and state based on queries and mln / domain
01071   bool markHardGndClauses = false;
01072   bool trackParentClauseWts = false;
01073     // Lazy version: queries and allPredGndingsAreQueries are empty,
01074     // markHardGndClause and trackParentClauseWts are not used
01075   VariableState* state = new VariableState(&queries, NULL, NULL,
01076                                            &allPredGndingsAreQueries,
01077                                            markHardGndClauses,
01078                                            trackParentClauseWts,
01079                                            mln, domain, aLazy);
01080   bool trackClauseTrueCnts = false;
01081     // MAP inference, MC-SAT, Gibbs or Sim. Temp.
01082   if (amapPos || amapAll || amcsatInfer || agibbsInfer || asimtpInfer)
01083   {
01084     if (amapPos || amapAll)
01085     { // MaxWalkSat
01086         // When standalone MWS, numSolutions is always 1
01087         // (maybe there is a better way to do this?)
01088       mwsparams->numSolutions = 1;
01089       inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts, mwsparams);
01090     }
01091     else if (amcsatInfer)
01092     { // MC-SAT
01093       inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
01094     }
01095     else if (asimtpInfer)
01096     { // Simulated Tempering
01097         // When MWS is used in Sim. Temp., numSolutions is always 1
01098         // (maybe there is a better way to do this?)
01099       mwsparams->numSolutions = 1;
01100       inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
01101                                          stparams);
01102     }
01103     else if (agibbsInfer)
01104     { // Gibbs sampling
01105         // When MWS is used in Gibbs, numSolutions is always 1
01106         // (maybe there is a better way to do this?)
01107       mwsparams->numSolutions = 1;
01108       inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
01109                                    gibbsparams);
01110     }
01111   }
01112   return 1;
01113 }
01114  
01115   // Typedefs
01116 typedef hash_map<string, const Array<const char*>*, HashString, EqualString> 
01117 StringToStrArrayMap;
01118 
01119 #endif

Generated on Wed Feb 14 15:15:17 2007 for Alchemy by  doxygen 1.5.1