00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef MRF_H_SEP_23_2005
00067 #define MRF_H_SEP_23_2005
00068
00069 #include <sys/times.h>
00070 #include <sys/time.h>
00071 #include <cstdlib>
00072 #include <cfloat>
00073 #include <fstream>
00074 #include "timer.h"
00075 #include "mln.h"
00076 #include "groundpredicate.h"
00077
00078 #define MAX_LINE 1000
00079
00080 const bool mrfdebug = false;
00081
00083
00084
00085 struct AddGroundClauseStruct
00086 {
00087 AddGroundClauseStruct(const GroundPredicateSet* const & sseenPreds,
00088 GroundPredicateSet* const & uunseenPreds,
00089 GroundPredicateHashArray* const & ggndPreds,
00090 const Array<int>* const & aallPredGndingsAreQueries,
00091 GroundClauseSet* const & ggndClausesSet,
00092 Array<GroundClause*>* const & ggndClauses,
00093 const bool& mmarkHardGndClauses,
00094 const double* const & pparentWtPtr,
00095 const int & cclauseId)
00096 : seenPreds(sseenPreds), unseenPreds(uunseenPreds), gndPreds(ggndPreds),
00097 allPredGndingsAreQueries(aallPredGndingsAreQueries),
00098 gndClausesSet(ggndClausesSet),
00099 gndClauses(ggndClauses), markHardGndClauses(mmarkHardGndClauses),
00100 parentWtPtr(pparentWtPtr), clauseId(cclauseId) {}
00101
00102 ~AddGroundClauseStruct() {}
00103
00104 const GroundPredicateSet* seenPreds;
00105 GroundPredicateSet* unseenPreds;
00106 GroundPredicateHashArray* gndPreds;
00107 const Array<int>* allPredGndingsAreQueries;
00108 GroundClauseSet* gndClausesSet;
00109 Array<GroundClause*>* gndClauses;
00110 const bool markHardGndClauses;
00111 const double* parentWtPtr;
00112 const int clauseId;
00113 };
00114
00116
00117
00118 class MRF
00119 {
00120 public:
00121
00122
00123
00124
00125 MRF(const GroundPredicateHashArray* const& queries,
00126 const Array<int>* const & allPredGndingsAreQueries,
00127 const Domain* const & domain, const Database * const & db,
00128 const MLN* const & mln, const bool& markHardGndClauses,
00129 const bool& trackParentClauseWts, const int& memLimit)
00130 {
00131 cout << "creating mrf..." << endl;
00132 Timer timer;
00133 GroundPredicateSet unseenPreds, seenPreds;
00134 GroundPredicateToIntMap gndPredsMap;
00135 GroundClauseSet gndClausesSet;
00136 gndPreds_ = new GroundPredicateHashArray;
00137 gndClauses_ = new Array<GroundClause*>;
00138 blocks_ = new Array<Array<int> >;
00139 blocks_->growToSize(domain->getNumPredBlocks());
00140 blockEvidence_ = new Array<bool>(*(domain->getBlockEvidenceArray()));
00141 long double memNeeded = 0;
00142
00143
00144 for (int i = 0; i < queries->size(); i++)
00145 {
00146 GroundPredicate* gp = (*queries)[i];
00147 unseenPreds.insert(gp);
00148 int gndPredIdx = gndPreds_->append(gp);
00149 assert(gndPredsMap.find(gp) == gndPredsMap.end());
00150 gndPredsMap[gp] = gndPredIdx;
00151 }
00152
00153
00154 if (memLimit > 0)
00155 {
00156 memNeeded = sizeKB();
00157 if (memNeeded > memLimit)
00158 {
00159 for (int i = 0; i < gndClauses_->size(); i++)
00160 delete (*gndClauses_)[i];
00161 delete gndClauses_;
00162
00163 for (int i = 0; i < gndPreds_->size(); i++)
00164 if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00165 delete gndPreds_;
00166
00167 for (int i = 0; i < blocks_->size(); i++)
00168 (*blocks_)[i].clearAndCompress();
00169 delete blocks_;
00170
00171 delete blockEvidence_;
00172
00173 throw 1;
00174 }
00175 }
00176
00177
00178 while (!unseenPreds.empty())
00179 {
00180 GroundPredicateSet::iterator predIt = unseenPreds.begin();
00181 GroundPredicate* pred = *predIt;
00182 unsigned int predId = pred->getId();
00183
00184
00185 bool genClausesForAllPredGndings = false;
00186
00187 if (allPredGndingsAreQueries && (*allPredGndingsAreQueries)[predId]>=1)
00188 {
00189
00190 if ((*allPredGndingsAreQueries)[predId] == 1)
00191 genClausesForAllPredGndings = true;
00192 else
00193 {
00194
00195
00196 unseenPreds.erase(predIt);
00197 seenPreds.insert(pred);
00198 continue;
00199 }
00200 }
00201
00202 const Array<IndexClause*>* clauses
00203 = mln->getClausesContainingPred(predId);
00204
00205
00206
00207 for (int i = 0; clauses && i < clauses->size(); i++)
00208 {
00209 Clause* c = (*clauses)[i]->clause;
00210
00211 const int clauseId = mln->findClauseIdx(c);
00212 assert(clauseId >= 0);
00213
00214
00215 if (c->getWt() == 0) continue;
00216
00217
00218 const double* parentWtPtr =
00219 (trackParentClauseWts) ? c->getWtPtr() : NULL;
00220 AddGroundClauseStruct agc(&seenPreds, &unseenPreds, gndPreds_,
00221 allPredGndingsAreQueries,
00222 &gndClausesSet, gndClauses_,
00223 markHardGndClauses, parentWtPtr,
00224 clauseId);
00225
00226 try
00227 {
00228 addUnknownGndClauses(pred, c, domain, db, genClausesForAllPredGndings,
00229 &agc);
00230 }
00231 catch (bad_alloc&)
00232 {
00233 throw 1;
00234 }
00235
00236
00237 if (memLimit > 0)
00238 {
00239 memNeeded = sizeKB();
00240
00241
00242
00243
00244 if (memNeeded > memLimit)
00245 {
00246 for (int i = 0; i < gndClauses_->size(); i++)
00247 delete (*gndClauses_)[i];
00248 delete gndClauses_;
00249
00250 for (int i = 0; i < gndPreds_->size(); i++)
00251 if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00252 delete gndPreds_;
00253
00254 for (int i = 0; i < blocks_->size(); i++)
00255 (*blocks_)[i].clearAndCompress();
00256 delete blocks_;
00257
00258 delete blockEvidence_;
00259 throw 1;
00260 }
00261 }
00262 }
00263
00264
00265
00266
00267
00268
00269
00270 unseenPreds.erase(predIt);
00271 seenPreds.insert(pred);
00272 if (genClausesForAllPredGndings)
00273 {
00274 assert(allPredGndingsAreQueries &&
00275 (*allPredGndingsAreQueries)[predId]==1);
00276
00277 (*allPredGndingsAreQueries)[predId]++;
00278 }
00279 }
00280
00281 cout << "number of grounded predicates = " << gndPreds_->size() << endl;
00282 cout << "number of grounded clauses = " << gndClauses_->size() << endl;
00283 if (gndClauses_->size() == 0)
00284 cout<< "Markov blankets of query ground predicates are empty" << endl;
00285
00286 if (mrfdebug)
00287 {
00288 cout << "Clauses in MRF: " << endl;
00289 for (int i = 0; i < gndClauses_->size(); i++)
00290 {
00291 (*gndClauses_)[i]->print(cout, domain, gndPreds_);
00292 cout << endl;
00293 }
00294 }
00295
00296 for (int i = 0; i < gndPreds_->size(); i++)
00297 {
00298 (*gndPreds_)[i]->compress();
00299
00300 const Array<Array<Predicate*>*>* blocks = domain->getPredBlocks();
00301 for (int j = 0; j < blocks->size(); j++)
00302 {
00303 Array<Predicate*>* block = (*blocks)[j];
00304 for (int k = 0; k < block->size(); k++)
00305 {
00306 Predicate* pred = (*block)[k];
00307 if (pred->canBeGroundedAs((*gndPreds_)[i]))
00308 {
00309 (*blocks_)[j].append(i);
00310 }
00311 }
00312 }
00313 }
00314
00315
00316
00317 int i = 0;
00318 while (i < blocks_->size())
00319 {
00320 Array<int> block = (*blocks_)[i];
00321 if (block.empty())
00322 {
00323 blocks_->removeItem(i);
00324 blockEvidence_->removeItem(i);
00325 continue;
00326 }
00327 i++;
00328 }
00329
00330 gndPreds_->compress();
00331 gndClauses_->compress();
00332
00333 cout <<"Time taken to construct MRF = ";
00334 Timer::printTime(cout,timer.time());
00335 cout << endl;
00336 }
00337
00341 long double sizeKB()
00342 {
00343
00344
00345 long double size = 0;
00346 for (int i = 0; i < gndClauses_->size(); i++)
00347 size += (*gndClauses_)[i]->sizeKB();
00348 for (int i = 0; i < gndPreds_->size(); i++)
00349 size += (*gndPreds_)[i]->sizeKB();
00350
00351 return size;
00352 }
00353
00354
00355
00356 static void addUnknownGndClause(const AddGroundClauseStruct* const & agcs,
00357 const Clause* const & clause,
00358 const Clause* const & truncClause,
00359 const bool& isHardClause)
00360 {
00361 const GroundPredicateSet* seenPreds = agcs->seenPreds;
00362 GroundPredicateSet* unseenPreds = agcs->unseenPreds;
00363 GroundPredicateHashArray* gndPreds = agcs->gndPreds;
00364 const Array<int>* allGndingsAreQueries = agcs->allPredGndingsAreQueries;
00365 GroundClauseSet* gndClausesSet = agcs->gndClausesSet;
00366 Array<GroundClause*>* gndClauses = agcs->gndClauses;
00367 const bool markHardGndClauses = agcs->markHardGndClauses;
00368 const double* parentWtPtr = agcs->parentWtPtr;
00369 const int clauseId = agcs->clauseId;
00370
00371
00372
00373
00374
00375
00376
00377 bool seenBefore = false;
00378 for (int j = 0; j < clause->getNumPredicates(); j++)
00379 {
00380 Predicate* p = clause->getPredicate(j);
00381 GroundPredicate* gp = new GroundPredicate(p);
00382 if (seenPreds->find(gp) != seenPreds->end() ||
00383 (allGndingsAreQueries && (*allGndingsAreQueries)[gp->getId()] > 1) )
00384 {
00385 seenBefore = true;
00386 break;
00387 }
00388 delete gp;
00389 }
00390
00391
00392 if (seenBefore) return;
00393
00394 GroundClause* gndClause = new GroundClause(truncClause, gndPreds);
00395 if (markHardGndClauses && isHardClause) gndClause->setWtToHardWt();
00396 assert(gndClause->getWt() != 0);
00397
00398 GroundClauseSet::iterator iter = gndClausesSet->find(gndClause);
00399
00400 if (iter == gndClausesSet->end())
00401 {
00402 gndClausesSet->insert(gndClause);
00403 gndClauses->append(gndClause);
00404 gndClause->appendToGndPreds(gndPreds);
00405
00406 if (parentWtPtr)
00407 {
00408 gndClause->appendParentWtPtr(parentWtPtr);
00409 gndClause->incrementClauseFrequency(clauseId, 1);
00410 assert(gndClause->getWt() == *parentWtPtr);
00411 }
00412
00413
00414
00415 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00416 {
00417 GroundPredicate* gp =
00418 (GroundPredicate*)gndClause->getGroundPredicate(j, gndPreds);
00419 assert(seenPreds->find(gp) == seenPreds->end());
00420
00421 GroundPredicateSet::iterator it = unseenPreds->find(gp);
00422 if (it == unseenPreds->end())
00423 {
00424
00425
00426 unseenPreds->insert(gp);
00427 }
00428 }
00429 }
00430 else
00431 {
00432 (*iter)->addWt(gndClause->getWt());
00433
00434 if (parentWtPtr)
00435 {
00436 (*iter)->appendParentWtPtr(parentWtPtr);
00437 (*iter)->incrementClauseFrequency(clauseId, 1);
00438 }
00439
00440 delete gndClause;
00441 }
00442 }
00443
00444
00445
00446 ~MRF()
00447 {
00448 for (int i = 0; i < gndClauses_->size(); i++)
00449 if ((*gndClauses_)[i]) delete (*gndClauses_)[i];
00450 delete gndClauses_;
00451
00452 for (int i = 0; i < gndPreds_->size(); i++)
00453 if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00454 delete gndPreds_;
00455
00456 for (int i = 0; i < blocks_->size(); i++)
00457 (*blocks_)[i].clearAndCompress();
00458 delete blocks_;
00459
00460 delete blockEvidence_;
00461 }
00462
00463 void deleteGndPredsGndClauseSets()
00464 {
00465 for (int i = 0; i < gndPreds_->size(); i++)
00466 (*gndPreds_)[i]->deleteGndClauseSet();
00467 }
00468
00469
00470 void setGndClausesWtsToSumOfParentWts()
00471 {
00472 for (int i = 0; i < gndClauses_->size(); i++)
00473 (*gndClauses_)[i]->setWtToSumOfParentWts();
00474 }
00475
00476 const GroundPredicateHashArray* getGndPreds() const { return gndPreds_; }
00477
00478 const Array<GroundClause*>* getGndClauses() const { return gndClauses_; }
00479
00480 private:
00481
00482 void addUnknownGndClauses(const GroundPredicate* const& queryGndPred,
00483 Clause* const & c, const Domain* const & domain,
00484 const Database* const & db,
00485 const bool& genClauseForAllPredGndings,
00486 const AddGroundClauseStruct* const & agcs)
00487 {
00488
00489 if (genClauseForAllPredGndings)
00490 c->addUnknownClauses(domain, db, -1, NULL, agcs);
00491 else
00492 {
00493 for (int i = 0; i < c->getNumPredicates(); i++)
00494 {
00495 if (c->getPredicate(i)->canBeGroundedAs(queryGndPred))
00496 c->addUnknownClauses(domain, db, i, queryGndPred, agcs);
00497 }
00498 }
00499 }
00500
00501 public:
00502
00503 const int getNumGndPreds()
00504 {
00505 return gndPreds_->size();
00506 }
00507
00508 const int getNumGndClauses()
00509 {
00510 return gndClauses_->size();
00511 }
00512
00513 Array<Array<int> >* getBlocks()
00514 {
00515 return blocks_;
00516 }
00517
00518 Array<bool>* getBlockEvidence()
00519 {
00520 return blockEvidence_;
00521 }
00522
00523 private:
00524 GroundPredicateHashArray* gndPreds_;
00525 Array<GroundClause*>* gndClauses_;
00526
00527 Array<Array<int> >* blocks_;
00528
00529 Array<bool>* blockEvidence_;
00530 };
00531
00532
00533 #endif