mrf.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, and Daniel Lowd.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, and Daniel Lowd. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, and Daniel Lowd in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef MRF_H_SEP_23_2005
00067 #define MRF_H_SEP_23_2005
00068 
00069 #include <sys/times.h>
00070 #include <sys/time.h>
00071 #include <cstdlib>
00072 #include <cfloat>
00073 #include <fstream>
00074 #include "timer.h"
00075 #include "mln.h"
00076 #include "groundpredicate.h"
00077 
00078 #define MAX_LINE 1000
00079 
00080 const bool mrfdebug = false;
00081 
00083 
00084   // used as parameter of addGndClause()
00085 struct AddGroundClauseStruct
00086 {
00087   AddGroundClauseStruct(const GroundPredicateSet* const & sseenPreds,
00088                         GroundPredicateSet* const & uunseenPreds,
00089                         GroundPredicateHashArray* const & ggndPreds,
00090                         const Array<int>* const & aallPredGndingsAreQueries,
00091                         GroundClauseSet* const & ggndClausesSet,
00092                         Array<GroundClause*>* const & ggndClauses,
00093                         const bool& mmarkHardGndClauses,
00094                         const double* const & pparentWtPtr,
00095                         const int & cclauseId)
00096     : seenPreds(sseenPreds), unseenPreds(uunseenPreds), gndPreds(ggndPreds),
00097       allPredGndingsAreQueries(aallPredGndingsAreQueries),
00098       gndClausesSet(ggndClausesSet),
00099       gndClauses(ggndClauses), markHardGndClauses(mmarkHardGndClauses),
00100       parentWtPtr(pparentWtPtr), clauseId(cclauseId) {}
00101   
00102   ~AddGroundClauseStruct() {}
00103   
00104   const GroundPredicateSet* seenPreds;
00105   GroundPredicateSet* unseenPreds;
00106   GroundPredicateHashArray* gndPreds;
00107   const Array<int>* allPredGndingsAreQueries;
00108   GroundClauseSet* gndClausesSet;
00109   Array<GroundClause*>* gndClauses;
00110   const bool markHardGndClauses;
00111   const double* parentWtPtr;
00112   const int clauseId;
00113 };
00114 
00116 
00117 
00118 class MRF
00119 {
00120  public:
00121     //allPredGndingsAreQueries[p] is 1 (true) if all groundings of predicate p 
00122     //are in queries, otherwise it is 0 (false). 
00123     //allPredGndingsAreQueries can be
00124     //NULL if none of the predicates have all their groundings as queries.
00125   MRF(const GroundPredicateHashArray* const& queries, 
00126       const Array<int>* const & allPredGndingsAreQueries,
00127       const Domain* const & domain,  const Database * const & db, 
00128       const MLN* const & mln, const bool& markHardGndClauses,
00129       const bool& trackParentClauseWts, const int& memLimit)
00130   {
00131     cout << "creating mrf..." << endl; 
00132     Timer timer;
00133     GroundPredicateSet unseenPreds, seenPreds;
00134     GroundPredicateToIntMap gndPredsMap;
00135     GroundClauseSet gndClausesSet;
00136     gndPreds_ = new GroundPredicateHashArray;
00137     gndClauses_ = new Array<GroundClause*>;
00138     long double memNeeded = 0;
00139         
00140       //add GroundPredicates in queries to unseenPreds
00141     for (int i = 0; i < queries->size(); i++)
00142     {
00143       GroundPredicate* gp = (*queries)[i];
00144       unseenPreds.insert(gp);
00145       int gndPredIdx = gndPreds_->append(gp);
00146       assert(gndPredsMap.find(gp) == gndPredsMap.end());
00147       gndPredsMap[gp] = gndPredIdx;
00148     }
00149 
00150       // If too much memory to build MRF then destroy it
00151     if (memLimit > 0)
00152     {
00153       memNeeded = sizeKB();
00154       if (memNeeded > memLimit)
00155       {
00156         for (int i = 0; i < gndClauses_->size(); i++)
00157           delete (*gndClauses_)[i];
00158         delete gndClauses_;    
00159 
00160         for (int i = 0; i < gndPreds_->size(); i++)
00161           if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00162         delete gndPreds_;
00163                     
00164         throw 1;
00165       }
00166     }
00167         
00168       //while there are still unknown preds we have not looked at
00169     while (!unseenPreds.empty())   
00170     {
00171       GroundPredicateSet::iterator predIt = unseenPreds.begin();
00172       GroundPredicate* pred = *predIt;
00173       unsigned int predId = pred->getId();
00174       //cout << "\tlooking at pred: ";  pred->print(cout, domain); cout << endl;
00175 
00176       bool genClausesForAllPredGndings = false;
00177         // if all groundings of predicate with predId are queries
00178       if (allPredGndingsAreQueries && (*allPredGndingsAreQueries)[predId] >= 1)
00179       {
00180           // if we have not generated gnd clauses containing the queries before
00181         if ((*allPredGndingsAreQueries)[predId] == 1) 
00182           genClausesForAllPredGndings = true;
00183         else
00184         {   //we have dealt with predicate predId already          
00185           //cout << "\terasing at pred: ";  pred->print(cout, domain); 
00186           //cout<< endl;
00187           unseenPreds.erase(predIt);
00188           seenPreds.insert(pred);
00189           continue;
00190         }
00191       }
00192         //get all clauses that contains pred with predId
00193       const Array<IndexClause*>* clauses
00194         = mln->getClausesContainingPred(predId);
00195 
00196         //for each clause, ground it and find those with unknown truth values,
00197         //dropping ground preds which do not matter to the clause's truth value
00198       for (int i = 0; clauses && i < clauses->size(); i++)
00199       {
00200         Clause* c = (*clauses)[i]->clause;
00201 //cout << "\tIn clause c: ";  c->printWithWtAndStrVar(cout, domain); cout << endl;
00202                 const int clauseId = mln->findClauseIdx(c);  
00203                 assert(clauseId >= 0);
00204                 
00205                   //ignore clause with zero weight
00206         if (c->getWt() == 0) continue;
00207 
00208           //add gnd clauses with unknown truth values to gndClauses_
00209         const double* parentWtPtr =
00210           (trackParentClauseWts) ? c->getWtPtr() : NULL;
00211         AddGroundClauseStruct agc(&seenPreds, &unseenPreds, gndPreds_,
00212                                   allPredGndingsAreQueries,
00213                                   &gndClausesSet, gndClauses_,
00214                                   markHardGndClauses, parentWtPtr,
00215                                   clauseId);
00216 
00217         try
00218         {
00219           addUnknownGndClauses(pred, c, domain, db, genClausesForAllPredGndings,
00220                                &agc);
00221         }
00222         catch (bad_alloc&)
00223         {
00224           cout << "Bad alloc when adding unknown ground clauses to MRF!\n";
00225           cerr << "Bad alloc when adding unknown ground clauses to MRF!\n";
00226           throw 1;
00227         }
00228 
00229           // If too much memory to build MRF then destroy it
00230         if (memLimit > 0)
00231         {
00232           memNeeded = sizeKB();
00233             //cout << "preds " << gndPreds_->size() << endl;
00234             //cout << "clauses " << gndClauses_->size() << endl;
00235             //cout << "memory " << memNeeded << endl;
00236             //cout << "limit " << memLimit << endl;
00237           if (memNeeded > memLimit)
00238           {
00239             for (int i = 0; i < gndClauses_->size(); i++)
00240               delete (*gndClauses_)[i];
00241             delete gndClauses_;    
00242 
00243             for (int i = 0; i < gndPreds_->size(); i++)
00244               if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00245             delete gndPreds_;
00246     
00247             throw 1;
00248           }
00249         }
00250       }
00251 
00252       //clauses with negative wts are handled by the inference algorithms
00253 
00254       //if all the gnd clauses that pred appears in have known truth value,
00255       //it is not added to gndPreds_ and excluded from the MCMC network
00256 
00257       //cout << "\terasing pred: ";  pred->print(cout, domain); cout << endl;
00258       unseenPreds.erase(predIt);
00259       seenPreds.insert(pred);
00260       if (genClausesForAllPredGndings)
00261       {
00262         assert(allPredGndingsAreQueries && 
00263                (*allPredGndingsAreQueries)[predId]==1);
00264           //indicate we have seen all groundings of pred predId
00265         (*allPredGndingsAreQueries)[predId]++;
00266       }
00267     }//while (!unseenPreds.empty())
00268 
00269     cout << "number of grounded predicates = " << gndPreds_->size() << endl;
00270     cout << "number of grounded clauses = " << gndClauses_->size() << endl;
00271     if (gndClauses_->size() == 0)
00272       cout<< "Markov blankets of query ground predicates are empty" << endl;
00273 
00274     if (mrfdebug)
00275     {
00276       cout << "Clauses in MRF: " << endl;
00277       for (int i = 0; i < gndClauses_->size(); i++)
00278       {
00279         (*gndClauses_)[i]->print(cout, domain, gndPreds_);
00280         cout << endl;
00281       }
00282     }
00283       // Compress preds
00284     for (int i = 0; i < gndPreds_->size(); i++)
00285       (*gndPreds_)[i]->compress();
00286 
00287     gndPreds_->compress();
00288     gndClauses_->compress();
00289 
00290     cout <<"Time taken to construct MRF = ";
00291     Timer::printTime(cout,timer.time());
00292     cout << endl;
00293   }
00294 
00298   long double sizeKB()
00299   {
00300       // # of ground clauses times memory for a ground clause +
00301       // # of ground predicates times memory for a ground predicate
00302     long double size = 0;
00303     for (int i = 0; i < gndClauses_->size(); i++)
00304       size += (*gndClauses_)[i]->sizeKB();
00305     for (int i = 0; i < gndPreds_->size(); i++)
00306       size += (*gndPreds_)[i]->sizeKB();
00307 
00308     return size;    
00309   }
00310 
00311     // Do not delete the clause and truncClause argument.
00312     // This function is tightly bound to Clause::createAndAddUnknownClause().
00313   static void addUnknownGndClause(const AddGroundClauseStruct* const & agcs, 
00314                                   const Clause* const & clause,
00315                                   const Clause* const & truncClause,
00316                                   const bool& isHardClause)
00317   {
00318     const GroundPredicateSet* seenPreds     = agcs->seenPreds;
00319     GroundPredicateSet*       unseenPreds   = agcs->unseenPreds;
00320     GroundPredicateHashArray* gndPreds      = agcs->gndPreds;
00321     const Array<int>* allGndingsAreQueries  = agcs->allPredGndingsAreQueries;
00322     GroundClauseSet*          gndClausesSet = agcs->gndClausesSet;
00323     Array<GroundClause*>*     gndClauses    = agcs->gndClauses;
00324     const bool markHardGndClauses           = agcs->markHardGndClauses;
00325     const double* parentWtPtr               = agcs->parentWtPtr;
00326     const int clauseId                      = agcs->clauseId;
00327 
00328     // Check none of the grounded clause's predicates have been seen before.
00329     // If any of them have been seen before, this clause has been created 
00330     // before (for that seen predicate), and can be ignored
00331 
00332       // Check the untruncated ground clause whether any of its predicates
00333       // have been seen before
00334     bool seenBefore = false;
00335     for (int j = 0; j < clause->getNumPredicates(); j++)
00336     {
00337       Predicate* p = clause->getPredicate(j);
00338       GroundPredicate* gp = new GroundPredicate(p);
00339       if (seenPreds->find(gp) !=  seenPreds->end() ||
00340           (allGndingsAreQueries && (*allGndingsAreQueries)[gp->getId()] > 1) )
00341       { 
00342         seenBefore = true;
00343         delete gp;
00344         break;
00345       }
00346       delete gp;
00347     }
00348 
00349     if (seenBefore) return;
00350 
00351     GroundClause* gndClause = new GroundClause(truncClause, gndPreds);
00352     if (markHardGndClauses && isHardClause) gndClause->setWtToHardWt();
00353     assert(gndClause->getWt() != 0);
00354 
00355     bool invertWt = false;
00356       // We want to normalize soft unit clauses to all be positives
00357     if (!isHardClause && gndClause->getNumGroundPredicates() == 1 &&
00358         !gndClause->getGroundPredicateSense(0))
00359     {
00360       gndClause->setGroundPredicateSense(0, true);
00361       gndClause->setWt(-gndClause->getWt());
00362       invertWt = true;
00363     }
00364 
00365     GroundClauseSet::iterator iter = gndClausesSet->find(gndClause);
00366       // If the unknown clause is not in gndClauses
00367     if (iter == gndClausesSet->end())
00368     {
00369       gndClausesSet->insert(gndClause);
00370       gndClauses->append(gndClause);
00371       gndClause->appendToGndPreds(gndPreds);
00372         // gndClause's wt is set when it was constructed
00373       if (parentWtPtr)
00374         gndClause->incrementClauseFrequency(clauseId, 1, invertWt);
00375 
00376         // Add the unknown predicates of the clause to unseenPreds if 
00377         // the predicates are already not in it
00378       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00379       {
00380         GroundPredicate* gp =
00381           (GroundPredicate*)gndClause->getGroundPredicate(j, gndPreds);
00382         assert(seenPreds->find(gp) == seenPreds->end());
00383           // if the ground predicate is not in unseenPreds
00384         GroundPredicateSet::iterator it = unseenPreds->find(gp);
00385         if (it == unseenPreds->end())
00386         {
00387           //cout << "\tinserting into unseen pred: ";  
00388           //pred->print(cout, domain); cout << endl;
00389           unseenPreds->insert(gp);
00390         }
00391       }
00392     }
00393     else
00394     {  // gndClause has appeared before, so just accumulate its weight
00395       (*iter)->addWt(gndClause->getWt());
00396 
00397       if (parentWtPtr)
00398         (*iter)->incrementClauseFrequency(clauseId, 1, invertWt);
00399 
00400       delete gndClause;
00401     }
00402   } //addUnknownGndClause()
00403 
00404 
00405 
00406   ~MRF()
00407   {
00408     for (int i = 0; i < gndClauses_->size(); i++)
00409       if ((*gndClauses_)[i]) delete (*gndClauses_)[i];
00410     delete gndClauses_;    
00411 
00412     for (int i = 0; i < gndPreds_->size(); i++)
00413       if ((*gndPreds_)[i]) delete (*gndPreds_)[i];
00414     delete gndPreds_;
00415   }
00416 
00417   void deleteGndPredsGndClauseSets()
00418   {
00419     for (int i = 0; i < gndPreds_->size(); i++)
00420       (*gndPreds_)[i]->deleteGndClauseSet();
00421   }  
00422 
00423   const GroundPredicateHashArray* getGndPreds() const { return gndPreds_; }
00424 
00425   const Array<GroundClause*>* getGndClauses() const { return gndClauses_; }
00426 
00427  private:
00428 
00429   void addUnknownGndClauses(const GroundPredicate* const& queryGndPred,
00430                             Clause* const & c, const Domain* const & domain, 
00431                             const Database* const & db, 
00432                             const bool& genClauseForAllPredGndings,
00433                             const AddGroundClauseStruct* const & agcs)
00434   {
00435     
00436     if (genClauseForAllPredGndings)
00437       c->addUnknownClauses(domain, db, -1, NULL, agcs);
00438     else
00439     {
00440       for (int i = 0; i < c->getNumPredicates(); i++)
00441       {
00442         if (c->getPredicate(i)->canBeGroundedAs(queryGndPred))
00443           c->addUnknownClauses(domain, db, i, queryGndPred, agcs);
00444       }
00445     }
00446   } 
00447 
00448  public:
00449 
00450   const int getNumGndPreds()
00451   {
00452     return gndPreds_->size();
00453   }
00454 
00455   const int getNumGndClauses()
00456   {
00457     return gndClauses_->size();
00458   }
00459 
00460  private:
00461   GroundPredicateHashArray* gndPreds_;
00462   Array<GroundClause*>* gndClauses_;
00463 };
00464 
00465 
00466 #endif

Generated on Sun Jun 7 11:55:13 2009 for Alchemy by  doxygen 1.5.1