00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef MRF_H_SEP_23_2005
00067 #define MRF_H_SEP_23_2005
00068
00069 #include <sys/times.h>
00070 #include <sys/time.h>
00071 #include <cstdlib>
00072 #include <cfloat>
00073 #include <fstream>
00074 #include "timer.h"
00075 #include "mln.h"
00076 #include "groundpredicate.h"
00077
00078 #define MAX_LINE 1000
00079
00080 const bool mrfdebug = false;
00081
00083
00084
00085 struct AddGroundClauseStruct
00086 {
00087 AddGroundClauseStruct(const GroundPredicateSet* const & sseenPreds,
00088 GroundPredicateSet* const & uunseenPreds,
00089 GroundPredicateHashArray* const & ggndPreds,
00090 const Array<int>* const & aallPredGndingsAreQueries,
00091 GroundClauseSet* const & ggndClausesSet,
00092 Array<GroundClause*>* const & ggndClauses,
00093 const bool& mmarkHardGndClauses,
00094 const double* const & pparentWtPtr,
00095 const int & cclauseId)
00096 : seenPreds(sseenPreds), unseenPreds(uunseenPreds), gndPreds(ggndPreds),
00097 allPredGndingsAreQueries(aallPredGndingsAreQueries),
00098 gndClausesSet(ggndClausesSet),
00099 gndClauses(ggndClauses), markHardGndClauses(mmarkHardGndClauses),
00100 parentWtPtr(pparentWtPtr), clauseId(cclauseId) {}
00101
00102 ~AddGroundClauseStruct() {}
00103
00104 const GroundPredicateSet* seenPreds;
00105 GroundPredicateSet* unseenPreds;
00106 GroundPredicateHashArray* gndPreds;
00107 const Array<int>* allPredGndingsAreQueries;
00108 GroundClauseSet* gndClausesSet;
00109 Array<GroundClause*>* gndClauses;
00110 const bool markHardGndClauses;
00111 const double* parentWtPtr;
00112 const int clauseId;
00113 };
00114
00116
00117
00118 class MRF
00119 {
00120 public:
00121
00122
00123
00124
00125 MRF(const GroundPredicateHashArray* const& queries,
00126 const Array<int>* const & allPredGndingsAreQueries,
00127 const Domain* const & domain, const Database * const & db,
00128 const MLN* const & mln, const bool& markHardGndClauses,
00129 const bool& trackParentClauseWts, const int& memLimit)
00130 {
00131 cout << "creating mrf..." << endl;
00132 Timer timer;
00133 GroundPredicateSet unseenPreds, seenPreds;
00134 GroundPredicateToIntMap gndPredsMap;
00135 GroundClauseSet gndClausesSet;
00136 gndPreds_ = new GroundPredicateHashArray;
00137 gndClauses_ = new Array<GroundClause*>;
00138 long double memNeeded = 0;
00139
00140
00141 for (int i = 0; i < queries->size(); i++)
00142 {
00143 GroundPredicate* gp = (*queries)[i];
00144 unseenPreds.insert(gp);
00145 int gndPredIdx = gndPreds_->append(gp);
00146 assert(gndPredsMap.find(gp) == gndPredsMap.end());
00147 gndPredsMap[gp] = gndPredIdx;
00148 }
00149
00150
00151 if (memLimit > 0)
00152 {
00153 memNeeded = sizeKB();
00154 if (memNeeded > memLimit)
00155 {
00156 for (int i = 0; i < gndClauses_->size(); i++)
00157 delete (*gndClauses_)[i];
00158 delete gndClauses_;
00159
00160 for (int i = 0; i < gndPreds_->size(); i++)
00161 if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00162 delete gndPreds_;
00163
00164 throw 1;
00165 }
00166 }
00167
00168
00169 while (!unseenPreds.empty())
00170 {
00171 GroundPredicateSet::iterator predIt = unseenPreds.begin();
00172 GroundPredicate* pred = *predIt;
00173 unsigned int predId = pred->getId();
00174
00175
00176 bool genClausesForAllPredGndings = false;
00177
00178 if (allPredGndingsAreQueries && (*allPredGndingsAreQueries)[predId] >= 1)
00179 {
00180
00181 if ((*allPredGndingsAreQueries)[predId] == 1)
00182 genClausesForAllPredGndings = true;
00183 else
00184 {
00185
00186
00187 unseenPreds.erase(predIt);
00188 seenPreds.insert(pred);
00189 continue;
00190 }
00191 }
00192
00193 const Array<IndexClause*>* clauses
00194 = mln->getClausesContainingPred(predId);
00195
00196
00197
00198 for (int i = 0; clauses && i < clauses->size(); i++)
00199 {
00200 Clause* c = (*clauses)[i]->clause;
00201
00202 const int clauseId = mln->findClauseIdx(c);
00203 assert(clauseId >= 0);
00204
00205
00206 if (c->getWt() == 0) continue;
00207
00208
00209 const double* parentWtPtr =
00210 (trackParentClauseWts) ? c->getWtPtr() : NULL;
00211 AddGroundClauseStruct agc(&seenPreds, &unseenPreds, gndPreds_,
00212 allPredGndingsAreQueries,
00213 &gndClausesSet, gndClauses_,
00214 markHardGndClauses, parentWtPtr,
00215 clauseId);
00216
00217 try
00218 {
00219 addUnknownGndClauses(pred, c, domain, db, genClausesForAllPredGndings,
00220 &agc);
00221 }
00222 catch (bad_alloc&)
00223 {
00224 cout << "Bad alloc when adding unknown ground clauses to MRF!\n";
00225 cerr << "Bad alloc when adding unknown ground clauses to MRF!\n";
00226 throw 1;
00227 }
00228
00229
00230 if (memLimit > 0)
00231 {
00232 memNeeded = sizeKB();
00233
00234
00235
00236
00237 if (memNeeded > memLimit)
00238 {
00239 for (int i = 0; i < gndClauses_->size(); i++)
00240 delete (*gndClauses_)[i];
00241 delete gndClauses_;
00242
00243 for (int i = 0; i < gndPreds_->size(); i++)
00244 if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00245 delete gndPreds_;
00246
00247 throw 1;
00248 }
00249 }
00250 }
00251
00252
00253
00254
00255
00256
00257
00258 unseenPreds.erase(predIt);
00259 seenPreds.insert(pred);
00260 if (genClausesForAllPredGndings)
00261 {
00262 assert(allPredGndingsAreQueries &&
00263 (*allPredGndingsAreQueries)[predId]==1);
00264
00265 (*allPredGndingsAreQueries)[predId]++;
00266 }
00267 }
00268
00269 cout << "number of grounded predicates = " << gndPreds_->size() << endl;
00270 cout << "number of grounded clauses = " << gndClauses_->size() << endl;
00271 if (gndClauses_->size() == 0)
00272 cout<< "Markov blankets of query ground predicates are empty" << endl;
00273
00274 if (mrfdebug)
00275 {
00276 cout << "Clauses in MRF: " << endl;
00277 for (int i = 0; i < gndClauses_->size(); i++)
00278 {
00279 (*gndClauses_)[i]->print(cout, domain, gndPreds_);
00280 cout << endl;
00281 }
00282 }
00283
00284 for (int i = 0; i < gndPreds_->size(); i++)
00285 (*gndPreds_)[i]->compress();
00286
00287 gndPreds_->compress();
00288 gndClauses_->compress();
00289
00290 cout <<"Time taken to construct MRF = ";
00291 Timer::printTime(cout,timer.time());
00292 cout << endl;
00293 }
00294
00298 long double sizeKB()
00299 {
00300
00301
00302 long double size = 0;
00303 for (int i = 0; i < gndClauses_->size(); i++)
00304 size += (*gndClauses_)[i]->sizeKB();
00305 for (int i = 0; i < gndPreds_->size(); i++)
00306 size += (*gndPreds_)[i]->sizeKB();
00307
00308 return size;
00309 }
00310
00311
00312
00313 static void addUnknownGndClause(const AddGroundClauseStruct* const & agcs,
00314 const Clause* const & clause,
00315 const Clause* const & truncClause,
00316 const bool& isHardClause)
00317 {
00318 const GroundPredicateSet* seenPreds = agcs->seenPreds;
00319 GroundPredicateSet* unseenPreds = agcs->unseenPreds;
00320 GroundPredicateHashArray* gndPreds = agcs->gndPreds;
00321 const Array<int>* allGndingsAreQueries = agcs->allPredGndingsAreQueries;
00322 GroundClauseSet* gndClausesSet = agcs->gndClausesSet;
00323 Array<GroundClause*>* gndClauses = agcs->gndClauses;
00324 const bool markHardGndClauses = agcs->markHardGndClauses;
00325 const double* parentWtPtr = agcs->parentWtPtr;
00326 const int clauseId = agcs->clauseId;
00327
00328
00329
00330
00331
00332
00333
00334 bool seenBefore = false;
00335 for (int j = 0; j < clause->getNumPredicates(); j++)
00336 {
00337 Predicate* p = clause->getPredicate(j);
00338 GroundPredicate* gp = new GroundPredicate(p);
00339 if (seenPreds->find(gp) != seenPreds->end() ||
00340 (allGndingsAreQueries && (*allGndingsAreQueries)[gp->getId()] > 1) )
00341 {
00342 seenBefore = true;
00343 delete gp;
00344 break;
00345 }
00346 delete gp;
00347 }
00348
00349 if (seenBefore) return;
00350
00351 GroundClause* gndClause = new GroundClause(truncClause, gndPreds);
00352 if (markHardGndClauses && isHardClause) gndClause->setWtToHardWt();
00353 assert(gndClause->getWt() != 0);
00354
00355 bool invertWt = false;
00356
00357 if (!isHardClause && gndClause->getNumGroundPredicates() == 1 &&
00358 !gndClause->getGroundPredicateSense(0))
00359 {
00360 gndClause->setGroundPredicateSense(0, true);
00361 gndClause->setWt(-gndClause->getWt());
00362 invertWt = true;
00363 }
00364
00365 GroundClauseSet::iterator iter = gndClausesSet->find(gndClause);
00366
00367 if (iter == gndClausesSet->end())
00368 {
00369 gndClausesSet->insert(gndClause);
00370 gndClauses->append(gndClause);
00371 gndClause->appendToGndPreds(gndPreds);
00372
00373 if (parentWtPtr)
00374 gndClause->incrementClauseFrequency(clauseId, 1, invertWt);
00375
00376
00377
00378 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00379 {
00380 GroundPredicate* gp =
00381 (GroundPredicate*)gndClause->getGroundPredicate(j, gndPreds);
00382 assert(seenPreds->find(gp) == seenPreds->end());
00383
00384 GroundPredicateSet::iterator it = unseenPreds->find(gp);
00385 if (it == unseenPreds->end())
00386 {
00387
00388
00389 unseenPreds->insert(gp);
00390 }
00391 }
00392 }
00393 else
00394 {
00395 (*iter)->addWt(gndClause->getWt());
00396
00397 if (parentWtPtr)
00398 (*iter)->incrementClauseFrequency(clauseId, 1, invertWt);
00399
00400 delete gndClause;
00401 }
00402 }
00403
00404
00405
00406 ~MRF()
00407 {
00408 for (int i = 0; i < gndClauses_->size(); i++)
00409 if ((*gndClauses_)[i]) delete (*gndClauses_)[i];
00410 delete gndClauses_;
00411
00412 for (int i = 0; i < gndPreds_->size(); i++)
00413 if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00414 delete gndPreds_;
00415 }
00416
00417 void deleteGndPredsGndClauseSets()
00418 {
00419 for (int i = 0; i < gndPreds_->size(); i++)
00420 (*gndPreds_)[i]->deleteGndClauseSet();
00421 }
00422
00423 const GroundPredicateHashArray* getGndPreds() const { return gndPreds_; }
00424
00425 const Array<GroundClause*>* getGndClauses() const { return gndClauses_; }
00426
00427 private:
00428
00429 void addUnknownGndClauses(const GroundPredicate* const& queryGndPred,
00430 Clause* const & c, const Domain* const & domain,
00431 const Database* const & db,
00432 const bool& genClauseForAllPredGndings,
00433 const AddGroundClauseStruct* const & agcs)
00434 {
00435
00436 if (genClauseForAllPredGndings)
00437 c->addUnknownClauses(domain, db, -1, NULL, agcs);
00438 else
00439 {
00440 for (int i = 0; i < c->getNumPredicates(); i++)
00441 {
00442 if (c->getPredicate(i)->canBeGroundedAs(queryGndPred))
00443 c->addUnknownClauses(domain, db, i, queryGndPred, agcs);
00444 }
00445 }
00446 }
00447
00448 public:
00449
00450 const int getNumGndPreds()
00451 {
00452 return gndPreds_->size();
00453 }
00454
00455 const int getNumGndClauses()
00456 {
00457 return gndClauses_->size();
00458 }
00459
00460 private:
00461 GroundPredicateHashArray* gndPreds_;
00462 Array<GroundClause*>* gndClauses_;
00463 };
00464
00465
00466 #endif