00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef _INFER_H_OCT_30_2005
00067 #define _INFER_H_OCT_30_2005
00068
00073 #include "util.h"
00074 #include "fol.h"
00075 #include "mrf.h"
00076
00108 bool createComLineQueryPreds(const string& queryPredsStr,
00109 const Domain* const & domain,
00110 Database* const & db,
00111 GroundPredicateHashArray* const & queries,
00112 GroundPredicateHashArray* const & knownQueries,
00113 Array<int>* const & allPredGndingsAreQueries,
00114 bool printToFile, ostream& out, bool amapPos,
00115 const GroundPredicateHashArray* const & trueQueries,
00116 const Array<double>* const & trueProbs)
00117 {
00118 if (queryPredsStr.length() == 0) return true;
00119 string preds = Util::trim(queryPredsStr);
00120
00121
00122 int balparen = 0;
00123 for (unsigned int i = 0; i < preds.length(); i++)
00124 {
00125 if (preds.at(i)=='(') balparen++;
00126 else if (preds.at(i)==')') balparen--;
00127 else if (preds.at(i)==',' && balparen==0) preds.at(i)='\n';
00128 }
00129
00130 bool onlyPredName;
00131 bool ret = true;
00132 unsigned int cur;
00133 int termId, varIdCnt = 0;
00134 hash_map<string, int, HashString, EqualString> varToId;
00135 hash_map<string, int, HashString, EqualString>::iterator it;
00136 Array<VarsTypeId*>* vtiArr;
00137 string pred, predName, term;
00138 const PredicateTemplate* ptemplate;
00139 istringstream iss(preds);
00140 char delimit[2]; delimit[1] = '\0';
00141
00142
00143 while (getline(iss, pred))
00144 {
00145 onlyPredName = false;
00146 varToId.clear();
00147 varIdCnt = 0;
00148 cur = 0;
00149
00150
00151 if (!Util::substr(pred,cur,predName,"("))
00152 {
00153 predName = pred;
00154 onlyPredName = true;
00155 }
00156
00157
00158 ptemplate = domain->getPredicateTemplate(predName.c_str());
00159 if (ptemplate == NULL)
00160 {
00161 cout << "ERROR: Cannot find command line query predicate" << predName
00162 << " in domain." << endl;
00163 ret = false;
00164 continue;
00165 }
00166 Predicate ppred(ptemplate);
00167
00168
00169 if (!onlyPredName)
00170 {
00171
00172 for (int i = 0; i < 2; i++)
00173 {
00174 if (i == 0) delimit[0] = ',';
00175 else delimit[0] = ')';
00176 while(Util::substr(pred, cur, term, delimit))
00177 {
00178
00179 if (isupper(term.at(0)) || term.at(0) == '"')
00180 {
00181 termId = domain->getConstantId(term.c_str());
00182 if (termId < 0)
00183 {
00184 cout <<"ERROR: Cannot find constant "<<term<<" in database"<<endl;
00185 ret = false;
00186 }
00187 }
00188 else
00189 {
00190 if ((it=varToId.find(term)) == varToId.end())
00191 {
00192 termId = --varIdCnt;
00193 varToId[term] = varIdCnt;
00194 }
00195 else
00196 termId = (*it).second;
00197 }
00198 ppred.appendTerm(new Term(termId, (void*)&ppred, true));
00199 }
00200 }
00201 }
00202 else
00203 {
00204 (*allPredGndingsAreQueries)[ptemplate->getId()] = true;
00205 for (int i = 0; i < ptemplate->getNumTerms(); i++)
00206 ppred.appendTerm(new Term(--varIdCnt, (void*)&ppred, true));
00207 }
00208
00209
00210 if (ppred.getNumTerms() != ptemplate->getNumTerms())
00211 {
00212 cout << "ERROR: " << predName << " requires " << ptemplate->getNumTerms()
00213 << " terms but given " << ppred.getNumTerms() << endl;
00214 ret = false;
00215 }
00216 if (!ret) continue;
00217
00218
00220 vtiArr = NULL;
00221 ppred.createVarsTypeIdArr(vtiArr);
00222
00223
00224 if (vtiArr->size() <= 1)
00225 {
00226 assert(ppred.isGrounded());
00227 assert(!db->isClosedWorld(ppred.getId()));
00228 TruthValue tv = db->getValue(&ppred);
00229 GroundPredicate* gndPred = new GroundPredicate(&ppred);
00230
00231
00232 if (printToFile) assert(tv != UNKNOWN);
00233 if (tv == UNKNOWN)
00234 {
00235 if (queries->append(gndPred) < 0) delete gndPred;
00236 }
00237 else
00238 {
00239
00240 if (printToFile)
00241 {
00242
00243 if (trueQueries)
00244 {
00245 double prob = 0.0;
00246 if (domain->getDB()->getEvidenceStatus(&ppred))
00247 {
00248
00249 continue;
00250
00251 }
00252 else
00253 {
00254 int found = trueQueries->find(gndPred);
00255 if (found >= 0) prob = (*trueProbs)[found];
00256 else
00257
00258 prob = (prob*10000+1/2.0)/(10000+1.0);
00259
00260 }
00261 gndPred->print(out, domain); out << " " << prob << endl;
00262 }
00263 else
00264 {
00265 if (amapPos)
00266 {
00267 if (tv == TRUE)
00268 {
00269 ppred.printWithStrVar(out, domain);
00270 out << endl;
00271 }
00272 }
00273 else
00274 {
00275 ppred.printWithStrVar(out, domain);
00276 out << " " << tv << endl;
00277 }
00278 }
00279 delete gndPred;
00280 }
00281 else
00282 {
00283
00284
00285
00286 if (knownQueries->append(gndPred) < 0) delete gndPred;
00287 }
00288 }
00289 }
00290 else
00291 {
00292 ArraysAccessor<int> acc;
00293 for (int i = 1; i < vtiArr->size(); i++)
00294 {
00295 const Array<int>* cons=domain->getConstantsByType((*vtiArr)[i]->typeId);
00296 acc.appendArray(cons);
00297 }
00298
00299
00300 Array<int> constIds;
00301 while (acc.getNextCombination(constIds))
00302 {
00303 assert(constIds.size() == vtiArr->size()-1);
00304 for (int j = 0; j < constIds.size(); j++)
00305 {
00306 Array<Term*>& terms = (*vtiArr)[j+1]->vars;
00307 for (int k = 0; k < terms.size(); k++)
00308 terms[k]->setId(constIds[j]);
00309 }
00310
00311
00312 assert(!db->isClosedWorld(ppred.getId()));
00313
00314 TruthValue tv = db->getValue(&ppred);
00315 GroundPredicate* gndPred = new GroundPredicate(&ppred);
00316
00317
00318 if (printToFile) assert(tv != UNKNOWN);
00319 if (tv == UNKNOWN)
00320 {
00321 if (queries->append(gndPred) < 0) delete gndPred;
00322 }
00323 else
00324 {
00325
00326 if (printToFile)
00327 {
00328
00329 if (trueQueries)
00330 {
00331 double prob = 0.0;
00332 if (domain->getDB()->getEvidenceStatus(&ppred))
00333 {
00334
00335 continue;
00336
00337 }
00338 else
00339 {
00340 int found = trueQueries->find(gndPred);
00341 if (found >= 0) prob = (*trueProbs)[found];
00342 else
00343
00344 prob = (prob*10000+1/2.0)/(10000+1.0);
00345 }
00346
00347
00348 gndPred->print(out, domain); out << " " << prob << endl;
00349 }
00350 else
00351 {
00352 if (amapPos)
00353 {
00354 if (tv == TRUE)
00355 {
00356 ppred.printWithStrVar(out, domain);
00357 out << endl;
00358 }
00359 }
00360 else
00361 {
00362 ppred.printWithStrVar(out, domain);
00363 out << " " << tv << endl;
00364 }
00365 }
00366 delete gndPred;
00367 }
00368 else
00369 {
00370
00371
00372 if (knownQueries->append(gndPred) < 0) delete gndPred;
00373 }
00374 }
00375 }
00376 }
00377
00378 ppred.deleteVarsTypeIdArr(vtiArr);
00379 }
00380
00381 if (!printToFile)
00382 {
00383 queries->compress();
00384 knownQueries->compress();
00385 }
00386
00387 return ret;
00388 }
00389
00394 bool createComLineQueryPreds(const string& queryPredsStr,
00395 const Domain* const & domain,
00396 Database* const & db,
00397 GroundPredicateHashArray* const & queries,
00398 GroundPredicateHashArray* const & knownQueries,
00399 Array<int>* const & allPredGndingsAreQueries)
00400 {
00401 return createComLineQueryPreds(queryPredsStr, domain, db,
00402 queries, knownQueries,
00403 allPredGndingsAreQueries,
00404 false, cout, false, NULL, NULL);
00405 }
00406
00417 bool extractPredNames(string preds, const string* queryFile,
00418 StringHashArray& predNames)
00419 {
00420 predNames.clear();
00421
00422
00423 string::size_type cur = 0, ws, ltparen;
00424 string qpred, predName;
00425
00426 if (preds.length() > 0)
00427 {
00428 preds.append(" ");
00429
00430
00431 int balparen = 0;
00432 for (unsigned int i = 0; i < preds.length(); i++)
00433 {
00434 if (preds.at(i) == '(') balparen++;
00435 else if (preds.at(i) == ')') balparen--;
00436 else if (preds.at(i) == ',' && balparen == 0) preds.at(i) = ' ';
00437 }
00438
00439 while (preds.at(cur) == ' ') cur++;
00440 while (true)
00441 {
00442 ws = preds.find(" ", cur);
00443 if (ws == string::npos) break;
00444 qpred = preds.substr(cur,ws-cur+1);
00445 cur = ws+1;
00446 while (cur < preds.length() && preds.at(cur) == ' ') cur++;
00447 ltparen = qpred.find("(",0);
00448
00449 if (ltparen == string::npos)
00450 {
00451 ws = qpred.find(" ");
00452 if (ws != string::npos) qpred = qpred.substr(0,ws);
00453 predName = qpred;
00454 }
00455 else
00456 predName = qpred.substr(0,ltparen);
00457
00458 predNames.append(predName);
00459 }
00460 }
00461
00462 if (queryFile == NULL || queryFile->length() == 0) return true;
00463
00464
00465 ifstream in((*queryFile).c_str());
00466 if (!in.good())
00467 {
00468 cout << "ERROR: unable to open " << *queryFile << endl;
00469 return false;
00470 }
00471 string buffer;
00472 while (getline(in, buffer))
00473 {
00474 cur = 0;
00475 while (cur < buffer.length() && buffer.at(cur) == ' ') cur++;
00476 ltparen = buffer.find("(", cur);
00477 if (ltparen == string::npos) continue;
00478 predName = buffer.substr(cur, ltparen-cur);
00479 predNames.append(predName);
00480 }
00481
00482 in.close();
00483 return true;
00484 }
00485
00490 char getTruthValueFirstChar(const TruthValue& tv)
00491 {
00492 if (tv == TRUE) return 'T';
00493 if (tv == FALSE) return 'F';
00494 if (tv == UNKNOWN) return 'U';
00495 assert(false);
00496 exit(-1);
00497 return '#';
00498 }
00499
00503 void setsrand()
00504 {
00505 struct timeval tv;
00506 struct timezone tzp;
00507 gettimeofday(&tv,&tzp);
00508 unsigned int seed = (( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec;
00509 srand(seed);
00510 }
00511
00512
00513
00514 void copyFileAndAppendDbFile(const string& srcFile, string& dstFile,
00515 const Array<string>& dbFilesArr,
00516 const Array<string>& constFilesArr)
00517 {
00518 ofstream out(dstFile.c_str());
00519 ifstream in(srcFile.c_str());
00520 if (!out.good()) { cout<<"ERROR: failed to open "<<dstFile<<endl;exit(-1);}
00521 if (!in.good()) { cout<<"ERROR: failed to open "<<srcFile<<endl;exit(-1);}
00522
00523 string buffer;
00524 while(getline(in, buffer)) out << buffer << endl;
00525 in.close();
00526
00527 out << endl;
00528 for (int i = 0; i < constFilesArr.size(); i++)
00529 out << "#include \"" << constFilesArr[i] << "\"" << endl;
00530 out << endl;
00531 for (int i = 0; i < dbFilesArr.size(); i++)
00532 out << "#include \"" << dbFilesArr[i] << "\"" << endl;
00533 out.close();
00534 }
00535
00536
00537 bool checkQueryPredsNotInClosedWorldPreds(const StringHashArray& qpredNames,
00538 const StringHashArray& cwPredNames)
00539 {
00540 bool ok = true;
00541 for (int i = 0; i < qpredNames.size(); i++)
00542 if (cwPredNames.contains(qpredNames[i]))
00543 {
00544 cout << "ERROR: query predicate " << qpredNames[i]
00545 << " cannot be specified as closed world" << endl;
00546 ok = false;
00547 }
00548 return ok;
00549 }
00550
00579 bool createQueryFilePreds(const string& queryFile,
00580 const Domain* const & domain,
00581 Database* const & db,
00582 GroundPredicateHashArray* const &queries,
00583 GroundPredicateHashArray* const &knownQueries,
00584 bool printToFile, ostream& out, bool amapPos,
00585 const GroundPredicateHashArray* const &trueQueries,
00586 const Array<double>* const & trueProbs)
00587 {
00588 if (queryFile.length() == 0) return true;
00589
00590 bool ret = true;
00591 ifstream in(queryFile.c_str());
00592 unsigned int line = 0;
00593 unsigned int cur;
00594 int constId, predId;
00595 bool ok;
00596 string predStr, predName, constant;
00597 Array<int> constIds;
00598 const PredicateTemplate* ptemplate;
00599
00600 while (getline(in, predStr))
00601 {
00602 line++;
00603 cur = 0;
00604
00605
00606 ok = Util::substr(predStr, cur, predName, "(");
00607 if (!ok) continue;
00608
00609 predId = domain->getPredicateId(predName.c_str());
00610 ptemplate = domain->getPredicateTemplate(predId);
00611
00612 if (predId < 0 || ptemplate == NULL)
00613 {
00614 cout << "ERROR: Cannot find " << predName << " in domain on line "
00615 << line << " of query file " << queryFile << endl;
00616 ret = false;
00617 continue;
00618 }
00619
00620
00621 constIds.clear();
00622 while (Util::substr(predStr, cur, constant, ","))
00623 {
00624 constId = domain->getConstantId(constant.c_str());
00625 constIds.append(constId);
00626 if (constId < 0)
00627 {
00628 cout << "ERROR: Cannot find " << constant << " in database on line "
00629 << line << " of query file " << queryFile << endl;
00630 ret = false;
00631 }
00632 }
00633
00634
00635 ok = Util::substr(predStr, cur, constant, ")");
00636 if (!ok)
00637 {
00638 cout << "ERROR: Failed to parse ground predicate on line " << line
00639 << " of query file " << queryFile << endl;
00640 ret = false;
00641 continue;
00642 }
00643
00644 constId = domain->getConstantId(constant.c_str());
00645 constIds.append(constId);
00646 if (constId < 0)
00647 {
00648 cout << "ERROR: Cannot find " << constant << " in database on line "
00649 << line << " of query file " << queryFile << endl;
00650 ret = false;
00651 }
00652
00653 if (!ret) continue;
00655
00656
00657 if (constIds.size() != ptemplate->getNumTerms())
00658 {
00659 cout << "ERROR: incorrect number of terms for " << predName
00660 << ". Expected " << ptemplate->getNumTerms() << ", given "
00661 << constIds.size() << endl;
00662 ret = false;
00663 continue;
00664 }
00665
00666 Predicate pred(ptemplate);
00667 for (int i = 0; i < constIds.size(); i++)
00668 {
00669 if (pred.getTermTypeAsInt(i) != domain->getConstantTypeId(constIds[i]))
00670 {
00671 cout << "ERROR: wrong type for term "
00672 << domain->getConstantName(constIds[i]) << " for predicate "
00673 << predName << " on line " << line << " of query file "
00674 << queryFile << endl;
00675 ret = false;
00676 continue;
00677 }
00678 pred.appendTerm(new Term(constIds[i], (void*)&pred, true));
00679 }
00680 if (!ret) continue;
00681
00682 assert(!db->isClosedWorld(predId));
00683
00684 TruthValue tv = db->getValue(&pred);
00685 GroundPredicate* gndPred = new GroundPredicate(&pred);
00686
00687
00688 if (printToFile) assert(tv != UNKNOWN);
00689 if (tv == UNKNOWN)
00690 {
00691 if (queries->append(gndPred) < 0) delete gndPred;
00692 }
00693 else
00694 {
00695
00696 if (printToFile)
00697 {
00698
00699
00700 if (trueQueries)
00701 {
00702 double prob = 0.0;
00703 if (domain->getDB()->getEvidenceStatus(&pred))
00704 {
00705
00706 continue;
00707
00708 }
00709 else
00710 {
00711 int found = trueQueries->find(gndPred);
00712 if (found >= 0) prob = (*trueProbs)[found];
00713 else
00714
00715 prob = (prob*10000+1/2.0)/(10000+1.0);
00716 }
00717 gndPred->print(out, domain); out << " " << prob << endl;
00718 }
00719 else
00720 {
00721 if (amapPos)
00722 {
00723 if (tv == TRUE)
00724 {
00725 pred.printWithStrVar(out, domain);
00726 out << endl;
00727 }
00728 }
00729 else
00730 {
00731 pred.printWithStrVar(out, domain);
00732 out << " " << tv << endl;
00733 }
00734 }
00735 delete gndPred;
00736 }
00737 else
00738 {
00739
00740
00741 if (knownQueries->append(gndPred) < 0) delete gndPred;
00742 }
00743 }
00744 }
00745
00746 in.close();
00747 return ret;
00748 }
00749
00754 bool createQueryFilePreds(const string& queryFile, const Domain* const & domain,
00755 Database* const & db,
00756 GroundPredicateHashArray* const &queries,
00757 GroundPredicateHashArray* const &knownQueries)
00758 {
00759 return createQueryFilePreds(queryFile, domain, db, queries, knownQueries,
00760 false, cout, false, NULL, NULL);
00761 }
00762
00763 void readPredValuesAndSetToUnknown(const StringHashArray& predNames,
00764 Domain *domain,
00765 Array<Predicate *> &queryPreds,
00766 Array<TruthValue> &queryPredValues,
00767 bool isQueryEvidence)
00768 {
00769 Array<Predicate*> ppreds;
00770
00771
00772 queryPreds.clear();
00773 queryPredValues.clear();
00774 for (int predno = 0; predno < predNames.size(); predno++)
00775 {
00776 ppreds.clear();
00777 int predid = domain->getPredicateId(predNames[predno].c_str());
00778 Predicate::createAllGroundings(predid, domain, ppreds);
00779 for (int gpredno = 0; gpredno < ppreds.size(); gpredno++)
00780 {
00781 Predicate *pred = ppreds[gpredno];
00782 TruthValue tv = domain->getDB()->getValue(pred);
00783 if (tv == UNKNOWN)
00784 domain->getDB()->setValue(pred,FALSE);
00785
00786
00787
00788
00789 if (isQueryEvidence && tv == UNKNOWN)
00790 delete pred;
00791 else
00792 queryPreds.append(pred);
00793 }
00794 }
00795
00796
00797 domain->getDB()->setValuesToUnknown(&queryPreds, &queryPredValues);
00798 }
00799
00810 void setPredsToGivenValues(const StringHashArray& predNames, Domain *domain,
00811 Array<TruthValue> &gpredValues)
00812 {
00813 Array<Predicate*> gpreds;
00814 Array<Predicate*> ppreds;
00815 Array<TruthValue> tmpValues;
00816
00817
00818 gpreds.clear();
00819 tmpValues.clear();
00820 for (int predno = 0; predno < predNames.size(); predno++)
00821 {
00822 ppreds.clear();
00823 int predid = domain->getPredicateId(predNames[predno].c_str());
00824 Predicate::createAllGroundings(predid, domain, ppreds);
00825
00826 gpreds.append(ppreds);
00827 }
00828
00829 domain->getDB()->setValuesToGivenValues(&gpreds, &gpredValues);
00830 for (int gpredno = 0; gpredno < gpreds.size(); gpredno++)
00831 delete gpreds[gpredno];
00832 }
00833
00834
00835 typedef hash_map<string, const Array<const char*>*, HashString, EqualString>
00836 StringToStrArrayMap;
00837
00838 #endif