factor.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #include "factor.h"
00068  
00069 /*****************************************************************************/
00070 // Functions for class Factor
00071 /*****************************************************************************/
00072 
00073 //the contribution of the factor itself
00074 void Factor::initFactorMesssages()
00075 {
00076   Predicate *pred;
00077   int numPreds = clause_->getNumPredicates();
00078   int stateCnt = (int)pow(2.0, numPreds);
00079   factorMsgs_ = new double[stateCnt];
00080   bool isSatisfied;
00081   for (int state = 0; state < stateCnt; state++)
00082   {
00083     isSatisfied = false;
00084     for (int predno = 0; predno < numPreds; predno++)
00085     {
00086       pred = clause_->getPredicate(predno);
00087       bool predBit = state & (1<<predno);
00088       if (pred->getSense() == predBit)
00089       {
00090         isSatisfied = true;
00091         break;
00092       }
00093     }
00094 
00095     if (isSatisfied)
00096     {
00097         // Always 1
00098         // MS: should be weight of clause in ground network (super or ground)
00099       factorMsgs_[state] = clause_->getWt();
00100     }
00101     else
00102     {
00103       factorMsgs_[state] = 0;
00104     }
00105   }
00106 }
00107 
00108   //find the outgoing message for the given inpPredIndex
00109 double* Factor::multiplyMessagesAndSumOut(int inpPredIndex)
00110 {
00111   int numPreds = clause_->getNumPredicates();
00112   int stateCnt = (int)pow(2.0, numPreds);
00113   double * prodMsgs = new double[stateCnt];
00114              
00115   Node *node;
00116   double *gndNodeCnts = new double[numPreds];
00117 
00118   // This is done to handle the case when BP is run
00119   // on supernodes/superfeatures which have not yet
00120   // reached an equilibrium state
00121   for (int predno = 0; predno < numPreds; predno++)
00122     gndNodeCnts[predno] = 0;
00123              
00124   for (int lno = 0; lno < links_->size(); lno++)
00125   {
00126     int predIndex = (*links_)[lno]->getPredIndex();
00127     node = (*links_)[lno]->getNode();
00128     gndNodeCnts[predIndex] += node->getGroundNodeCount();
00129   }
00130 
00131     //initialize the product
00132   for (int state = 0; state < stateCnt; state++)
00133     prodMsgs[state] = factorMsgs_[state];
00134 
00135   for (int lno = 0; lno < links_->size(); lno++)
00136   {
00137     int predIndex = (*links_)[lno]->getPredIndex();
00138     if (predIndex == inpPredIndex)
00139       continue;
00140 
00141     node = (*links_)[lno]->getNode();
00142       //wt must be equal to 1 in equilibrium. This step 
00143       //averages (weighted) out the messages from various 
00144       //supernodes (at this predIndex position).
00145     assert(gndNodeCnts[predIndex] != 0);
00146       //MS: Ground: should be 1, super: should be # of ground clauses in super clause
00147     double cnt = node->getGroundNodeCount()/gndNodeCnts[predIndex];
00148     //double cnt = 1;
00149     //if (superClause_)
00150     //  cnt = superClause_->getNumTuplesIncludingImplicit();
00151 
00152     for (int state = 0; state < stateCnt; state++)
00153     {
00154       bool predBit = state & (1<<predIndex);
00155       if (predBit)
00156         prodMsgs[state] += (*msgsArr_)[lno][1]*cnt;
00157       else
00158         prodMsgs[state] += (*msgsArr_)[lno][0]*cnt;
00159 //cout << state << " state " << predBit << " " << prodMsgs[state] << endl;
00160     }
00161   }
00162 
00163     //caller is responsible for deleting it
00164   double *outMsgs = new double[2];    
00165   double maxMsgs[2];
00166   maxMsgs[0] = maxMsgs[1] = 0;
00167   bool firstTime[2];
00168   firstTime[0] = firstTime[1] = true;
00169              
00170     //now find the max messages 
00171   for (int state = 0; state < stateCnt; state++)
00172   {
00173     bool predBit = state & (1<<inpPredIndex);
00174     if (predBit && (maxMsgs[1] < prodMsgs[state] || firstTime[1]))
00175     {
00176       firstTime[1] = false;
00177       maxMsgs[1] = prodMsgs[state];
00178     }
00179 
00180     if (!predBit && (maxMsgs[0] < prodMsgs[state] || firstTime[0]))
00181     {
00182       firstTime[0] = false;
00183       maxMsgs[0] = prodMsgs[state];
00184     }
00185   }
00186 
00187   outMsgs[0] = outMsgs[1] = 0;
00188   for (int state = 0; state < stateCnt; state++)
00189   {
00190     bool predBit = state & (1<<inpPredIndex);
00191     if (predBit) 
00192       outMsgs[1] += expl(prodMsgs[state] - maxMsgs[1]);
00193     else
00194       outMsgs[0] += expl(prodMsgs[state] - maxMsgs[0]);
00195   }
00196   outMsgs[1] = maxMsgs[1] + logl(outMsgs[1]);
00197   outMsgs[0] = maxMsgs[0] + logl(outMsgs[0]);
00198   outMsgs[1] = outMsgs[1] - outMsgs[0];
00199   outMsgs[0] = 0;
00200   delete [] prodMsgs;
00201   delete [] gndNodeCnts;
00202   return outMsgs;
00203 }
00204 
00205   //send Message on all the links
00206 void Factor::sendMessage()
00207 {
00208   double *outMsgs = NULL;
00209   Link *link;
00210   Node *node;
00211   int predIndex;
00212   for (int lindex = 0; lindex < links_->size(); lindex++)
00213   {
00214     link = (*links_)[lindex];
00215     node = link->getNode();
00216     predIndex = link->getPredIndex();
00217 
00218     outMsgs = multiplyMessagesAndSumOut(predIndex);
00219       //Assumes pass by value copy of the messages
00220     node->receiveMessage(outMsgs, link);
00221     delete [] outMsgs;
00222   }
00223 }
00224 
00225   //update the stored msgs and update the msgProduct
00226 void Factor::moveToNextStep()
00227 {
00228   double * msgs;
00229   double * nextMsgs;
00230 
00231     //store the current messages in prevMessages array and
00232     //reinitialize the current message array
00233   for (int lindex = 0;lindex < links_->size(); lindex++)
00234   {
00235     msgs = (*msgsArr_)[lindex];
00236     nextMsgs = (*nextMsgsArr_)[lindex];
00237     for (int i = 0; i < 2; i++)
00238     {
00239       msgs[i] = nextMsgs[i];
00240       nextMsgs[i] = 0;
00241     }
00242   }
00243 }
00244 
00248 ostream& Factor::print(ostream& out)
00249 {
00250   /*
00251   if (superClause_ != NULL)
00252   {
00253     printArray(*(superClause_->getConstantTuple(0)), 1, out);
00254   }
00255   else
00256   {
00257     printArray(*constants_,1,out);
00258   }
00259   out << "// ";
00260   */
00261     
00262   for (int i = 0; i < links_->size(); i++)
00263   {
00264     (*links_)[i]->getNode()->print(out);
00265     if (i != links_->size() - 1) out << "/ ";
00266   }
00267   
00268   return out;
00269 }
00270 
00274 ostream& Factor::printWts(ostream& out)
00275 {
00276   for (int i = 0; i < (int)pow(2.0, clause_->getNumPredicates()); i++)
00277   {
00278     out << factorMsgs_[i] * outputWt_ << " ";
00279   }
00280   
00281   return out;
00282 }
00283 

Generated on Sun Jun 7 11:55:11 2009 for Alchemy by  doxygen 1.5.1