factorgraph.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #ifndef FACTORGRAPH_H_
00068 #define FACTORGRAPH_H_
00069 
00070 #include "twowaymessage.h"
00071 #include "superclause.h"
00072 #include "auxfactor.h"
00073 #include "node.h"
00074 
00075 const int fgdebug = false;
00076 
00087 class FactorGraph
00088 {
00089  public:
00090 
00094   FactorGraph(bool lifted, MLN* mln, Domain* domain,
00095               Array<Array<Predicate* >* >* queryFormulas = NULL)
00096   {
00097     lifted_ = lifted;
00098     
00099     lidToTWMsg_ = new LinkIdToTwoWayMessageMap();
00100     superClausesArr_ = new Array<Array<SuperClause*>*>();
00101     factors_ = new Array<Factor*>();
00102     nodes_ = new Array<Node*>();
00103     mln_ = mln;
00104     domain_ = domain;
00105     
00106     auxFactors_ = NULL;
00107     if (queryFormulas)
00108     {
00109       auxFactors_ = new Array<AuxFactor*>();
00110       for (int i = 0; i < queryFormulas->size(); i++)
00111         auxFactors_->append(new AuxFactor((*queryFormulas)[i]));
00112     }
00113   }
00114 
00118   ~FactorGraph()
00119   {
00120     delete lidToTWMsg_;
00121     delete superClausesArr_;
00122     delete factors_;
00123     delete nodes_;
00124     if (auxFactors_) delete auxFactors_;
00125   }
00126   
00131   void init()
00132   {
00133     Timer timer1;
00134     cout << "Building ";
00135     if (lifted_) cout << "Lifted ";
00136     cout << "Factor Graph..." << endl;
00137 
00138     if (lifted_)
00139     {
00140       createSuper();
00141       createSuperNetwork();
00142     }
00143     else
00144     {
00145       createGround();
00146       createGroundNetwork();
00147     }
00148     
00149     if (fgdebug)
00150     {
00151       cout << "[init] ";
00152       Timer::printTime(cout, timer1.time());
00153       cout << endl;
00154       timer1.reset();
00155     }
00156   }
00157 
00161   void printNetwork(ostream& out)
00162   {
00163     out << "variables:" << endl;
00164     for (int i = 0; i < nodes_->size(); i++)
00165     {
00166       (*nodes_)[i]->print(out);
00167       out << endl;
00168     }
00169     out << "factors:" << endl;
00170     for (int i = 0; i < factors_->size(); i++)
00171     {
00172       (*factors_)[i]->print(out);
00173       out << "// ";
00174       (*factors_)[i]->printWts(out);
00175       out << endl;
00176     }
00177   }
00178 
00182   LinkIdToTwoWayMessageMap* getLinkIdToTwoWayMessageMap()
00183   {
00184     return lidToTWMsg_;
00185   }
00186 
00190   const int getNumNodes()
00191   {
00192     return nodes_->size();
00193   }
00194   
00198   const int getNumFactors()
00199   {
00200     return factors_->size();
00201   }
00202   
00206   const int getNumAuxFactors()
00207   {
00208     if (auxFactors_ == NULL) return 0;
00209     return auxFactors_->size();
00210   }
00211   
00217   Node* getNode(const int& index)
00218   {
00219     return (*nodes_)[index];
00220   }
00221 
00227   Factor* getFactor(const int& index)
00228   {
00229     return (*factors_)[index];
00230   }
00231 
00237   AuxFactor* getAuxFactor(const int& index)
00238   {
00239     return (*auxFactors_)[index];
00240   }
00241 
00245   Domain* getDomain()
00246   {
00247     return domain_;
00248   }
00249 
00250  private:
00251  
00256   void createGround()
00257   {
00258     MLN* mln = mln_;
00259     Domain* domain = domain_;
00260     Clause* mlnClause;
00261 
00262       //clause with all variables in it
00263     Clause *varClause;
00264 
00265     Array<Clause *> allClauses;
00266     Array<int> *mlnClauseTermIds;
00267   
00268     ClauseToSuperClauseMap *clauseToSuperClause;
00269     ClauseToSuperClauseMap::iterator clauseItr;
00270     int numClauses = mln->getNumClauses();
00271   
00272     PredicateTermToVariable *ptermToVar = NULL;
00273   
00274     double gndtime;
00275     Util::elapsed_seconds();
00276   
00277     clauseToSuperClause = new ClauseToSuperClauseMap();
00278     for (int i = 0; i < numClauses; i++)
00279     {
00280         //remove the unknown predicates
00281       mlnClause = (Clause*) mln->getClause(i);
00282       varClause = new Clause(*mlnClause);
00283       mlnClauseTermIds = varClause->updateToVarClause();
00284      
00285       mlnClause->getConstantTuples(domain, domain->getDB(), mlnClauseTermIds, 
00286                                    varClause, ptermToVar, clauseToSuperClause,
00287                                    false);
00288 
00289         //can delete the var Clause now
00290       delete varClause;
00291     }
00292 
00293     clauseToSuperClause = mergeSuperClauses(clauseToSuperClause);
00294     addSuperClauses(clauseToSuperClause);
00295   
00296     gndtime = Util::elapsed_seconds();
00297 
00298       //create the super preds  
00299     int totalGndClauseCnt = 0;
00300     for (int i = 0; i < superClausesArr_->size(); i++)
00301     {
00302       SuperClause *superClause = (*(*superClausesArr_)[i])[0];
00303       int gndCnt = superClause->getNumTuplesIncludingImplicit();
00304       totalGndClauseCnt += gndCnt;
00305     }
00306     cout<<"Total Number of Ground Clauses = "<<totalGndClauseCnt<<endl;
00307   }
00308  
00309  
00313   void createGroundNetwork()
00314   {
00315     Domain* domain = domain_;
00316     int domainPredCnt = domain->getNumPredicates();
00317     Array<IntArrayToIntMap*>* pconstantsToNodeIndexArr =
00318       new Array<IntArrayToIntMap*>();
00319     pconstantsToNodeIndexArr->growToSize(domainPredCnt);
00320      
00321       // Initialize the mappings from pred constants to the nodeIndices
00322     for(int i = 0; i < domainPredCnt; i++)
00323     {
00324       (*pconstantsToNodeIndexArr)[i] = new IntArrayToIntMap();
00325     }
00326      
00327     IntArrayToIntMap *pconstantsToNodeIndex;
00328     IntArrayToIntMap::iterator itr;
00329      
00330     Array<SuperClause*>* superClauses;
00331     SuperClause* superClause;
00332 
00333     Clause *clause;
00334     Predicate *pred;
00335     Array<int>* constants;
00336     Array<int>* pconstants;
00337     
00338     Factor* factor;
00339     Node* node;
00340 
00341     for (int arrIndex = 0; arrIndex < superClausesArr_->size(); arrIndex++)
00342     {
00343       superClauses = (*superClausesArr_)[arrIndex];
00344       for (int scindex = 0; scindex < superClauses->size(); scindex++)
00345       {
00346         superClause = (*superClauses)[scindex];
00347         clause = superClause->getClause();
00348         int numTuples = superClause->getNumTuples();
00349         for (int tindex = 0; tindex < numTuples; tindex++)
00350         {
00351             // Number of times this tuple appears in the superclause
00352             // (because of counting elimination)
00353           double tcnt = superClause->getTupleCount(tindex);
00354           constants = superClause->getConstantTuple(tindex);
00355 
00356             // MS: Set to real weight, not 1
00357           //clause->setWt(superClause->getOutputWt());
00358           factor = new Factor(clause, NULL, constants, domain,
00359                                 superClause->getOutputWt());
00360           factors_->append(factor);
00361           for (int pindex = 0; pindex < clause->getNumPredicates(); pindex++)
00362           {
00363             pred = clause->getPredicate(pindex);
00364             pconstants = pred->getPredicateConstants(constants);
00365             
00366             int predId = domain->getPredicateId(pred->getName());
00367             pconstantsToNodeIndex = (*pconstantsToNodeIndexArr)[predId];
00368             itr = pconstantsToNodeIndex->find(pconstants);
00369                  
00370             if (itr == pconstantsToNodeIndex->end())
00371             {
00372               int nodeIndex = nodes_->size();
00373               (*pconstantsToNodeIndex)[pconstants] = nodeIndex;
00374 
00375               node = new Node(predId, NULL, pconstants, domain);
00376               nodes_->append(node);
00377             }
00378             else
00379             {
00380               delete pconstants;
00381               int nodeIndex = itr->second;
00382               node = (*nodes_)[nodeIndex];
00383             }
00384 
00385               // Now, add the links to the node/factor
00386             int reverseNodeIndex = factor->getNumLinks();
00387               // Index where this factor would be stored in the list of nodes
00388             int reverseFactorIndex = node->getNumLinks();
00389             Link * link = new Link(node, factor, reverseNodeIndex,
00390                                        reverseFactorIndex, pindex, tcnt);
00391             node->addLink(link, NULL);
00392             factor->addLink(link, NULL);
00393           }
00394         }
00395       }
00396     }
00397 
00398       // If query formulas present, make aux. links to the nodes
00399     if (auxFactors_)
00400     {
00401       for (int i = 0; i < auxFactors_->size(); i++)
00402       {
00403         AuxFactor* auxFactor = (*auxFactors_)[i];
00404         Array<Predicate* >* formula = auxFactor->getFormula();
00405         for (int pindex = 0; pindex < formula->size(); pindex++)
00406         {
00407           double z = 1.0;
00408           pred = (*formula)[pindex];
00409           pconstants = pred->getPredicateConstants();
00410           int predId = domain->getPredicateId(pred->getName());
00411 
00412           pconstantsToNodeIndex = (*pconstantsToNodeIndexArr)[predId];
00413           itr = pconstantsToNodeIndex->find(pconstants);
00414 
00415           if (itr == pconstantsToNodeIndex->end())
00416           {
00417             cout << "ERROR: couldn't find predicate ";
00418             pred->printWithStrVar(cout, domain);
00419             cout << " from query formula in factor graph" << endl;
00420             exit(-1);
00421           }
00422           else
00423           {
00424             delete pconstants;
00425             int nodeIndex = itr->second;
00426             node = (*nodes_)[nodeIndex];
00427           }
00428 
00429             // Now, add the links to the node/factor
00430           int reverseNodeIndex = auxFactor->getNumLinks();
00431             // Index where this factor would be stored in the list of nodes
00432           int reverseFactorIndex = node->getNumLinks();
00433           Link* link = new Link(node, auxFactor, reverseNodeIndex,
00434                                     reverseFactorIndex, pindex, z);
00435           node->addAuxLink(link);
00436           auxFactor->addLink(link, NULL);
00437         }
00438       }
00439     }
00440 
00441       // clean up
00442     for (int predId = 0; predId < domainPredCnt; predId++)
00443     {
00444       pconstantsToNodeIndex = (*pconstantsToNodeIndexArr)[predId];
00445       pconstantsToNodeIndex->clear();
00446       delete pconstantsToNodeIndex;
00447     }
00448     delete pconstantsToNodeIndexArr;
00449     cout << "Created Ground Network" << endl;
00450   }
00451 
00456   void createSuper()
00457   {
00458     MLN* mln = mln_;
00459     Domain* domain = domain_;
00460     Clause* mlnClause;
00461 
00462       //clause with all variables in it
00463     Clause *varClause;
00464     Array<Clause *> allClauses;
00465     Array<int> *mlnClauseTermIds;
00466        
00467     ClauseToSuperClauseMap *clauseToSuperClause;
00468     ClauseToSuperClauseMap::iterator clauseItr;
00469 
00470     int numClauses = mln->getNumClauses();
00471        
00472     PredicateTermToVariable *ptermToVar = NULL;
00473     double gndtime, setuptime;
00474     Util::elapsed_seconds();
00475 
00476     if (implicitRep_)
00477     {
00478       ptermToVar = getPredicateTermToVariableMap(mln, domain);
00479       getIndexedConstants(ptermToVar, mln, domain);
00480     }
00481       
00482     clauseToSuperClause = new ClauseToSuperClauseMap();
00483     for (int i = 0; i < numClauses; i++)
00484     {
00485         //remove the unknown predicates
00486       mlnClause = (Clause*) mln->getClause(i);
00487       varClause = new Clause(*mlnClause);
00488       mlnClauseTermIds = varClause->updateToVarClause();
00489           
00490       mlnClause->getConstantTuples(domain, domain->getDB(), mlnClauseTermIds, 
00491                                    varClause, ptermToVar, clauseToSuperClause,
00492                                    implicitRep_);
00493           
00494         //can delete the var Clause now
00495       delete varClause;
00496     }
00497       
00498     clauseToSuperClause = mergeSuperClauses(clauseToSuperClause);
00499     addSuperClauses(clauseToSuperClause);
00500 
00501     gndtime = Util::elapsed_seconds();
00502 
00503     cout << endl << endl;
00504     cout << "*****************************************************************"
00505          << endl << endl;
00506     cout << "Now, starting the iterations of creating supernodes/superfeatures"
00507          << endl;
00508       //now create the super preds corresponding to the current set of
00509       //superclauses
00510 
00511     SuperClause *superClause;
00512     int totalTupleCnt = 0;
00513     int totalGndTupleCnt = 0;
00514 
00515     cout << "Counts in the beginning:" << endl;
00516     for (int i = 0; i < superClausesArr_->size(); i++)
00517     {
00518       superClause = (*(*superClausesArr_)[i])[0];
00519          
00520       int cnt = superClause->getNumTuples();
00521       int gndCnt = superClause->getNumTuplesIncludingImplicit();
00522 
00523       totalTupleCnt += cnt;
00524       totalGndTupleCnt += gndCnt;
00525     }
00526     cout << "Total Number of Ground Tuples = " << totalGndTupleCnt << endl;
00527     cout << "Total Number of Tuples Created = " << totalTupleCnt << endl;
00528 
00529     /*************************************************************************
00530      * Start the Iterations now */
00531     /*************************************************************************/
00532        
00533     int newSuperClauseCnt = getNumArrayArrayElements(*superClausesArr_);
00534     int superClauseCnt = newSuperClauseCnt;
00535     int itr = 1;
00536     cout << "********************************************************"
00537          << endl << endl;
00538     setuptime = 0;
00539 
00540       // For creation of the Network
00541     Array<Factor *> * factors = new Array<Factor *>();
00542     Array<Node *> * nodes = new Array<Node *>();
00543 
00544     while (newSuperClauseCnt != superClauseCnt || itr <= 2)
00545     {
00546       superClauseCnt = newSuperClauseCnt;
00547       cout << "***************************************************************"
00548            << endl;
00549       cout << "Iteration: " << itr << endl;
00550            
00551         //for iteration 1, superclauses have already been created
00552       if (itr > 1)
00553       {
00554         cout << "Creating Super Clauses.. " << endl;
00555         createSuperClauses(superClausesArr_, domain);
00556         newSuperClauseCnt = getNumArrayArrayElements(*superClausesArr_);
00557       }
00558        
00559       cout << "Creating New Super Preds.. " << endl;
00560       createSuperPreds(superClausesArr_, domain);
00561        
00562       cout << "Number of superclauses after this iteration is = "
00563            << newSuperClauseCnt << endl;
00564       itr++;
00565     }
00566 
00567     superClauseCnt = getNumArrayArrayElements(*superClausesArr_);
00568     cout << "***************************************************************"
00569          << endl;
00570     cout << "Total Number of Super Clauses = " << superClauseCnt << endl;
00571   
00572     int predCnt = domain->getNumPredicates();
00573     const PredicateTemplate *ptemplate;
00574     for (int predId = 0; predId < predCnt; predId++)
00575     {
00576       ptemplate = domain->getPredicateTemplate(predId);
00577       if (ptemplate->isEqualPredicateTemplate())
00578         continue;
00579       int cnt = SuperPred::getSuperPredCount(predId);
00580       if (cnt > 0)
00581       {
00582         cout<<"SuperPred count for pred: ";
00583         ptemplate->print(cout);
00584         cout << " = " << cnt << endl;
00585       }
00586     }
00587 
00588     for (int i = 0; i < nodes->size(); i++) delete (*nodes)[i];
00589     for (int i = 0; i < factors->size(); i++) delete (*factors)[i];
00590     nodes->clear();
00591     factors->clear();
00592   }
00593 
00594 
00598   void createSuperNetwork()
00599   {
00600     Domain* domain = domain_;
00601     Array<SuperPred*> * superPreds;
00602     Array<SuperClause*> *superClauses;
00603   
00604     SuperClause *superClause;
00605     SuperPred *superPred;
00606 
00607     Factor *factor;
00608     Node *node;
00609     Clause *clause;
00610     Array<int>* constants = NULL;
00611 
00612       //create the factor (superclause) nodes
00613     for (int arrIndex = 0; arrIndex < superClausesArr_->size(); arrIndex++)
00614     {
00615       superClauses = (*superClausesArr_)[arrIndex];
00616       for (int scindex = 0; scindex < superClauses->size(); scindex++)
00617       {
00618         superClause = (*superClauses)[scindex];
00619         clause = superClause->getClause();
00620           // MS: Set to real weight, not 1
00621         //clause->setWt(superClause->getOutputWt());
00622         factor = new Factor(clause, superClause, constants, domain,
00623                               superClause->getOutputWt());
00624         factors_->append(factor);
00625       }
00626     }
00627 
00628       //create the variable (superpreds) nodes
00629     int predCnt = domain->getNumPredicates();
00630     for (int predId = 0;predId < predCnt; predId++)
00631     {
00632       superPreds = SuperPred::getSuperPreds(predId);
00633       for (int spindex = 0; spindex < superPreds->size(); spindex++)
00634       {
00635         superPred = (*superPreds)[spindex];
00636         node = new Node(predId, superPred, constants, domain);
00637         node->addFactors(factors_, getLinkIdToTwoWayMessageMap());
00638         nodes_->append(node);
00639       }
00640     }
00641   }
00642   
00646   void updateLinkIdToTwoWayMessageMap()
00647   {
00648     Node *node;
00649     Link *link;
00650     Factor *factor;
00651     double nodeToFactorMsgs[2];
00652     double factorToNodeMsgs[2];
00653 
00654     LinkId *lid;
00655     TwoWayMessage *tmsg;
00656     LinkIdToTwoWayMessageMap::iterator lidToTMsgItr;
00657      
00658       // Delete the old values
00659     Array<LinkId*> keysArr;
00660     for (lidToTMsgItr = lidToTWMsg_->begin();
00661          lidToTMsgItr != lidToTWMsg_->end();
00662          lidToTMsgItr++)
00663     {
00664       keysArr.append(lidToTMsgItr->first);
00665       tmsg = lidToTMsgItr->second;
00666       delete tmsg;
00667     }
00668                        
00669     for (int i = 0; i < keysArr.size(); i++)
00670     {
00671       delete keysArr[i];
00672     }
00673     lidToTWMsg_->clear();
00674 
00675       // Now populate
00676     for (int i = 0; i < nodes_->size(); i++)
00677     {
00678       node = (*nodes_)[i];
00679       for (int j = 0; j < node->getNumLinks(); j++)
00680       {
00681         link = node->getLink(j);
00682         factor = link->getFactor();
00683               
00684         int predId = node->getPredId();
00685         int superPredId = node->getSuperPredId();
00686         int superClauseId = factor->getSuperClauseId();
00687         int predIndex = link->getPredIndex(); 
00688               
00689         lid = new LinkId(predId, superPredId, superClauseId, predIndex);
00690 
00691         int reverseFactorIndex = link->getReverseFactorIndex();
00692         node->getMessage(reverseFactorIndex, nodeToFactorMsgs);
00693 
00694         int reverseNodeIndex = link->getReverseNodeIndex();
00695         factor->getMessage(reverseNodeIndex, factorToNodeMsgs);
00696  
00697         tmsg = new TwoWayMessage(nodeToFactorMsgs,factorToNodeMsgs);
00698         (*lidToTWMsg_)[lid] = tmsg;
00699       }
00700     }
00701   }
00702   
00706   void addSuperClauses(ClauseToSuperClauseMap* const & clauseToSuperClause)
00707   {
00708     ClauseToSuperClauseMap::iterator clauseItr;
00709     Array<SuperClause *> * superClauses;
00710     SuperClause *superClause;
00711      
00712     for(clauseItr = clauseToSuperClause->begin();
00713         clauseItr != clauseToSuperClause->end(); 
00714         clauseItr++)
00715     {
00716       superClauses = new Array<SuperClause *>();
00717       superClausesArr_->append(superClauses);
00718       superClause = clauseItr->second;
00719       superClauses->append(superClause);
00720     }
00721   }
00722   
00723   
00727   ClauseToSuperClauseMap*
00728   mergeSuperClauses(ClauseToSuperClauseMap* const & clauseToSuperClause)
00729   {
00730     Domain* domain = domain_;
00731     ClauseToSuperClauseMap *mergedClauseToSuperClause =
00732       new ClauseToSuperClauseMap();
00733     SuperClause *superClause, *mergedSuperClause;
00734     Clause *keyClause;
00735     ClauseToSuperClauseMap::iterator itr, mergedItr;
00736     Array<int> * constants;
00737     double tcnt;
00738     for (itr = clauseToSuperClause->begin();
00739          itr != clauseToSuperClause->end();
00740          itr++)
00741     {
00742       superClause = itr->second;
00743       keyClause = superClause->getClause();
00744       mergedItr = mergedClauseToSuperClause->find(keyClause);
00745       if (mergedItr != mergedClauseToSuperClause->end())
00746       {
00747         mergedSuperClause = mergedItr->second;
00748         for (int tindex = 0; tindex < superClause->getNumTuples(); tindex++)
00749         {
00750           constants = superClause->getConstantTuple(tindex);
00751           tcnt = superClause->getTupleCount(tindex);
00752           mergedSuperClause->addNewConstantsAndIncrementCount(constants, tcnt);
00753           delete constants;
00754         }
00755         delete superClause;
00756       }
00757       else
00758       {
00759         (*mergedClauseToSuperClause)[keyClause] = superClause;
00760         keyClause->print(cout, domain);
00761         cout << endl;
00762       }
00763     }
00764     return mergedClauseToSuperClause;
00765   }
00766   
00771   PredicateTermToVariable* getPredicateTermToVariableMap(MLN * const & mln,
00772                                                          Domain* const & domain)
00773   {
00774     Clause *clause;
00775     Predicate *pred;
00776     const Term* term;
00777     Array<Variable *> *eqVars = new Array<Variable *>();
00778     int eqClassId = 0;
00779     Variable *var;
00780     const PredicateTemplate *ptemplate;
00781     const Array<int>* constants;
00782 
00783     for (int clauseno = 0; clauseno < mln->getNumClauses(); clauseno++)
00784     {
00785       clause = (Clause *)mln->getClause(clauseno);
00786       for (int predno = 0; predno < clause->getNumPredicates(); predno++)
00787       {
00788         pred = clause->getPredicate(predno);
00789         ptemplate = pred->getTemplate();
00790         for (int termno = 0; termno < pred->getNumTerms(); termno++)
00791         {
00792           term = pred->getTerm(termno);
00793           int varId = term->getId();
00794           int varTypeId = ptemplate->getTermTypeAsInt(termno);
00795           constants = domain->getConstantsByType(varTypeId); 
00796           var = new Variable(clause,varId,pred,termno,eqClassId,constants);
00797           eqVars->append(var);
00798           eqClassId++;
00799         }
00800       }
00801     }
00802     
00803     Variable  *var1, *var2;
00804     for (int i = 0; i < eqVars->size(); i++)
00805     {
00806       var1 = (*eqVars)[i];
00807       for (int j = i + 1; j < eqVars->size(); j++)
00808       {
00809         var2 = (*eqVars)[j];
00810         if (var1->same(var2))
00811         {
00812           var1->merge(var2); 
00813         }
00814       }
00815     }
00816 
00817       //now populate the map
00818     PredicateTermToVariable *ptermToVar = new PredicateTermToVariable();
00819     PredicateTermToVariable::iterator itr;
00820     PredicateTerm *pterm;
00821     Variable *tiedVar;
00822     int uniqueCnt = 0;
00823     for (int i = 0; i < eqVars->size(); i++)
00824     {
00825       var = (*eqVars)[i];
00826       if (var->isRepresentative())
00827       {
00828         uniqueCnt++;
00829         for (int j = 0; j < var->getNumTiedVariables(); j++)
00830         {
00831           tiedVar = var->getTiedVariable(j);
00832           int predId = tiedVar->getPredId();
00833           int termno = tiedVar->getTermno();
00834           pterm = new PredicateTerm(predId,termno);
00835           itr = ptermToVar->find(pterm);
00836           if (itr == ptermToVar->end())
00837           {
00838             (*ptermToVar)[pterm] = var;
00839           }
00840           else
00841           {
00842             delete pterm;
00843           }
00844         }
00845       }
00846     }
00847 
00848     cout << "size of PtermToVarMap is " << ptermToVar->size() << endl;
00849     cout << "count of Variable Eq Classes (Unique) is = " << uniqueCnt << endl;
00850     return ptermToVar;
00851   }
00852 
00853 
00854   void getIndexedConstants(PredicateTermToVariable * const & ptermToVar, 
00855                            MLN * const & mln, 
00856                            Domain * const & domain)
00857   {
00858     Predicate *pred, *gndPred;
00859     IntHashArray seenPredIds;
00860     const Clause *clause;
00861     const Term *term;
00862      
00863     Array<Predicate *> * indexedGndings;
00864     PredicateTerm *pterm;
00865     Database * db;
00866      
00867     int predId, termId, constantId;
00868     bool ignoreActivePreds = true;
00869 
00870     PredicateTermToVariable::iterator itr;
00871     Variable * var;
00872 
00873     indexedGndings = new Array<Predicate *>();
00874     db = domain->getDB();
00875     cout << "size of PtermToVarMap is " << ptermToVar->size() << endl;
00876      
00877     Clause *varClause;
00878     for (int clauseno = 0; clauseno < mln->getNumClauses(); clauseno++)
00879     {
00880       clause = mln->getClause(clauseno);    
00881       varClause = new Clause(*clause);
00882       varClause->updateToVarClause();
00883 
00884         //to make sure that we do not use clause
00885       clause = NULL;
00886 
00887       for (int predno = 0; predno < varClause->getNumPredicates(); predno++)
00888       {
00889         pred = varClause->getPredicate(predno);
00890         predId = pred->getId();
00891 
00892         if (seenPredIds.append(predId) < 0)
00893           continue;
00894         indexedGndings->clear();
00895           //Note: we assume that every predicate is indexable
00896         if(db->isClosedWorld(predId))
00897         {
00898             //precidate is closed world - rettrieve only true groundings
00899           db->getIndexedGndings(indexedGndings,pred,ignoreActivePreds,true);
00900         }
00901         else
00902         {
00903             //predicate is open world - retrieve both true and false groundings  
00904           db->getIndexedGndings(indexedGndings,pred,ignoreActivePreds,true);
00905           db->getIndexedGndings(indexedGndings,pred,ignoreActivePreds,false);
00906         }
00907         
00908         for (int gndno = 0; gndno < indexedGndings->size(); gndno++)
00909         {
00910           gndPred = (*indexedGndings)[gndno];
00911 
00912           for (int termno = 0; termno < gndPred->getNumTerms(); termno++)
00913           {
00914             pterm = new PredicateTerm(predId, termno);
00915             itr = ptermToVar->find(pterm);
00916             assert(itr != ptermToVar->end());
00917             var = itr->second;
00918             term = gndPred->getTerm(termno);
00919             constantId = term->getId();
00920             var->removeImplicit(constantId);
00921             delete pterm;
00922           }
00923           delete (*indexedGndings)[gndno];
00924         }
00925       }
00926       delete varClause;
00927     }
00928    
00929       //now explicitly handle the constants appearing in the clause
00930     for (int clauseno = 0; clauseno < mln->getNumClauses(); clauseno++)
00931     {
00932       clause = mln->getClause(clauseno);    
00933       for (int predno = 0; predno < clause->getNumPredicates(); predno++)
00934       {
00935         pred = clause->getPredicate(predno);
00936         predId = pred->getId();
00937         for (int termno = 0; termno < pred->getNumTerms(); termno++)
00938         {
00939           term = pred->getTerm(termno);
00940           termId = term->getId();
00941             // if it is a variable, nothing to do
00942           if (termId < 0) continue;
00943             // else, this constant also should be added the list of
00944             // indexed constants
00945           pterm = new PredicateTerm(predId, termno);
00946           itr = ptermToVar->find(pterm);
00947           assert(itr != ptermToVar->end());
00948           var = itr->second;
00949           var->removeImplicit(termId);
00950           delete pterm;
00951         }
00952       }
00953     }
00954 
00955     IntHashArray *seenEqClassIds = new IntHashArray();
00956     cout << "Implicit Set of constants are: " << endl;
00957     for (itr = ptermToVar->begin(); itr != ptermToVar->end(); itr++)
00958     {
00959       pterm = itr->first;
00960       var = itr->second;
00961       int eqClassId = var->getEqClassId();
00962       if (seenEqClassIds->find(eqClassId) >= 0)
00963         continue;
00964       seenEqClassIds->append(eqClassId);
00965       cout << "Implicit Constants for Eq class " << eqClassId << endl;
00966       cout << "Count =  " << var->getNumImplicitConstants() << " => " << endl;
00967       var->printImplicitConstants(cout, domain);
00968       cout << endl << endl << endl;
00969     }
00970     delete seenEqClassIds;
00971   }
00972   
00973  private:
00974     // Indicates if lifted inference will be run
00975   bool lifted_;
00976     // Indicates if implicit representation is to be used
00977   bool implicitRep_;
00978 
00979   LinkIdToTwoWayMessageMap* lidToTWMsg_;
00980 
00981   Array<Array<SuperClause*>*>* superClausesArr_;
00982 
00983     // Factors in the graph  
00984   Array<Factor*>* factors_;
00985     // Nodes in the graph
00986   Array<Node*>* nodes_;
00987 
00988     // MLN from which the factor graph is built
00989   MLN* mln_;
00990     // Domain containing the constants from which the factor graph is built
00991   Domain* domain_;
00992   
00993     // Stores auxiliary factors used for query formulas
00994   Array<AuxFactor*>* auxFactors_;
00995 };
00996 
00997 #endif /*FACTORGRAPH_H_*/

Generated on Sun Jun 7 11:55:11 2009 for Alchemy by  doxygen 1.5.1