bp.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #ifndef BP_H_
00068 #define BP_H_
00069 
00070 #include "inference.h"
00071 #include "bpparams.h"
00072 #include "twowaymessage.h"
00073 #include "superclause.h"
00074 #include "auxfactor.h"
00075 #include "node.h"
00076 #include "factorgraph.h"
00077 
00078 const int bpdebug = true;
00079 
00084 class BP : public Inference
00085 {
00086  public:
00087 
00092   BP(FactorGraph* factorGraph, BPParams* bpParams,
00093      Array<Array<Predicate* >* >* queryFormulas = NULL)
00094     : Inference(NULL, -1, false, queryFormulas)
00095   {
00096     factorGraph_ = factorGraph;
00097     maxSteps_ = bpParams->maxSteps;
00098     maxSeconds_ = bpParams->maxSeconds;
00099     convergenceThresh_ = bpParams->convergenceThresh;
00100     convergeRequiredItrCnt_ = bpParams->convergeRequiredItrCnt;
00101     outputNetwork_ = bpParams->outputNetwork;
00102   }
00103 
00107   ~BP()
00108   {
00109   }
00110   
00114   void init()
00115   {
00116     Timer timer1;
00117     cout << "Initializing ";
00118     cout << "Belief Propagation..." << endl;
00119 
00120     factorGraph_->init();
00121     if (bpdebug)
00122     {
00123       cout << "[init] ";
00124       Timer::printTime(cout, timer1.time());
00125       cout << endl;
00126       timer1.reset();
00127     }
00128   }
00129 
00133   void infer()
00134   {
00135     Timer timer1;
00136 
00137     double oldProbs[2];
00138     double newProbs[2];
00139     double diff;
00140     double maxDiff;
00141     int maxDiffNodeIndex;
00142     int convergeItrCnt = 0;
00143     bool converged = false;
00144     int numFactors = factorGraph_->getNumFactors();
00145     int numNodes = factorGraph_->getNumNodes();
00146 
00147     cout << "factorcnt = " << numFactors
00148          << ", nodecnt = " << numNodes << endl;
00149 
00150     if (bpdebug)
00151     {
00152       cout << "factors:" << endl;
00153       for (int i = 0; i < numFactors; i++)
00154       {
00155         factorGraph_->getFactor(i)->print(cout);
00156         cout << endl;
00157       }
00158       cout << "nodes:" << endl;
00159       for (int i = 0; i < numNodes; i++)
00160       {
00161         factorGraph_->getNode(i)->print(cout);
00162         cout << endl;
00163       }
00164     }
00165     
00166       // Pass around (send) the messages
00167     int itr;
00168     
00169       // Move to next step to transfer the message in the nextMsgsArr in the
00170       // beginning to the the msgsArr 
00171     for (itr = 1; itr <= maxSteps_; itr++)
00172     {
00173       if (bpdebug)
00174       {
00175         cout<<"*************************************"<<endl;
00176         cout<<"Performing Iteration "<<itr<<" of BP"<<endl;
00177         cout<<"*************************************"<<endl;
00178       }
00179 
00180       for (int i = 0; i < numFactors; i++)
00181       {
00182         if (bpdebug)
00183         {
00184           cout << "Sending messages for Factor: ";
00185           factorGraph_->getFactor(i)->print(cout); cout << endl;
00186         }
00187         factorGraph_->getFactor(i)->sendMessage();
00188       }
00189       
00190       for (int i = 0; i < numNodes; i++)
00191       {
00192         if (bpdebug)
00193         {
00194           cout << "Sending messages for Node: ";
00195           factorGraph_->getNode(i)->print(cout); cout << endl;
00196         }
00197         factorGraph_->getNode(i)->sendMessage();
00198       }
00199 
00200       for (int i = 0; i < numFactors; i++)
00201       {
00202         factorGraph_->getFactor(i)->moveToNextStep();
00203         if (bpdebug)
00204         {
00205           cout << "BP-Factor Iteration " << itr << " => ";
00206           factorGraph_->getFactor(i)->print(cout); cout << endl;
00207         }
00208       }
00209           
00210       maxDiff = -1;
00211       maxDiffNodeIndex = -1;
00212       for (int i = 0; i < numNodes; i++)
00213       {
00214         if (bpdebug)
00215         {
00216           cout<<"************************************"<<endl;
00217           cout<<"Node "<<i<<":"<<endl;
00218           cout<<"************************************"<<endl;
00219           cout<<"Getting Old Probabilities =>"<<endl; 
00220           cout<<endl;
00221           cout<<"Moving to next step "<<endl;
00222           cout<<endl;
00223           cout<<"Getting New Probabilities =>"<<endl; 
00224         }
00225 
00226         factorGraph_->getNode(i)->getProbs(oldProbs);
00227         factorGraph_->getNode(i)->moveToNextStep();
00228         factorGraph_->getNode(i)->getProbs(newProbs);
00229 
00230         diff = abs(newProbs[1] - oldProbs[1]);
00231 
00232         if (bpdebug)
00233         {
00234           cout << endl << endl << "Final Probs : " << endl;
00235           cout << "Node " << i << ": probs[" << 0 << "] = " << newProbs[0]
00236                << ", probs[" << 1 << "] = " << newProbs[1] << endl;
00237           cout << "BP-Node Iteration " << itr << ": " << newProbs[0]
00238                << "  probs[" << 1 << "] = " << newProbs[1] << endl;
00239           cout << " : => ";
00240           factorGraph_->getNode(i)->print(cout);
00241           cout << endl;
00242         }
00243         
00244         if (maxDiff < diff)
00245         {
00246           maxDiff = diff;
00247           maxDiffNodeIndex = i;
00248         }
00249       }
00250           
00251       cout << "At Iteration " << itr << ": MaxDiff = " << maxDiff << endl;
00252       cout << endl;
00253            
00254         //check if BP has converged
00255       if (maxDiff < convergenceThresh_)
00256         convergeItrCnt++;
00257       else
00258         convergeItrCnt = 0;
00259 
00260         // Check if for N continuous iterations, maxDiff has been below the
00261         // threshold
00262       if (convergeItrCnt >= convergeRequiredItrCnt_)
00263       {
00264         converged = true;
00265         break;
00266       }
00267     }
00268 
00269     if (converged)
00270     {
00271       cout << "Converged in " << itr << " Iterations " << endl;
00272     }
00273     else
00274     {
00275       cout << "Did not converge in " << maxSteps_ << " (max allowed) Iterations"
00276            << endl;
00277     }
00278     
00279     if (queryFormulas_)
00280     {
00281       cout << "Computing probabilities of query formulas ..." << endl;
00282       for (int i = 0; i < numNodes; i++)
00283       {
00284         if (bpdebug)
00285         {
00286           cout << "Sending auxiliary messages for Node: ";
00287           factorGraph_->getNode(i)->print(cout); cout << endl;
00288         }
00289         factorGraph_->getNode(i)->sendAuxMessage();
00290           // Now, messages have been sent to the aux. factors
00291       }
00292       for (int j = 0; j < qfProbs_->size(); j++)
00293       {
00294         (*qfProbs_)[j] = factorGraph_->getAuxFactor(j)->getProb();
00295       }
00296     }
00297   }
00298 
00302   void printNetwork(ostream& out)
00303   {
00304     factorGraph_->printNetwork(out);
00305   }
00306 
00310   void printProbabilities(ostream& out)
00311   {
00312     double probs[2];
00313     Array<int>* constants;
00314     Predicate* pred;
00315     int predId;
00316     Node* node;
00317     double exp;
00318     Domain* domain = factorGraph_->getDomain();
00319     for (int i = 0; i < factorGraph_->getNumNodes(); i++)
00320     { 
00321       node = factorGraph_->getNode(i);
00322       predId = node->getPredId();
00323       node->getProbs(probs);
00324       exp = node->getExp();
00325       SuperPred * superPred = node->getSuperPred();
00326 
00327       if (superPred)
00328       {        
00329         for (int index = 0; index < superPred->getNumTuples(); index++)
00330         {
00331           constants = superPred->getConstantTuple(index);
00332           pred = domain->getPredicate(constants, predId);
00333           pred->printWithStrVar(out, domain);
00334           out << " " << probs[1] << endl;
00335           //out<<" "<<exp<<endl;
00336         }
00337       }
00338       else
00339       {
00340         constants = node->getConstants(); 
00341         assert(constants != NULL);
00342         pred = domain->getPredicate(constants, predId);
00343         pred->printWithStrVar(out, domain);
00344         out << " " << probs[1] << endl;
00345         //out<<" "<<exp<<endl;
00346       }
00347     }
00348   }
00349 
00364   void getChangedPreds(vector<string>& changedPreds, vector<float>& probs,
00365                        vector<float>& oldProbs, const float& probDelta)
00366   {
00367   }
00368 
00375   double getProbability(GroundPredicate* const& gndPred)
00376   {
00377     double probs[2];
00378     Array<int>* constants;
00379     Predicate* pred;
00380     unsigned int predId;
00381     Node* node;
00382     Domain* domain = factorGraph_->getDomain();
00383     bool found = false;
00384     for (int i = 0; i < factorGraph_->getNumNodes(); i++)
00385     { 
00386       node = factorGraph_->getNode(i);
00387       predId = node->getPredId();
00388       if (predId != gndPred->getId()) continue;
00389       node->getProbs(probs);         
00390       SuperPred * superPred = node->getSuperPred();
00391 
00392       if (superPred)
00393       {        
00394         for (int index = 0; index < superPred->getNumTuples(); index++)
00395         {
00396           constants = superPred->getConstantTuple(index);
00397           pred = domain->getPredicate(constants, predId);
00398           if (!pred->same(gndPred))
00399           {
00400             delete pred;
00401             continue;
00402           }
00403           delete pred;
00404           found = true;
00405           return probs[1];
00406         }
00407       }
00408       else
00409       {
00410         constants = node->getConstants(); 
00411         assert(constants != NULL);
00412         pred = domain->getPredicate(constants, predId);
00413         if (!pred->same(gndPred))
00414         {
00415           delete pred;
00416           continue;
00417         }
00418         delete pred;
00419         found = true;
00420         return probs[1];
00421       }
00422     }
00423     return 0.5;
00424   }
00425 
00432   double getProbabilityH(GroundPredicate* const& gndPred)
00433   {
00434     return 0.0;
00435   }
00436 
00441   void printTruePreds(ostream& out)
00442   {
00443   }
00444   
00449   void printTruePredsH(ostream& out)
00450   {
00451   }
00452 
00453  private:
00454     // Network on which BP is run
00455   FactorGraph* factorGraph_;
00456     // Max. no. of BP iterations to perform
00457   int maxSteps_;
00458     // Max. no. of seconds BP should run
00459   int maxSeconds_;
00460     // Maximum difference between probabilities must be less than this
00461     // in order to converge
00462   double convergenceThresh_;
00463     // Convergence must last this number of iterations
00464   int convergeRequiredItrCnt_;
00465     // No inference is run, rather the factor graph is built
00466   bool outputNetwork_;
00467 };
00468 
00469 #endif /*BP_H_*/

Generated on Sun Jun 7 11:55:11 2009 for Alchemy by  doxygen 1.5.1