mrf.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef MRF_H_SEP_23_2005
00067 #define MRF_H_SEP_23_2005
00068 
00069 #include <sys/times.h>
00070 #include <sys/time.h>
00071 #include <cstdlib>
00072 #include <cfloat>
00073 #include <fstream>
00074 #include "timer.h"
00075 #include "mln.h"
00076 #include "groundpredicate.h"
00077 
00078 #define MAX_LINE 1000
00079 
00080 const bool mrfdebug = false;
00081 
00083 
00084   // used as parameter of addGndClause()
00085 struct AddGroundClauseStruct
00086 {
00087   AddGroundClauseStruct(const GroundPredicateSet* const & sseenPreds,
00088                         GroundPredicateSet* const & uunseenPreds,
00089                         GroundPredicateHashArray* const & ggndPreds,
00090                         const Array<int>* const & aallPredGndingsAreQueries,
00091                         GroundClauseSet* const & ggndClausesSet,
00092                         Array<GroundClause*>* const & ggndClauses,
00093                         const bool& mmarkHardGndClauses,
00094                         const double* const & pparentWtPtr,
00095                         const int & cclauseId)
00096     : seenPreds(sseenPreds), unseenPreds(uunseenPreds), gndPreds(ggndPreds),
00097       allPredGndingsAreQueries(aallPredGndingsAreQueries),
00098       gndClausesSet(ggndClausesSet),
00099       gndClauses(ggndClauses), markHardGndClauses(mmarkHardGndClauses),
00100       parentWtPtr(pparentWtPtr), clauseId(cclauseId) {}
00101   
00102   ~AddGroundClauseStruct() {}
00103   
00104   const GroundPredicateSet* seenPreds;
00105   GroundPredicateSet* unseenPreds;
00106   GroundPredicateHashArray* gndPreds;
00107   const Array<int>* allPredGndingsAreQueries;
00108   GroundClauseSet* gndClausesSet;
00109   Array<GroundClause*>* gndClauses;
00110   const bool markHardGndClauses;
00111   const double* parentWtPtr;
00112   const int clauseId;
00113 };
00114 
00116 
00117 
00118 class MRF
00119 {
00120  public:
00121     //allPredGndingsAreQueries[p] is 1 (true) if all groundings of predicate p 
00122     //are in queries, otherwise it is 0 (false). 
00123     //allPredGndingsAreQueries can be
00124     //NULL if none of the predicates have all their groundings as queries.
00125   MRF(const GroundPredicateHashArray* const& queries, 
00126       const Array<int>* const & allPredGndingsAreQueries,
00127       const Domain* const & domain,  const Database * const & db, 
00128       const MLN* const & mln, const bool& markHardGndClauses,
00129       const bool& trackParentClauseWts, const int& memLimit)
00130   {
00131     cout << "creating mrf..." << endl; 
00132     Timer timer;
00133     GroundPredicateSet unseenPreds, seenPreds;
00134     GroundPredicateToIntMap gndPredsMap;
00135     GroundClauseSet gndClausesSet;
00136     gndPreds_ = new GroundPredicateHashArray;
00137     gndClauses_ = new Array<GroundClause*>;
00138     blocks_ = new Array<Array<int> >;
00139     blocks_->growToSize(domain->getNumPredBlocks());
00140     blockEvidence_ = new Array<bool>(*(domain->getBlockEvidenceArray()));
00141     long double memNeeded = 0;
00142         
00143       //add GroundPredicates in queries to unseenPreds
00144     for (int i = 0; i < queries->size(); i++)
00145     {
00146       GroundPredicate* gp = (*queries)[i];
00147       unseenPreds.insert(gp);
00148       int gndPredIdx = gndPreds_->append(gp);
00149       assert(gndPredsMap.find(gp) == gndPredsMap.end());
00150       gndPredsMap[gp] = gndPredIdx;
00151     }
00152 
00153       // If too much memory to build MRF then destroy it
00154     if (memLimit > 0)
00155     {
00156       memNeeded = sizeKB();
00157       if (memNeeded > memLimit)
00158       {
00159         for (int i = 0; i < gndClauses_->size(); i++)
00160           delete (*gndClauses_)[i];
00161         delete gndClauses_;    
00162 
00163         for (int i = 0; i < gndPreds_->size(); i++)
00164           if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00165         delete gndPreds_;
00166     
00167         for (int i = 0; i < blocks_->size(); i++)
00168           (*blocks_)[i].clearAndCompress();
00169         delete blocks_;
00170     
00171         delete blockEvidence_;
00172                     
00173         throw 1;
00174       }
00175     }
00176         
00177       //while there are still unknown preds we have not looked at
00178     while (!unseenPreds.empty())   
00179     {
00180       GroundPredicateSet::iterator predIt = unseenPreds.begin();
00181       GroundPredicate* pred = *predIt;
00182       unsigned int predId = pred->getId();
00183       //cout << "\tlooking at pred: ";  pred->print(cout, domain); cout << endl;
00184 
00185       bool genClausesForAllPredGndings = false;
00186         // if all groundings of predicate with predId are queries
00187       if (allPredGndingsAreQueries && (*allPredGndingsAreQueries)[predId]>=1)
00188       {
00189           // if we have not generated gnd clauses containing the queries before
00190         if ((*allPredGndingsAreQueries)[predId] == 1) 
00191           genClausesForAllPredGndings = true;
00192         else
00193         {   //we have dealt with predicate predId already          
00194           //cout << "\terasing at pred: ";  pred->print(cout, domain); 
00195           //cout<< endl;
00196           unseenPreds.erase(predIt);
00197           seenPreds.insert(pred);
00198           continue;
00199         }
00200       }
00201         //get all clauses that contains pred with predId
00202       const Array<IndexClause*>* clauses
00203         = mln->getClausesContainingPred(predId);
00204 
00205         //for each clause, ground it and find those with unknown truth values,
00206         //dropping ground preds which do not matter to the clause's truth value
00207       for (int i = 0; clauses && i < clauses->size(); i++)
00208       {
00209         Clause* c = (*clauses)[i]->clause;
00210 //cout << "\tIn clause c: ";  c->printWithWtAndStrVar(cout, domain); cout << endl;
00211                 const int clauseId = mln->findClauseIdx(c);  
00212                 assert(clauseId >= 0);
00213                 
00214                   //ignore clause with zero weight
00215         if (c->getWt() == 0) continue;
00216 
00217           //add gnd clauses with unknown truth values to gndClauses_
00218         const double* parentWtPtr =
00219           (trackParentClauseWts) ? c->getWtPtr() : NULL;
00220         AddGroundClauseStruct agc(&seenPreds, &unseenPreds, gndPreds_,
00221                                   allPredGndingsAreQueries,
00222                                   &gndClausesSet, gndClauses_,
00223                                   markHardGndClauses, parentWtPtr,
00224                                   clauseId);
00225 
00226         try
00227         {
00228           addUnknownGndClauses(pred, c, domain, db, genClausesForAllPredGndings,
00229                                &agc);
00230         }
00231         catch (bad_alloc&)
00232         {
00233           throw 1;
00234         }
00235 
00236           // If too much memory to build MRF then destroy it
00237         if (memLimit > 0)
00238         {
00239           memNeeded = sizeKB();
00240             //cout << "preds " << gndPreds_->size() << endl;
00241             //cout << "clauses " << gndClauses_->size() << endl;
00242             //cout << "memory " << memNeeded << endl;
00243             //cout << "limit " << memLimit << endl;
00244           if (memNeeded > memLimit)
00245           {
00246             for (int i = 0; i < gndClauses_->size(); i++)
00247               delete (*gndClauses_)[i];
00248             delete gndClauses_;    
00249 
00250             for (int i = 0; i < gndPreds_->size(); i++)
00251               if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00252             delete gndPreds_;
00253     
00254             for (int i = 0; i < blocks_->size(); i++)
00255               (*blocks_)[i].clearAndCompress();
00256             delete blocks_;
00257     
00258             delete blockEvidence_;
00259             throw 1;
00260           }
00261         }
00262       }
00263 
00264       //clauses with negative wts are handled by the inference algorithms
00265 
00266       //if all the gnd clauses that pred appears in have known truth value,
00267       //it is not added to gndPreds_ and excluded from the MCMC network
00268 
00269       //cout << "\terasing pred: ";  pred->print(cout, domain); cout << endl;
00270       unseenPreds.erase(predIt);
00271       seenPreds.insert(pred);
00272       if (genClausesForAllPredGndings)
00273       {
00274         assert(allPredGndingsAreQueries && 
00275                (*allPredGndingsAreQueries)[predId]==1);
00276           //indicate we have seen all groundings of pred predId
00277         (*allPredGndingsAreQueries)[predId]++;
00278       }
00279     }//while (!unseenPreds.empty())
00280 
00281     cout << "number of grounded predicates = " << gndPreds_->size() << endl;
00282     cout << "number of grounded clauses = " << gndClauses_->size() << endl;
00283     if (gndClauses_->size() == 0)
00284       cout<< "Markov blankets of query ground predicates are empty" << endl;
00285 
00286     if (mrfdebug)
00287     {
00288       cout << "Clauses in MRF: " << endl;
00289       for (int i = 0; i < gndClauses_->size(); i++)
00290       {
00291         (*gndClauses_)[i]->print(cout, domain, gndPreds_);
00292         cout << endl;
00293       }
00294     }
00295       // Compress preds and find blocks of preds
00296     for (int i = 0; i < gndPreds_->size(); i++)
00297     {
00298       (*gndPreds_)[i]->compress();
00299 
00300       const Array<Array<Predicate*>*>* blocks = domain->getPredBlocks();
00301       for (int j = 0; j < blocks->size(); j++)
00302       {
00303         Array<Predicate*>* block = (*blocks)[j];
00304         for (int k = 0; k < block->size(); k++)
00305         {
00306           Predicate* pred = (*block)[k];
00307           if (pred->canBeGroundedAs((*gndPreds_)[i]))
00308           {
00309             (*blocks_)[j].append(i);
00310           }
00311         }
00312       }
00313     }
00314 
00315       // Remove empty blocks (blocks generated in domain, but contain no query
00316       // atoms)
00317     int i = 0;
00318     while (i < blocks_->size())
00319     {
00320       Array<int> block = (*blocks_)[i];
00321       if (block.empty())
00322       {
00323         blocks_->removeItem(i);
00324         blockEvidence_->removeItem(i);
00325         continue;
00326       }
00327       i++;
00328     }
00329     
00330     gndPreds_->compress();
00331     gndClauses_->compress();
00332 
00333     cout <<"Time taken to construct MRF = ";
00334     Timer::printTime(cout,timer.time());
00335     cout << endl;
00336   }
00337 
00341   long double sizeKB()
00342   {
00343       // # of ground clauses times memory for a ground clause +
00344       // # of ground predicates times memory for a ground predicate
00345     long double size = 0;
00346     for (int i = 0; i < gndClauses_->size(); i++)
00347       size += (*gndClauses_)[i]->sizeKB();
00348     for (int i = 0; i < gndPreds_->size(); i++)
00349       size += (*gndPreds_)[i]->sizeKB();
00350 
00351     return size;    
00352   }
00353 
00354     // Do not delete the clause and truncClause argument.
00355     // This function is tightly bound to Clause::createAndAddUnknownClause().
00356   static void addUnknownGndClause(const AddGroundClauseStruct* const & agcs, 
00357                                   const Clause* const & clause,
00358                                   const Clause* const & truncClause,
00359                                   const bool& isHardClause)
00360   {
00361     const GroundPredicateSet* seenPreds     = agcs->seenPreds;
00362     GroundPredicateSet*       unseenPreds   = agcs->unseenPreds;
00363     GroundPredicateHashArray* gndPreds      = agcs->gndPreds;
00364     const Array<int>* allGndingsAreQueries  = agcs->allPredGndingsAreQueries;
00365     GroundClauseSet*          gndClausesSet = agcs->gndClausesSet;
00366     Array<GroundClause*>*     gndClauses    = agcs->gndClauses;
00367     const bool markHardGndClauses           = agcs->markHardGndClauses;
00368     const double* parentWtPtr               = agcs->parentWtPtr;
00369     const int clauseId                      = agcs->clauseId;
00370 
00371     // Check none of the grounded clause's predicates have been seen before.
00372     // If any of them have been seen before, this clause has been created 
00373     // before (for that seen predicate), and can be ignored
00374 
00375       // Check the untruncated ground clause whether any of its predicates
00376       // have been seen before
00377     bool seenBefore = false;
00378     for (int j = 0; j < clause->getNumPredicates(); j++)
00379     {
00380       Predicate* p = clause->getPredicate(j);
00381       GroundPredicate* gp = new GroundPredicate(p);
00382       if (seenPreds->find(gp) !=  seenPreds->end() ||
00383           (allGndingsAreQueries && (*allGndingsAreQueries)[gp->getId()] > 1) )
00384       { 
00385         seenBefore = true;
00386         break;
00387       }
00388       delete gp;
00389     }
00390 
00391     //delete gndClause;
00392     if (seenBefore) return;
00393 
00394     GroundClause* gndClause = new GroundClause(truncClause, gndPreds);
00395     if (markHardGndClauses && isHardClause) gndClause->setWtToHardWt();
00396     assert(gndClause->getWt() != 0);
00397 
00398     GroundClauseSet::iterator iter = gndClausesSet->find(gndClause);
00399       // If the unknown clause is not in gndClauses
00400     if (iter == gndClausesSet->end())
00401     {
00402       gndClausesSet->insert(gndClause);
00403       gndClauses->append(gndClause);
00404       gndClause->appendToGndPreds(gndPreds);
00405         // gndClause's wt is set when it was constructed
00406       if (parentWtPtr)
00407       { 
00408         gndClause->appendParentWtPtr(parentWtPtr);
00409         gndClause->incrementClauseFrequency(clauseId, 1);
00410         assert(gndClause->getWt() == *parentWtPtr);
00411       }      
00412 
00413         // Add the unknown predicates of the clause to unseenPreds if 
00414         // the predicates are already not in it
00415       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00416       {
00417         GroundPredicate* gp =
00418           (GroundPredicate*)gndClause->getGroundPredicate(j, gndPreds);
00419         assert(seenPreds->find(gp) == seenPreds->end());
00420           // if the ground predicate is not in unseenPreds
00421         GroundPredicateSet::iterator it = unseenPreds->find(gp);
00422         if (it == unseenPreds->end())
00423         {
00424           //cout << "\tinserting into unseen pred: ";  
00425           //pred->print(cout, domain); cout << endl;
00426           unseenPreds->insert(gp);
00427         }
00428       }
00429     }
00430     else
00431     {  // gndClause has appeared before, so just accumulate its weight
00432       (*iter)->addWt(gndClause->getWt());
00433 
00434       if (parentWtPtr)
00435       {
00436         (*iter)->appendParentWtPtr(parentWtPtr);
00437         (*iter)->incrementClauseFrequency(clauseId, 1);
00438       }
00439 
00440       delete gndClause;
00441     }
00442   } //addUnknownGndClause()
00443 
00444 
00445 
00446   ~MRF()
00447   {
00448     for (int i = 0; i < gndClauses_->size(); i++)
00449       if ((*gndClauses_)[i]) delete (*gndClauses_)[i];
00450     delete gndClauses_;    
00451 
00452     for (int i = 0; i < gndPreds_->size(); i++)
00453       if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00454     delete gndPreds_;
00455     
00456     for (int i = 0; i < blocks_->size(); i++)
00457       (*blocks_)[i].clearAndCompress();
00458     delete blocks_;
00459     
00460     delete blockEvidence_;
00461   }
00462 
00463   void deleteGndPredsGndClauseSets()
00464   {
00465     for (int i = 0; i < gndPreds_->size(); i++)
00466       (*gndPreds_)[i]->deleteGndClauseSet();
00467   }  
00468 
00469   
00470   void setGndClausesWtsToSumOfParentWts()
00471   {
00472     for (int i = 0; i < gndClauses_->size(); i++)
00473       (*gndClauses_)[i]->setWtToSumOfParentWts();
00474   }
00475 
00476   const GroundPredicateHashArray* getGndPreds() const { return gndPreds_; }
00477 
00478   const Array<GroundClause*>* getGndClauses() const { return gndClauses_; }
00479 
00480  private:
00481 
00482   void addUnknownGndClauses(const GroundPredicate* const& queryGndPred,
00483                             Clause* const & c, const Domain* const & domain, 
00484                             const Database* const & db, 
00485                             const bool& genClauseForAllPredGndings,
00486                             const AddGroundClauseStruct* const & agcs)
00487   {
00488     
00489     if (genClauseForAllPredGndings)
00490       c->addUnknownClauses(domain, db, -1, NULL, agcs);
00491     else
00492     {
00493       for (int i = 0; i < c->getNumPredicates(); i++)
00494       {
00495         if (c->getPredicate(i)->canBeGroundedAs(queryGndPred))
00496           c->addUnknownClauses(domain, db, i, queryGndPred, agcs);
00497       }
00498     }
00499   } 
00500 
00501  public:
00502 
00503   const int getNumGndPreds()
00504   {
00505     return gndPreds_->size();
00506   }
00507 
00508   const int getNumGndClauses()
00509   {
00510     return gndClauses_->size();
00511   }
00512 
00513   Array<Array<int> >* getBlocks()
00514   {
00515     return blocks_;
00516   }
00517   
00518   Array<bool>* getBlockEvidence()
00519   {
00520     return blockEvidence_;
00521   }
00522   
00523  private:
00524   GroundPredicateHashArray* gndPreds_;
00525   Array<GroundClause*>* gndClauses_;
00526     // Blocks of gndPred indices which belong together
00527   Array<Array<int> >* blocks_;
00528     // Flags indicating if block is fulfilled by evidence
00529   Array<bool>* blockEvidence_;
00530 };
00531 
00532 
00533 #endif

Generated on Tue Jan 16 05:30:03 2007 for Alchemy by  doxygen 1.5.1