bpnode.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, and Daniel Lowd.
00006  * 
00007  * Copyright [2004-08] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, and Daniel Lowd. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, and Daniel Lowd in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef _BPNODE_H_Jan_2008
00067 #define _BPNODE_H_Jan_2008
00068 
00069 #include <math.h>
00070 #include "util.h"
00071 #include "mrf.h"
00072 #include "array.h"
00073 #include "bplink.h"
00074 #include "superpred.h"
00075 #include "twowaymessage.h"
00076 #include "bpfactor.h"
00077 
00078 const double MIN = 1e-6;
00079 const double MINEXP = logl(MIN);
00080 
00081 const double SMOOTHTERM = 1e-6;
00082 
00083 using namespace std;
00084 using namespace __gnu_cxx;
00085 
00086 class BPNode
00087 {
00088  public:
00089   BPNode(int predId, SuperPred * const & superPred, 
00090          Array<int> * const & constants, Domain * const & domain)
00091   {
00092     predId_ = predId;
00093     superPred_ = superPred;
00094     constants_ = constants;
00095     domain_ = domain;
00096     links_ = new Array<BPLink *>();
00097     msgsArr_ = new Array<double *>();
00098     nextMsgsArr_ = new Array<double *>();
00099     msgProds_[0] = msgProds_[1] = 0;
00100   }
00101 
00102   ~BPNode()
00103   {
00104     for (int i = 0; i < links_->size(); i++)
00105       delete (*links_)[i];
00106     for (int i = 0; i < msgsArr_->size(); i++)
00107     {
00108       delete (*msgsArr_)[i];
00109       delete (*nextMsgsArr_)[i];
00110     }
00111     delete links_;
00112     delete msgsArr_;
00113     delete nextMsgsArr_;
00114   }
00115 
00116   int getPredId() { return predId_;}
00117 
00118   int getSuperPredId()
00119   {
00120     if (superPred_) 
00121       return superPred_->getSuperPredId();
00122     else
00123       return -1;
00124   }
00125 
00126   int getParentSuperPredId()
00127   {
00128     if (superPred_) 
00129       return superPred_->getParentSuperPredId();
00130     else
00131       return -1;
00132   }
00133 
00134   SuperPred* getSuperPred() {return superPred_;}
00135   Array<int> * getConstants() {return constants_;}
00136 
00137   int getGroundNodeCount()
00138   {
00139     if (superPred_)
00140     {
00141       return superPred_->getNumTuples();
00142     }
00143     else
00144       return 1;
00145   }
00146 
00147   int getNumLinks() { return links_->size();}
00148   
00149   BPLink * getLink(int index) {return (*links_)[index];}
00150           
00151   void getMessage(int index, double msgs[])
00152   {
00153     msgs[0] = (*msgsArr_)[index][0];
00154     msgs[1] = (*msgsArr_)[index][1];
00155   }
00156 
00161   void addLink(BPLink *link, double inpMsgs[2])
00162   {
00163     links_->append(link);
00164     double *msgs;
00165     msgs = new double[2];
00166 
00167     double cnt = link->getCount(); 
00168     for (int i = 0; i < 2; i++)
00169     {
00170       if (inpMsgs)
00171       {
00172         msgs[i] = inpMsgs[i];
00173       }
00174       else
00175       {
00176         msgs[i] = 0;
00177       }
00178       msgProds_[i] = msgProds_[i] + cnt*msgs[i];
00179     }
00180 
00181     msgsArr_->append(msgs);
00182     msgs = new double[2];
00183     nextMsgsArr_->append(msgs);
00184   }
00185 
00186     //add the factors with appropriate counts, also add the node to the
00187     //corresponding factor
00188   void addFactors(Array<BPFactor *> * const & allFactors,
00189                   LinkIdToTwoWayMessageMap* const & lidToTWMsg);
00190 
00191     //receive the message send over a link
00192   void receiveMessage(double* inpMsgs, BPLink * const & link)
00193   {
00194     double *nextMsgs;
00195     int reverseFactorIndex = link->getReverseFactorIndex();
00196     nextMsgs = (*nextMsgsArr_)[reverseFactorIndex];
00197     nextMsgs[0] = inpMsgs[0];
00198     nextMsgs[1] = inpMsgs[1];
00199   }
00200          
00201   double getExp()
00202   {
00203     double exp = msgProds_[1] - msgProds_[0];
00204     return exp;
00205   }
00206 
00211   double * getProbs(double * const & probs)
00212   {
00213     double exps[2];
00214     double norm, smooth;
00215     exps[0] = 0; //msgProds_[0];
00216     exps[1] = msgProds_[1] - msgProds_[0];
00217 
00218     for (int i = 0; i < 2; i++)
00219     {
00220       if(exps[i] < MINEXP)
00221         exps[i] = MINEXP;
00222       if(exps[i] > -MINEXP)
00223         exps[i] = -MINEXP;
00224       probs[i] = expl(exps[i]);
00225     }
00226 
00227     norm = probs[0] + probs[1];
00228     smooth = norm*SMOOTHTERM;
00229     norm += smooth;
00230 
00231     for (int i = 0; i < 2 ; i++)
00232     {
00233       probs[i] += smooth/2;
00234       probs[i] = probs[i]/norm;
00235     }
00236     return probs;       
00237   }
00238 
00239     //send the messages to all the factor nodes connected to this node
00240   void sendMessage();
00241 
00242     //update the stored msgs and update the msgProduct
00243   void moveToNextStep();
00244 
00245   ostream& print(ostream& out)
00246   {
00247     out << predId_ << ": ";
00248     if (superPred_ != NULL)
00249       printArray(*(superPred_->getConstantTuple(0)),out);
00250     else
00251       printArray(*constants_,out);
00252     return out;
00253   }
00254 
00255  private:
00256 
00257   int predId_;
00258 
00259     //only one of the two below is used - superPred in case of lifted
00260     //inference and constants in case of ground inference
00261   SuperPred * superPred_;
00262   Array<int> *constants_;
00263 
00264   Domain * domain_;
00265 
00266   Array<BPLink *> * links_;
00267     //log of the actual message is stored
00268   Array<double *> * msgsArr_; 
00269   Array<double *> * nextMsgsArr_;
00270 
00271     //we store these so that we do not have to n^2 squared
00272     //computation while calculating the outgoing messages - we can 
00273     //just divide (subtract in the log domain) by the message from the 
00274     //factor node to which it is being sent
00275   double msgProds_[2];
00276 };
00277 
00278 #endif
00279 

Generated on Sun Jun 7 11:55:11 2009 for Alchemy by  doxygen 1.5.1