00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #ifndef _INFER_H_OCT_30_2005
00068 #define _INFER_H_OCT_30_2005
00069
00074 #include "util.h"
00075 #include "fol.h"
00076 #include "mrf.h"
00077 #include "learnwts.h"
00078 #include "inferenceargs.h"
00079 #include "maxwalksat.h"
00080 #include "mcsat.h"
00081 #include "gibbssampler.h"
00082 #include "simulatedtempering.h"
00083 #include "bp.h"
00084 #include "variablestate.h"
00085 #include "hvariablestate.h"
00086 #include "hmcsat.h"
00087 #include "lbfgsp.h"
00088
00089
00090
00091 char* aevidenceFiles = NULL;
00092 char* aresultsFile = NULL;
00093 char* aqueryPredsStr = NULL;
00094 char* aqueryFile = NULL;
00095
00096 char* atestcont = NULL;
00097 char* atestdis = NULL;
00098 char* aHMWSDis = NULL;
00099 char* aMWSrst = NULL;
00100
00101 int anumerator = 50;
00102 int adenominator = 100;
00103
00104 char* aMaxSeconds = NULL;
00105 bool aGenRandom = false;
00106 bool aStartPt = false;
00107
00108 int aLineNum = -1;
00109 char* aLinePara = NULL;
00110 char* aLineName = NULL;
00111
00112 int saInterval = 100;
00113
00114 char* aGndPredIdxMapFile = NULL;
00115 bool aPrintSamplePerIteration = false;
00116
00117 string queryPredsStr, queryFile;
00118 GroundPredicateHashArray queries;
00119 GroundPredicateHashArray knownQueries;
00120
00149 bool createComLineQueryPreds(const string& queryPredsStr,
00150 const Domain* const & domain,
00151 Database* const & db,
00152 GroundPredicateHashArray* const & queries,
00153 GroundPredicateHashArray* const & knownQueries,
00154 Array<int>* const & allPredGndingsAreQueries,
00155 bool printToFile, ostream& out, bool amapPos,
00156 const GroundPredicateHashArray* const & trueQueries,
00157 const Array<double>* const & trueProbs,
00158 Array<Array<Predicate* >* >* queryConjs)
00159 {
00160 if (queryPredsStr.length() == 0) return true;
00161 string predConjs = Util::trim(queryPredsStr);
00162
00163
00164 int balparen = 0;
00165 for (unsigned int i = 0; i < predConjs.length(); i++)
00166 {
00167 if (predConjs.at(i)=='(') balparen++;
00168 else if (predConjs.at(i)==')') balparen--;
00169 else if ((predConjs.at(i)==';' || predConjs.at(i)==',') &&
00170 balparen==0) predConjs.at(i) = '\n';
00171 }
00172
00173 bool ret = true;
00174 string predConj;
00175 istringstream iss(predConjs);
00176 char delimit[2]; delimit[1] = '\0';
00177
00178
00179 while (getline(iss, predConj))
00180 {
00181
00182 int balparen = 0;
00183 for (unsigned int i = 0; i < predConj.length(); i++)
00184 {
00185 if (predConj.at(i)=='(') balparen++;
00186 else if (predConj.at(i)==')') balparen--;
00187 else if (predConj.at(i)=='^' && balparen==0) predConj.at(i) = '\n';
00188 }
00189 bool onlyPredName;
00190 unsigned int cur;
00191 int termId, varIdCnt = 0;
00192 hash_map<string, int, HashString, EqualString> varToId;
00193 hash_map<string, int, HashString, EqualString>::iterator it;
00194 Array<VarsTypeId*>* vtiArr;
00195 string pred, predName, term;
00196 istringstream iss2(predConj);
00197 const PredicateTemplate* ptemplate;
00198 int predicate = 0;
00199
00200 Array<Predicate* >* predArray = new Array<Predicate*>;
00201 if (queryConjs) queryConjs->append(predArray);
00202
00203 while (getline(iss2, pred))
00204 {
00205 pred = Util::trim(pred);
00206 onlyPredName = false;
00207 varToId.clear();
00208 varIdCnt = 0;
00209 cur = 0;
00210 bool negated = false;
00211
00212
00213 if (pred.at(0) == '!')
00214 {
00215 negated = true;
00216 pred.at(0) = ' ';
00217 pred = Util::trim(pred);
00218 }
00219
00220 if (!Util::substr(pred, cur, predName, "("))
00221 {
00222 predName = pred;
00223 onlyPredName = true;
00224 }
00225
00226
00227 ptemplate = domain->getPredicateTemplate(predName.c_str());
00228 if (ptemplate == NULL)
00229 {
00230 cout << "ERROR: Cannot find command line query predicate" << predName
00231 << " in domain." << endl;
00232 ret = false;
00233 continue;
00234 }
00235 Predicate ppred(ptemplate);
00236
00237
00238 if (!onlyPredName)
00239 {
00240
00241 for (int i = 0; i < 2; i++)
00242 {
00243 if (i == 0) delimit[0] = ',';
00244 else delimit[0] = ')';
00245 while (Util::substr(pred, cur, term, delimit))
00246 {
00247
00248 if (isupper(term.at(0)) || term.at(0) == '"' || isdigit(term.at(0)))
00249 {
00250 termId = domain->getConstantId(term.c_str());
00251 if (termId < 0)
00252 {
00253 cout <<"ERROR: Cannot find constant "<<term<<" in database"<<endl;
00254 ret = false;
00255 }
00256 }
00257 else
00258 {
00259 if ((it=varToId.find(term)) == varToId.end())
00260 {
00261 termId = --varIdCnt;
00262 varToId[term] = varIdCnt;
00263 }
00264 else
00265 termId = (*it).second;
00266 }
00267 ppred.appendTerm(new Term(termId, (void*)&ppred, true));
00268 }
00269 }
00270 }
00271 else
00272 {
00273
00274
00275 for (int i = 0; i < ptemplate->getNumTerms(); i++)
00276 ppred.appendTerm(new Term(--varIdCnt, (void*)&ppred, true));
00277 }
00278
00279
00280 if (ppred.getNumTerms() != ptemplate->getNumTerms())
00281 {
00282 cout << "ERROR: " << predName << " requires " << ptemplate->getNumTerms()
00283 << " terms but given " << ppred.getNumTerms() << endl;
00284 ret = false;
00285 }
00286 if (!ret) continue;
00287
00289 vtiArr = NULL;
00290 ppred.createVarsTypeIdArr(vtiArr);
00291
00292
00293 if (vtiArr->size() <= 1)
00294 {
00295 assert(ppred.isGrounded());
00296 assert(!db->isClosedWorld(ppred.getId()));
00297 TruthValue tv = db->getValue(&ppred);
00298
00299 if (negated) ppred.setSense(false);
00300 predArray->append(new Predicate(ppred));
00301
00302 GroundPredicate* gndPred = new GroundPredicate(&ppred);
00303
00304 if (printToFile) assert(tv != UNKNOWN);
00305 if (tv == UNKNOWN)
00306 {
00307 if (queries->append(gndPred) < 0) delete gndPred;
00308 }
00309 else
00310 {
00311
00312 if (printToFile)
00313 {
00314
00315 if (trueQueries)
00316 {
00317 double prob = 0.0;
00318 if (domain->getDB()->getEvidenceStatus(&ppred))
00319 {
00320
00321 continue;
00322
00323 }
00324 else
00325 {
00326 int found = trueQueries->find(gndPred);
00327 if (found >= 0) prob = (*trueProbs)[found];
00328 else
00329
00330 prob = (prob*10000+1/2.0)/(10000+1.0);
00331
00332 }
00333 gndPred->print(out, domain); out << " " << prob << endl;
00334 }
00335 else
00336 {
00337 if (amapPos)
00338 {
00339 if (tv == TRUE)
00340 {
00341 ppred.printWithStrVar(out, domain);
00342 out << endl;
00343 }
00344 }
00345 else
00346 {
00347 ppred.printWithStrVar(out, domain);
00348 out << " " << tv << endl;
00349 }
00350 }
00351 delete gndPred;
00352 }
00353 else
00354 {
00355
00356
00357 if (knownQueries->append(gndPred) < 0) delete gndPred;
00358 }
00359 }
00360 }
00361 else
00362 {
00363 ArraysAccessor<int> acc;
00364 for (int i = 1; i < vtiArr->size(); i++)
00365 {
00366 const Array<int>* cons =
00367 domain->getConstantsByType((*vtiArr)[i]->typeId);
00368 acc.appendArray(cons);
00369 }
00370
00371
00372 Array<int> constIds;
00373 while (acc.getNextCombination(constIds))
00374 {
00375 assert(constIds.size() == vtiArr->size()-1);
00376 for (int j = 0; j < constIds.size(); j++)
00377 {
00378 Array<Term*>& terms = (*vtiArr)[j+1]->vars;
00379 for (int k = 0; k < terms.size(); k++)
00380 terms[k]->setId(constIds[j]);
00381 }
00382
00383
00384 assert(!db->isClosedWorld(ppred.getId()));
00385
00386 TruthValue tv = db->getValue(&ppred);
00387 GroundPredicate* gndPred = new GroundPredicate(&ppred);
00388
00389
00390 if (printToFile) assert(tv != UNKNOWN);
00391 if (tv == UNKNOWN)
00392 {
00393 if (queries->append(gndPred) < 0) delete gndPred;
00394 }
00395 else
00396 {
00397
00398 if (printToFile)
00399 {
00400
00401 if (trueQueries)
00402 {
00403 double prob = 0.0;
00404 if (domain->getDB()->getEvidenceStatus(&ppred))
00405 {
00406
00407 continue;
00408
00409 }
00410 else
00411 {
00412 int found = trueQueries->find(gndPred);
00413 if (found >= 0) prob = (*trueProbs)[found];
00414 else
00415
00416 prob = (prob*10000+1/2.0)/(10000+1.0);
00417 }
00418
00419
00420 gndPred->print(out, domain); out << " " << prob << endl;
00421 }
00422 else
00423 {
00424 if (amapPos)
00425 {
00426 if (tv == TRUE)
00427 {
00428 ppred.printWithStrVar(out, domain);
00429 out << endl;
00430 }
00431 }
00432 else
00433 {
00434 ppred.printWithStrVar(out, domain);
00435 out << " " << tv << endl;
00436 }
00437 }
00438 delete gndPred;
00439 }
00440 else
00441 {
00442
00443
00444 if (knownQueries->append(gndPred) < 0) delete gndPred;
00445 }
00446 }
00447 }
00448 }
00449
00450 ppred.deleteVarsTypeIdArr(vtiArr);
00451 predicate++;
00452 }
00453 if (predArray->size() == 0)
00454 {
00455 if (queryConjs) queryConjs->removeLastItem();
00456 delete predArray;
00457 }
00458 }
00459
00460 if (!printToFile)
00461 {
00462 queries->compress();
00463 knownQueries->compress();
00464 }
00465
00466 return ret;
00467 }
00468
00473 bool createComLineQueryPreds(const string& queryPredsStr,
00474 const Domain* const & domain,
00475 Database* const & db,
00476 GroundPredicateHashArray* const & queries,
00477 GroundPredicateHashArray* const & knownQueries,
00478 Array<int>* const & allPredGndingsAreQueries,
00479 Array<Array<Predicate* >* >* queryConjs)
00480 {
00481 return createComLineQueryPreds(queryPredsStr, domain, db,
00482 queries, knownQueries,
00483 allPredGndingsAreQueries,
00484 false, cout, false, NULL, NULL, queryConjs);
00485 }
00486
00497 bool extractPredNames(string preds, const string* queryFile,
00498 StringHashArray& predNames)
00499 {
00500 predNames.clear();
00501
00502
00503 string::size_type cur = 0, ws, ltparen;
00504 string qpred, predName;
00505
00506 if (preds.length() > 0)
00507 {
00508 preds.append(" ");
00509
00510
00511 int balparen = 0;
00512 for (unsigned int i = 0; i < preds.length(); i++)
00513 {
00514 if (preds.at(i) == '(') balparen++;
00515 else if (preds.at(i) == ')') balparen--;
00516 else if ((preds.at(i) == ',' || preds.at(i) == ';') &&
00517 balparen == 0) preds.at(i) = ' ';
00518 }
00519
00520 while (preds.at(cur) == ' ') cur++;
00521 while (true)
00522 {
00523 ws = preds.find(" ", cur);
00524 if (ws == string::npos) break;
00525 qpred = preds.substr(cur,ws-cur+1);
00526 cur = ws+1;
00527 while (cur < preds.length() &&
00528 (preds.at(cur) == ' ' || preds.at(cur) == '^' ||
00529 preds.at(cur) == '!')) cur++;
00530 ltparen = qpred.find("(",0);
00531
00532 if (ltparen == string::npos)
00533 {
00534 ws = qpred.find(" ");
00535 if (ws != string::npos) qpred = qpred.substr(0,ws);
00536 predName = qpred;
00537 }
00538 else
00539 predName = qpred.substr(0,ltparen);
00540
00541 predNames.append(predName);
00542 }
00543 }
00544
00545 if (queryFile == NULL || queryFile->length() == 0) return true;
00546
00547
00548 ifstream in((*queryFile).c_str());
00549 if (!in.good())
00550 {
00551 cout << "ERROR: unable to open " << *queryFile << endl;
00552 return false;
00553 }
00554 string buffer;
00555 while (getline(in, buffer))
00556 {
00557 cur = 0;
00558 while (cur < buffer.length() &&
00559 (buffer.at(cur) == ' ' || buffer.at(cur) == '^' ||
00560 buffer.at(cur) == '!')) cur++;
00561 ltparen = buffer.find("(", cur);
00562 if (ltparen == string::npos) continue;
00563 predName = buffer.substr(cur, ltparen-cur);
00564 predNames.append(predName);
00565 }
00566
00567 in.close();
00568 return true;
00569 }
00570
00575 char getTruthValueFirstChar(const TruthValue& tv)
00576 {
00577 if (tv == TRUE) return 'T';
00578 if (tv == FALSE) return 'F';
00579 if (tv == UNKNOWN) return 'U';
00580 assert(false);
00581 exit(-1);
00582 return '#';
00583 }
00584
00588 void setsrand()
00589 {
00590 struct timeval tv;
00591 struct timezone tzp;
00592 gettimeofday(&tv,&tzp);
00593 unsigned int seed = (( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec;
00594 srand(seed);
00595 }
00596
00597
00598
00599 void copyFileAndAppendDbFile(const string& srcFile, string& dstFile,
00600 const Array<string>& dbFilesArr,
00601 const Array<string>& constFilesArr)
00602 {
00603 ofstream out(dstFile.c_str());
00604 ifstream in(srcFile.c_str());
00605 if (!out.good()) { cout<<"ERROR: failed to open "<<dstFile<<endl;exit(-1);}
00606 if (!in.good()) { cout<<"ERROR: failed to open "<<srcFile<<endl;exit(-1);}
00607
00608 string buffer;
00609 while(getline(in, buffer)) out << buffer << endl;
00610 in.close();
00611
00612 out << endl;
00613 for (int i = 0; i < constFilesArr.size(); i++)
00614 out << "#include \"" << constFilesArr[i] << "\"" << endl;
00615 out << endl;
00616 for (int i = 0; i < dbFilesArr.size(); i++)
00617 out << "#include \"" << dbFilesArr[i] << "\"" << endl;
00618 out.close();
00619 }
00620
00621
00622 bool checkQueryPredsNotInClosedWorldPreds(const StringHashArray& qpredNames,
00623 const StringHashArray& cwPredNames)
00624 {
00625 bool ok = true;
00626 for (int i = 0; i < qpredNames.size(); i++)
00627 if (cwPredNames.contains(qpredNames[i]))
00628 {
00629 cout << "ERROR: query predicate " << qpredNames[i]
00630 << " cannot be specified as closed world" << endl;
00631 ok = false;
00632 }
00633 return ok;
00634 }
00635
00664 bool createQueryFilePreds(const string& queryFile,
00665 const Domain* const & domain,
00666 Database* const & db,
00667 GroundPredicateHashArray* const &queries,
00668 GroundPredicateHashArray* const &knownQueries,
00669 bool printToFile, ostream& out, bool amapPos,
00670 const GroundPredicateHashArray* const &trueQueries,
00671 const Array<double>* const & trueProbs,
00672 Array<Array<Predicate* >* >* queryConjs)
00673 {
00674 if (queryFile.length() == 0) return true;
00675
00676 bool ret = true;
00677 string predConj;
00678 ifstream in(queryFile.c_str());
00679 char delimit[2]; delimit[1] = '\0';
00680
00681
00682 while (getline(in, predConj))
00683 {
00684
00685 int balparen = 0;
00686 for (unsigned int i = 0; i < predConj.length(); i++)
00687 {
00688 if (predConj.at(i)=='(') balparen++;
00689 else if (predConj.at(i)==')') balparen--;
00690 else if (predConj.at(i)=='^' && balparen==0) predConj.at(i) = '\n';
00691 }
00692
00693 bool onlyPredName;
00694 unsigned int cur;
00695 int termId, varIdCnt = 0;
00696 hash_map<string, int, HashString, EqualString> varToId;
00697 hash_map<string, int, HashString, EqualString>::iterator it;
00698 Array<VarsTypeId*>* vtiArr;
00699 string pred, predName, term;
00700 istringstream iss2(predConj);
00701 const PredicateTemplate* ptemplate;
00702 int predicate = 0;
00703
00704 Array<Predicate* >* predArray = new Array<Predicate*>;
00705 queryConjs->append(predArray);
00706
00707 while (getline(iss2, pred))
00708 {
00709 pred = Util::trim(pred);
00710 onlyPredName = false;
00711 varToId.clear();
00712 varIdCnt = 0;
00713 cur = 0;
00714 bool negated = false;
00715
00716
00717 if (pred.at(0) == '!')
00718 {
00719 negated = true;
00720 pred.at(0) = ' ';
00721 pred = Util::trim(pred);
00722 }
00723
00724 if (!Util::substr(pred, cur, predName, "("))
00725 {
00726 predName = pred;
00727 onlyPredName = true;
00728 }
00729
00730
00731 ptemplate = domain->getPredicateTemplate(predName.c_str());
00732 if (ptemplate == NULL)
00733 {
00734 cout << "ERROR: Cannot find command line query predicate" << predName
00735 << " in domain." << endl;
00736 ret = false;
00737 continue;
00738 }
00739 Predicate ppred(ptemplate);
00740
00741
00742 if (!onlyPredName)
00743 {
00744
00745 for (int i = 0; i < 2; i++)
00746 {
00747 if (i == 0) delimit[0] = ',';
00748 else delimit[0] = ')';
00749 while (Util::substr(pred, cur, term, delimit))
00750 {
00751
00752 if (isupper(term.at(0)) || term.at(0) == '"' || isdigit(term.at(0)))
00753 {
00754 termId = domain->getConstantId(term.c_str());
00755 if (termId < 0)
00756 {
00757 cout <<"ERROR: Cannot find constant "<<term<<" in database"<<endl;
00758 ret = false;
00759 }
00760 }
00761 else
00762 {
00763 if ((it=varToId.find(term)) == varToId.end())
00764 {
00765 termId = --varIdCnt;
00766 varToId[term] = varIdCnt;
00767 }
00768 else
00769 termId = (*it).second;
00770 }
00771 ppred.appendTerm(new Term(termId, (void*)&ppred, true));
00772 }
00773 }
00774 }
00775 else
00776 {
00777
00778
00779 for (int i = 0; i < ptemplate->getNumTerms(); i++)
00780 ppred.appendTerm(new Term(--varIdCnt, (void*)&ppred, true));
00781 }
00782
00783
00784 if (ppred.getNumTerms() != ptemplate->getNumTerms())
00785 {
00786 cout << "ERROR: " << predName << " requires " << ptemplate->getNumTerms()
00787 << " terms but given " << ppred.getNumTerms() << endl;
00788 ret = false;
00789 }
00790 if (!ret) continue;
00791
00793 vtiArr = NULL;
00794 ppred.createVarsTypeIdArr(vtiArr);
00795
00796
00797 if (vtiArr->size() <= 1)
00798 {
00799 assert(ppred.isGrounded());
00800 assert(!db->isClosedWorld(ppred.getId()));
00801 TruthValue tv = db->getValue(&ppred);
00802
00803 if (negated) ppred.setSense(false);
00804 predArray->append(new Predicate(ppred));
00805
00806 GroundPredicate* gndPred = new GroundPredicate(&ppred);
00807
00808 if (printToFile) assert(tv != UNKNOWN);
00809 if (tv == UNKNOWN)
00810 {
00811 if (queries->append(gndPred) < 0) delete gndPred;
00812 }
00813 else
00814 {
00815
00816 if (printToFile)
00817 {
00818
00819 if (trueQueries)
00820 {
00821 double prob = 0.0;
00822 if (domain->getDB()->getEvidenceStatus(&ppred))
00823 {
00824
00825 continue;
00826
00827 }
00828 else
00829 {
00830 int found = trueQueries->find(gndPred);
00831 if (found >= 0) prob = (*trueProbs)[found];
00832 else
00833
00834 prob = (prob*10000+1/2.0)/(10000+1.0);
00835
00836 }
00837 gndPred->print(out, domain); out << " " << prob << endl;
00838 }
00839 else
00840 {
00841 if (amapPos)
00842 {
00843 if (tv == TRUE)
00844 {
00845 ppred.printWithStrVar(out, domain);
00846 out << endl;
00847 }
00848 }
00849 else
00850 {
00851 ppred.printWithStrVar(out, domain);
00852 out << " " << tv << endl;
00853 }
00854 }
00855 delete gndPred;
00856 }
00857 else
00858 {
00859
00860
00861 if (knownQueries->append(gndPred) < 0) delete gndPred;
00862 }
00863 }
00864 }
00865 else
00866 {
00867 ArraysAccessor<int> acc;
00868 for (int i = 1; i < vtiArr->size(); i++)
00869 {
00870 const Array<int>* cons =
00871 domain->getConstantsByType((*vtiArr)[i]->typeId);
00872 acc.appendArray(cons);
00873 }
00874
00875
00876 Array<int> constIds;
00877 while (acc.getNextCombination(constIds))
00878 {
00879 assert(constIds.size() == vtiArr->size()-1);
00880 for (int j = 0; j < constIds.size(); j++)
00881 {
00882 Array<Term*>& terms = (*vtiArr)[j+1]->vars;
00883 for (int k = 0; k < terms.size(); k++)
00884 terms[k]->setId(constIds[j]);
00885 }
00886
00887
00888 assert(!db->isClosedWorld(ppred.getId()));
00889
00890 TruthValue tv = db->getValue(&ppred);
00891 GroundPredicate* gndPred = new GroundPredicate(&ppred);
00892
00893
00894 if (printToFile) assert(tv != UNKNOWN);
00895 if (tv == UNKNOWN)
00896 {
00897 if (queries->append(gndPred) < 0) delete gndPred;
00898 }
00899 else
00900 {
00901
00902 if (printToFile)
00903 {
00904
00905 if (trueQueries)
00906 {
00907 double prob = 0.0;
00908 if (domain->getDB()->getEvidenceStatus(&ppred))
00909 {
00910
00911 continue;
00912
00913 }
00914 else
00915 {
00916 int found = trueQueries->find(gndPred);
00917 if (found >= 0) prob = (*trueProbs)[found];
00918 else
00919
00920 prob = (prob*10000+1/2.0)/(10000+1.0);
00921 }
00922
00923
00924 gndPred->print(out, domain); out << " " << prob << endl;
00925 }
00926 else
00927 {
00928 if (amapPos)
00929 {
00930 if (tv == TRUE)
00931 {
00932 ppred.printWithStrVar(out, domain);
00933 out << endl;
00934 }
00935 }
00936 else
00937 {
00938 ppred.printWithStrVar(out, domain);
00939 out << " " << tv << endl;
00940 }
00941 }
00942 delete gndPred;
00943 }
00944 else
00945 {
00946
00947
00948 if (knownQueries->append(gndPred) < 0) delete gndPred;
00949 }
00950 }
00951 }
00952 }
00953
00954 ppred.deleteVarsTypeIdArr(vtiArr);
00955 predicate++;
00956 }
00957 if (predArray->size() == 0)
00958 {
00959 if (queryConjs) queryConjs->removeLastItem();
00960 delete predArray;
00961 }
00962 }
00963
00964 if (!printToFile)
00965 {
00966 queries->compress();
00967 knownQueries->compress();
00968 }
00969
00970 in.close();
00971 return ret;
00972 }
00973
00978 bool createQueryFilePreds(const string& queryFile, const Domain* const & domain,
00979 Database* const & db,
00980 GroundPredicateHashArray* const &queries,
00981 GroundPredicateHashArray* const &knownQueries,
00982 Array<Array<Predicate* >* >* queryConjs)
00983 {
00984 return createQueryFilePreds(queryFile, domain, db, queries, knownQueries,
00985 false, cout, false, NULL, NULL, queryConjs);
00986 }
00987
00988
00995 int buildInference(Inference*& inference, Domain*& domain,
00996 bool const &aisQueryEvidence, Array<Predicate *> &queryPreds,
00997 Array<TruthValue> &queryPredValues)
00998 {
00999 string inMLNFile, wkMLNFile, evidenceFile;
01000
01001 StringHashArray queryPredNames;
01002 StringHashArray owPredNames;
01003 StringHashArray cwPredNames;
01004 MLN* mln = NULL;
01005 Array<string> constFilesArr;
01006 Array<string> evidenceFilesArr;
01007
01008
01009
01010
01011
01012
01013
01014 extractFileNames(ainMLNFiles, constFilesArr);
01015 assert(constFilesArr.size() >= 1);
01016 inMLNFile.append(constFilesArr[0]);
01017 constFilesArr.removeItem(0);
01018 extractFileNames(aevidenceFiles, evidenceFilesArr);
01019
01020 if (aqueryPredsStr) queryPredsStr.append(aqueryPredsStr);
01021 if (aqueryFile) queryFile.append(aqueryFile);
01022
01023 if (queryPredsStr.length() == 0 && queryFile.length() == 0)
01024 { cout << "No query predicates specified" << endl; return -1; }
01025
01026 if (agibbsInfer && agibbsTestConvergence && amcmcNumChains < 2)
01027 {
01028 cout << "ERROR: If testing for convergence, there must be at least 2 "
01029 << "MCMC chains in Gibbs sampling"
01030 << endl; return -1;
01031 }
01032
01033 if (!asimtpInfer && !amapPos && !amapAll && !agibbsInfer && !amcsatInfer &&
01034 !aHybrid && !aSA && !abpInfer && !aoutputNetwork)
01035 {
01036
01037 amcsatInfer = true;
01038 }
01039
01040
01041 if (queryPredsStr.length() > 0 || queryFile.length() > 0)
01042 {
01043 if (!extractPredNames(queryPredsStr, &queryFile, queryPredNames)) return -1;
01044 }
01045
01046 if (amwsMaxSteps <= 0)
01047 { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
01048
01049 if (amwsTries <= 0)
01050 { cout << "ERROR: mwsTries must be positive" << endl; return -1; }
01051
01052
01053 if (aOpenWorldPredsStr)
01054 {
01055 if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames))
01056 return -1;
01057 assert(owPredNames.size() > 0);
01058 }
01059
01060
01061 if (aClosedWorldPredsStr)
01062 {
01063 if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames))
01064 return -1;
01065 assert(cwPredNames.size() > 0);
01066 if (!checkQueryPredsNotInClosedWorldPreds(queryPredNames, cwPredNames))
01067 return -1;
01068 }
01069
01070
01071
01072
01073
01074
01075
01076
01077
01078
01079
01080 SampleSatParams* ssparams = new SampleSatParams;
01081 ssparams->lateSa = assLateSa;
01082 ssparams->saRatio = assSaRatio;
01083 ssparams->saTemp = assSaTemp;
01084
01085
01086 MaxWalksatParams* mwsparams = new MaxWalksatParams;
01087 mwsparams->ssParams = ssparams;
01088 mwsparams->maxSteps = amwsMaxSteps;
01089 mwsparams->maxTries = amwsTries;
01090 mwsparams->targetCost = amwsTargetWt;
01091 mwsparams->hard = amwsHard;
01092
01093
01094 mwsparams->numSolutions = amwsNumSolutions;
01095 mwsparams->heuristic = amwsHeuristic;
01096 mwsparams->tabuLength = amwsTabuLength;
01097 mwsparams->lazyLowState = amwsLazyLowState;
01098
01099
01100 MCSatParams* msparams = new MCSatParams;
01101 msparams->mwsParams = mwsparams;
01102
01103 msparams->numChains = 1;
01104 msparams->burnMinSteps = amcmcBurnMinSteps;
01105 msparams->burnMaxSteps = amcmcBurnMaxSteps;
01106 msparams->minSteps = amcmcMinSteps;
01107 msparams->maxSteps = amcmcMaxSteps;
01108 msparams->maxSeconds = amcmcMaxSeconds;
01109
01110
01111
01112 GibbsParams* gibbsparams = new GibbsParams;
01113 gibbsparams->mwsParams = mwsparams;
01114 gibbsparams->numChains = amcmcNumChains;
01115 gibbsparams->burnMinSteps = amcmcBurnMinSteps;
01116 gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
01117 gibbsparams->minSteps = amcmcMinSteps;
01118 gibbsparams->maxSteps = amcmcMaxSteps;
01119 gibbsparams->maxSeconds = amcmcMaxSeconds;
01120
01121 gibbsparams->gamma = 1 - agibbsDelta;
01122 gibbsparams->epsilonError = agibbsEpsilonError;
01123 gibbsparams->fracConverged = agibbsFracConverged;
01124 gibbsparams->walksatType = agibbsWalksatType;
01125 gibbsparams->testConvergence = agibbsTestConvergence;
01126 gibbsparams->samplesPerTest = agibbsSamplesPerTest;
01127
01128
01129 SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
01130 stparams->mwsParams = mwsparams;
01131 stparams->numChains = amcmcNumChains;
01132 stparams->burnMinSteps = amcmcBurnMinSteps;
01133 stparams->burnMaxSteps = amcmcBurnMaxSteps;
01134 stparams->minSteps = amcmcMinSteps;
01135 stparams->maxSteps = amcmcMaxSteps;
01136 stparams->maxSeconds = amcmcMaxSeconds;
01137
01138 stparams->subInterval = asimtpSubInterval;
01139 stparams->numST = asimtpNumST;
01140 stparams->numSwap = asimtpNumSwap;
01141
01142
01143 BPParams* bpparams = new BPParams;
01144 bpparams->maxSteps = amcmcMaxSteps;
01145 bpparams->maxSeconds = amcmcMaxSeconds;
01146 bpparams->lifted = aliftedInfer;
01147 bpparams->convergenceThresh = abpConvergenceThresh;
01148 bpparams->convergeRequiredItrCnt = abpConvergeRequiredItrCnt;
01149 bpparams->implicitRep = !aexplicitRep;
01150 bpparams->outputNetwork = aoutputNetwork;
01151
01153
01154 cout << "Reading formulas and evidence predicates..." << endl;
01155
01156
01157 string::size_type bslash = inMLNFile.rfind("/");
01158 string tmp = (bslash == string::npos) ?
01159 inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
01160 char buf[100];
01161 sprintf(buf, "%s%d%s", tmp.c_str(), getpid(), ZZ_TMP_FILE_POSTFIX);
01162 wkMLNFile = buf;
01163 copyFileAndAppendDbFile(inMLNFile, wkMLNFile,
01164 evidenceFilesArr, constFilesArr);
01165
01166
01167 domain = new Domain;
01168 mln = new MLN();
01169 bool addUnitClauses = false;
01170 bool mustHaveWtOrFullStop = true;
01171 bool warnAboutDupGndPreds = true;
01172 bool flipWtsOfFlippedClause = true;
01173 bool allPredsExceptQueriesAreCW = false;
01174
01175 Domain* forCheckingPlusTypes = NULL;
01176
01177
01178
01179 if (!runYYParser(mln, domain, wkMLNFile.c_str(), allPredsExceptQueriesAreCW,
01180 &owPredNames, &cwPredNames, &queryPredNames, addUnitClauses,
01181 warnAboutDupGndPreds, 0, mustHaveWtOrFullStop,
01182 forCheckingPlusTypes, true, flipWtsOfFlippedClause))
01183 {
01184 unlink(wkMLNFile.c_str());
01185 return -1;
01186 }
01187
01188 unlink(wkMLNFile.c_str());
01189
01191
01193 Array<int> allPredGndingsAreQueries;
01194 Array<Array<Predicate* >* >* queryFormulas = new Array<Array<Predicate*> *>;
01195
01196
01197
01198 if (!aLazy)
01199 {
01200 if (queryFile.length() > 0)
01201 {
01202 cout << "Reading query predicates that are specified in query file..."
01203 << endl;
01204 bool ok = createQueryFilePreds(queryFile, domain, domain->getDB(),
01205 &queries, &knownQueries, queryFormulas);
01206 if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
01207 }
01208
01209 allPredGndingsAreQueries.growToSize(domain->getNumPredicates(), false);
01210 if (queryPredsStr.length() > 0)
01211 {
01212
01213
01214 GroundPredicateHashArray unePreds;
01215 GroundPredicateHashArray knePreds;
01216 bool ok = createComLineQueryPreds(queryPredsStr, domain,
01217 domain->getDB(), &unePreds, &knePreds,
01218 &allPredGndingsAreQueries, queryFormulas);
01219 if (!ok) { cout<<"Failed to create query predicates."<<endl; exit(-1); }
01220
01221 if (aisQueryEvidence)
01222 {
01223
01224
01225
01226
01227
01228
01229
01230
01231 queryPredValues.clear();
01232 for (int predno = 0; predno < unePreds.size(); predno++)
01233 domain->getDB()->setValue(unePreds[predno], FALSE);
01234 knownQueries = unePreds;
01235
01236
01237 for (int predno = 0; predno < knePreds.size(); predno++)
01238 {
01239 TruthValue origValue
01240 = domain->getDB()->setValue(knePreds[predno], UNKNOWN);
01241 queryPredValues.append(origValue);
01242 }
01243 queries = knePreds;
01244 }
01245 else
01246 {
01247 queries = unePreds;
01248 knownQueries = knePreds;
01249 }
01250 }
01251 }
01252
01253 bool trackClauseTrueCnts = false;
01254 VariableState* state = NULL;
01255 HVariableState* hstate = NULL;
01256 FactorGraph* factorGraph = NULL;
01257
01258 if (abpInfer || aoutputNetwork)
01259 {
01260 factorGraph = new FactorGraph(bpparams->lifted, mln, domain, queryFormulas);
01261 inference = new BP(factorGraph, bpparams, queryFormulas);
01262 }
01263 else
01264 {
01265
01266 bool markHardGndClauses = true;
01267 bool trackParentClauseWts = false;
01268 if (aHybrid)
01269 {
01270
01271 bool markHardGndClauses = true;
01272 bool trackParentClauseWts = false;
01273 hstate = new HVariableState(&queries, NULL, NULL,
01274 &allPredGndingsAreQueries,
01275 markHardGndClauses,
01276 trackParentClauseWts,
01277 mln, domain, aLazy);
01278
01279 hstate->LoadContGroundedMLN(aContPartFile);
01280 hstate->WriteGndPredIdxMap(aGndPredIdxMapFile);
01281 hstate->bMaxOnly_ = amapPos;
01282 }
01283 else
01284 {
01285
01286
01287 state = new VariableState(&queries, NULL, NULL,
01288 &allPredGndingsAreQueries,
01289 markHardGndClauses,
01290 trackParentClauseWts,
01291 mln, domain, aLazy);
01292 }
01293
01294 if (aGenRandom && aHybrid)
01295 {
01296 hstate->setProposalStdev(aProposalStdev);
01297 hstate->initRandom();
01298
01299 hstate->saveLowStateAll();
01300 ofstream oscont(atestcont), osdis(atestdis);
01301 hstate->printLowStateCont(oscont);
01302 hstate->printLowState(osdis);
01303 return 0;
01304 }
01305
01306 if (aStartPt && aHybrid)
01307 {
01308 hstate->LoadDisEviValuesFromRst(atestdis);
01309 hstate->LoadContEviValues(atestcont);
01310 }
01311
01312 if (aStartPt && !aHybrid)
01313 {
01314 state->LoadDisEviValuesFromRst(atestdis);
01315 }
01316
01317
01318 if (amapPos || amapAll || amcsatInfer || agibbsInfer || asimtpInfer ||
01319 aHybrid || aSA)
01320 {
01321 if ((amapPos || amapAll) && !aHybrid)
01322 {
01323
01324
01325 mwsparams->numSolutions = 1;
01326 inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts, mwsparams);
01327 }
01328 else if (amapPos && aHybrid && !amcsatInfer && !aSA)
01329 {
01330 hstate->setProposalStdev(aProposalStdev);
01331
01332
01333 hstate->optimizeIndividualNumTerm();
01334 mwsparams->numSolutions = 1;
01335 mwsparams->heuristic = HMWS;
01336 inference = new HMaxWalkSat(hstate, aSeed, trackClauseTrueCnts,
01337 mwsparams);
01338 HMaxWalkSat* p = (HMaxWalkSat*)inference;
01339 p->SetNoisePra(anumerator, adenominator);
01340 if (aMaxSeconds)
01341 {
01342 p->SetMaxSeconds(atof(aMaxSeconds));
01343 }
01344
01345 if (aStartPt)
01346 {
01347 hstate->setInitFromEvi(true);
01348 }
01349 }
01350 else if (amapPos && aHybrid && !amcsatInfer && aSA)
01351 {
01352 hstate->setProposalStdev(aProposalStdev);
01353 mwsparams->numSolutions = 1;
01354 mwsparams->heuristic = HSA;
01355 inference = new HMaxWalkSat(hstate, aSeed, trackClauseTrueCnts,
01356 mwsparams);
01357
01358 HMaxWalkSat* p = (HMaxWalkSat*)inference;
01359
01360 p->setHeuristic(HSA);
01361 p->setSATempDownRatio(aSATempDownRatio);
01362 p->SetMaxSeconds(atof(aMaxSeconds));
01363 p->SetSAInterval(saInterval);
01364 if (aStartPt)
01365 {
01366 hstate->setInitFromEvi(true);
01367 }
01368 }
01369 else if (aHybrid && amcsatInfer)
01370 {
01371 cout << "Creating HMCSAT instance." << endl;
01372 hstate->setProposalStdev(aProposalStdev);
01373 if (aContSamples == NULL)
01374 {
01375 cout << "Numeric sample file is error." <<endl;
01376 }
01377 inference = new HMCSAT(hstate, aSeed, trackClauseTrueCnts, msparams);
01378 HMCSAT* p = (HMCSAT*) inference;
01379 p->SetContSampleFile(aContSamples);
01380 p->SetPrintVarsPerSample(aPrintSamplePerIteration);
01381 }
01382 else if (amcsatInfer && !aHybrid)
01383 {
01384 inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams,
01385 queryFormulas);
01386 }
01387 else if (asimtpInfer && !aHybrid)
01388 {
01389
01390
01391 mwsparams->numSolutions = 1;
01392 inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
01393 stparams);
01394 }
01395 else if (agibbsInfer && !aHybrid)
01396 {
01397
01398
01399 mwsparams->numSolutions = 1;
01400 inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
01401 gibbsparams);
01402 }
01403 }
01404 }
01405 return 1;
01406 }
01407
01408
01409 typedef hash_map<string, const Array<const char*>*, HashString, EqualString>
01410 StringToStrArrayMap;
01411
01412 #endif