bpfactor.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, and Daniel Lowd.
00006  * 
00007  * Copyright [2004-08] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, and Daniel Lowd. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, and Daniel Lowd in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #include "bpfactor.h"
00067  
00068 /*****************************************************************************/
00069 // Functions for class BPFactor
00070 /*****************************************************************************/
00071 
00072 //the contribution of the factor itself
00073 void BPFactor::initFactorMesssages()
00074 {
00075   Predicate *pred;
00076   int numPreds = clause_->getNumPredicates();
00077   int stateCnt = (int)pow(2.0, numPreds);
00078   factorMsgs_ = new double[stateCnt];
00079   bool isSatisfied;
00080   for (int state = 0; state < stateCnt; state++)
00081   {
00082     isSatisfied = false;
00083     for (int predno = 0; predno < numPreds; predno++)
00084     {
00085       pred = clause_->getPredicate(predno);
00086       bool predBit = state & (1<<predno);
00087       if (pred->getSense() == predBit)
00088       {
00089         isSatisfied = true;
00090         break;
00091       }
00092     }
00093 
00094     if (isSatisfied)
00095     {
00096         // Always 1
00097       factorMsgs_[state] = clause_->getWt();
00098     }
00099     else
00100     {
00101       factorMsgs_[state] = 0;
00102     }
00103   }
00104 }
00105 
00106   //find the outgoing message for the given inpPredIndex
00107 double* BPFactor::multiplyMessagesAndSumOut(int inpPredIndex)
00108 {
00109   int numPreds = clause_->getNumPredicates();
00110   int stateCnt = (int)pow(2.0, numPreds);
00111   double * prodMsgs = new double[stateCnt];
00112              
00113   BPNode *node;
00114   double *gndNodeCnts = new double[numPreds];
00115 
00116   // This is done to handle the case when BP is run
00117   // on supernodes/superfeatures which have not yet
00118   // reached an equilibrium state
00119   for (int predno = 0; predno < numPreds; predno++)
00120     gndNodeCnts[predno] = 0;
00121              
00122   for (int lno = 0; lno < links_->size(); lno++)
00123   {
00124     int predIndex = (*links_)[lno]->getPredIndex();
00125     node = (*links_)[lno]->getNode();
00126     gndNodeCnts[predIndex] += node->getGroundNodeCount();
00127   }
00128 
00129     //initialize the product
00130   for (int state = 0; state < stateCnt; state++)
00131     prodMsgs[state] = factorMsgs_[state];
00132 
00133   for (int lno = 0; lno < links_->size(); lno++)
00134   {
00135     int predIndex = (*links_)[lno]->getPredIndex();
00136     if (predIndex == inpPredIndex)
00137       continue;
00138 
00139     node = (*links_)[lno]->getNode();
00140       //wt must be equal to 1 in equilibrium. This step 
00141       //averages (weighted) out the messages from various 
00142       //supernodes (at this predIndex position).
00143     assert(gndNodeCnts[predIndex] != 0);
00144     double wt = node->getGroundNodeCount()/gndNodeCnts[predIndex];
00145 
00146     for (int state = 0; state < stateCnt; state++)
00147     {
00148       bool predBit = state & (1<<predIndex);
00149       if (predBit)
00150         prodMsgs[state] += (*msgsArr_)[lno][1]*wt;
00151       else
00152         prodMsgs[state] += (*msgsArr_)[lno][0]*wt;
00153     }
00154   }
00155 
00156     //caller is responsible for deleting it
00157   double *outMsgs = new double[2];    
00158   double maxMsgs[2];
00159   maxMsgs[0] = maxMsgs[1] = 0;
00160   bool firstTime[2];
00161   firstTime[0] = firstTime[1] = true;
00162              
00163     //now find the max messages 
00164   for (int state = 0; state < stateCnt; state++)
00165   {
00166     bool predBit = state & (1<<inpPredIndex);
00167     if (predBit && (maxMsgs[1] < prodMsgs[state] || firstTime[1]))
00168     {
00169       firstTime[1] = false;
00170       maxMsgs[1] = prodMsgs[state];
00171     }
00172 
00173     if (!predBit && (maxMsgs[0] < prodMsgs[state] || firstTime[0]))
00174     {
00175       firstTime[0] = false;
00176       maxMsgs[0] = prodMsgs[state];
00177     }
00178   }
00179 
00180   outMsgs[0] = outMsgs[1] = 0;
00181   for (int state = 0; state < stateCnt; state++)
00182   {
00183     bool predBit = state & (1<<inpPredIndex);
00184     if (predBit) 
00185       outMsgs[1] += expl(prodMsgs[state] - maxMsgs[1]);
00186     else
00187       outMsgs[0] += expl(prodMsgs[state] - maxMsgs[0]);
00188   }
00189   outMsgs[1] = maxMsgs[1] + logl(outMsgs[1]);
00190   outMsgs[0] = maxMsgs[0] + logl(outMsgs[0]);
00191   outMsgs[1] = outMsgs[1] - outMsgs[0];
00192   outMsgs[0] = 0;
00193   delete [] prodMsgs;
00194   delete [] gndNodeCnts;
00195   return outMsgs;
00196 }
00197 
00198   //send Message on all the links
00199 void BPFactor::sendMessage()
00200 {
00201   double *outMsgs = NULL;
00202   BPLink *link;
00203   BPNode *node;
00204   int predIndex;
00205   for (int lindex = 0; lindex < links_->size(); lindex++)
00206   {
00207     link = (*links_)[lindex];
00208     node = link->getNode();
00209     predIndex = link->getPredIndex();
00210 
00211     outMsgs = multiplyMessagesAndSumOut(predIndex);
00212       //Assumes pass by value copy of the messages
00213     node->receiveMessage(outMsgs, link);
00214     delete [] outMsgs;
00215   }
00216 }
00217 
00218   //update the stored msgs and update the msgProduct
00219 void BPFactor::moveToNextStep()
00220 {
00221   double * msgs;
00222   double * nextMsgs;
00223 
00224     //store the current messages in prevMessages array and
00225     //reinitialize the current message array
00226   for (int lindex = 0;lindex < links_->size(); lindex++)
00227   {
00228     msgs = (*msgsArr_)[lindex];
00229     nextMsgs = (*nextMsgsArr_)[lindex];
00230     for (int i = 0; i < 2; i++)
00231     {
00232       msgs[i] = nextMsgs[i];
00233       nextMsgs[i] = 0;
00234     }
00235   }
00236 }
00237 
00241 ostream& BPFactor::print(ostream& out)
00242 {
00243   /*
00244   if (superClause_ != NULL)
00245   {
00246     printArray(*(superClause_->getConstantTuple(0)), 1, out);
00247   }
00248   else
00249   {
00250     printArray(*constants_,1,out);
00251   }
00252   out << "// ";
00253   */
00254     
00255   for (int i = 0; i < links_->size(); i++)
00256   {
00257     (*links_)[i]->getNode()->print(out);
00258     if (i != links_->size() - 1) out << "/ ";
00259   }
00260   
00261   return out;
00262 }
00263 
00267 ostream& BPFactor::printWts(ostream& out)
00268 {
00269   for (int i = 0; i < (int)pow(2.0, clause_->getNumPredicates()); i++)
00270   {
00271     out << factorMsgs_[i] * outputWt_ << " ";
00272   }
00273   
00274   return out;
00275 }
00276 

Generated on Sun Jun 7 11:55:11 2009 for Alchemy by  doxygen 1.5.1