00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef _BPNODE_H_Jan_2008
00067 #define _BPNODE_H_Jan_2008
00068
00069 #include <math.h>
00070 #include "util.h"
00071 #include "mrf.h"
00072 #include "array.h"
00073 #include "bplink.h"
00074 #include "superpred.h"
00075 #include "twowaymessage.h"
00076 #include "bpfactor.h"
00077
00078 const double MIN = 1e-6;
00079 const double MINEXP = logl(MIN);
00080
00081 const double SMOOTHTERM = 1e-6;
00082
00083 using namespace std;
00084 using namespace __gnu_cxx;
00085
00086 class BPNode
00087 {
00088 public:
00089 BPNode(int predId, SuperPred * const & superPred,
00090 Array<int> * const & constants, Domain * const & domain)
00091 {
00092 predId_ = predId;
00093 superPred_ = superPred;
00094 constants_ = constants;
00095 domain_ = domain;
00096 links_ = new Array<BPLink *>();
00097 msgsArr_ = new Array<double *>();
00098 nextMsgsArr_ = new Array<double *>();
00099 msgProds_[0] = msgProds_[1] = 0;
00100 }
00101
00102 ~BPNode()
00103 {
00104 for (int i = 0; i < links_->size(); i++)
00105 delete (*links_)[i];
00106 for (int i = 0; i < msgsArr_->size(); i++)
00107 {
00108 delete (*msgsArr_)[i];
00109 delete (*nextMsgsArr_)[i];
00110 }
00111 delete links_;
00112 delete msgsArr_;
00113 delete nextMsgsArr_;
00114 }
00115
00116 int getPredId() { return predId_;}
00117
00118 int getSuperPredId()
00119 {
00120 if (superPred_)
00121 return superPred_->getSuperPredId();
00122 else
00123 return -1;
00124 }
00125
00126 int getParentSuperPredId()
00127 {
00128 if (superPred_)
00129 return superPred_->getParentSuperPredId();
00130 else
00131 return -1;
00132 }
00133
00134 SuperPred* getSuperPred() {return superPred_;}
00135 Array<int> * getConstants() {return constants_;}
00136
00137 int getGroundNodeCount()
00138 {
00139 if (superPred_)
00140 {
00141 return superPred_->getNumTuples();
00142 }
00143 else
00144 return 1;
00145 }
00146
00147 int getNumLinks() { return links_->size();}
00148
00149 BPLink * getLink(int index) {return (*links_)[index];}
00150
00151 void getMessage(int index, double msgs[])
00152 {
00153 msgs[0] = (*msgsArr_)[index][0];
00154 msgs[1] = (*msgsArr_)[index][1];
00155 }
00156
00161 void addLink(BPLink *link, double inpMsgs[2])
00162 {
00163 links_->append(link);
00164 double *msgs;
00165 msgs = new double[2];
00166
00167 double cnt = link->getCount();
00168 for (int i = 0; i < 2; i++)
00169 {
00170 if (inpMsgs)
00171 {
00172 msgs[i] = inpMsgs[i];
00173 }
00174 else
00175 {
00176 msgs[i] = 0;
00177 }
00178 msgProds_[i] = msgProds_[i] + cnt*msgs[i];
00179 }
00180
00181 msgsArr_->append(msgs);
00182 msgs = new double[2];
00183 nextMsgsArr_->append(msgs);
00184 }
00185
00186
00187
00188 void addFactors(Array<BPFactor *> * const & allFactors,
00189 LinkIdToTwoWayMessageMap* const & lidToTWMsg);
00190
00191
00192 void receiveMessage(double* inpMsgs, BPLink * const & link)
00193 {
00194 double *nextMsgs;
00195 int reverseFactorIndex = link->getReverseFactorIndex();
00196 nextMsgs = (*nextMsgsArr_)[reverseFactorIndex];
00197 nextMsgs[0] = inpMsgs[0];
00198 nextMsgs[1] = inpMsgs[1];
00199 }
00200
00201 double getExp()
00202 {
00203 double exp = msgProds_[1] - msgProds_[0];
00204 return exp;
00205 }
00206
00211 double * getProbs(double * const & probs)
00212 {
00213 double exps[2];
00214 double norm, smooth;
00215 exps[0] = 0;
00216 exps[1] = msgProds_[1] - msgProds_[0];
00217
00218 for (int i = 0; i < 2; i++)
00219 {
00220 if(exps[i] < MINEXP)
00221 exps[i] = MINEXP;
00222 if(exps[i] > -MINEXP)
00223 exps[i] = -MINEXP;
00224 probs[i] = expl(exps[i]);
00225 }
00226
00227 norm = probs[0] + probs[1];
00228 smooth = norm*SMOOTHTERM;
00229 norm += smooth;
00230
00231 for (int i = 0; i < 2 ; i++)
00232 {
00233 probs[i] += smooth/2;
00234 probs[i] = probs[i]/norm;
00235 }
00236 return probs;
00237 }
00238
00239
00240 void sendMessage();
00241
00242
00243 void moveToNextStep();
00244
00245 ostream& print(ostream& out)
00246 {
00247 out << predId_ << ": ";
00248 if (superPred_ != NULL)
00249 printArray(*(superPred_->getConstantTuple(0)),out);
00250 else
00251 printArray(*constants_,out);
00252 return out;
00253 }
00254
00255 private:
00256
00257 int predId_;
00258
00259
00260
00261 SuperPred * superPred_;
00262 Array<int> *constants_;
00263
00264 Domain * domain_;
00265
00266 Array<BPLink *> * links_;
00267
00268 Array<double *> * msgsArr_;
00269 Array<double *> * nextMsgsArr_;
00270
00271
00272
00273
00274
00275 double msgProds_[2];
00276 };
00277
00278 #endif
00279