00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #ifndef _NODE_H_Jan_2008
00068 #define _NODE_H_Jan_2008
00069
00070 #include <math.h>
00071 #include "util.h"
00072 #include "mrf.h"
00073 #include "array.h"
00074 #include "link.h"
00075 #include "superpred.h"
00076 #include "twowaymessage.h"
00077 #include "factor.h"
00078
00079 const double MIN = 1e-6;
00080 const double MINEXP = logl(MIN);
00081
00082 const double SMOOTHTERM = 1e-6;
00083
00084 using namespace std;
00085 using namespace __gnu_cxx;
00086
00090 class Node
00091 {
00092 public:
00093
00097 Node(int predId, SuperPred * const & superPred,
00098 Array<int> * const & constants, Domain * const & domain)
00099 {
00100 predId_ = predId;
00101 superPred_ = superPred;
00102 constants_ = constants;
00103 domain_ = domain;
00104 links_ = new Array<Link *>();
00105 auxLinks_ = new Array<Link *>();
00106 msgsArr_ = new Array<double *>();
00107 nextMsgsArr_ = new Array<double *>();
00108 msgProds_[0] = msgProds_[1] = 0;
00109 }
00110
00114 ~Node()
00115 {
00116 for (int i = 0; i < links_->size(); i++)
00117 delete (*links_)[i];
00118 for (int i = 0; i < auxLinks_->size(); i++)
00119 delete (*auxLinks_)[i];
00120 for (int i = 0; i < msgsArr_->size(); i++)
00121 {
00122 delete (*msgsArr_)[i];
00123 delete (*nextMsgsArr_)[i];
00124 }
00125 delete links_;
00126 delete auxLinks_;
00127 delete msgsArr_;
00128 delete nextMsgsArr_;
00129 }
00130
00131 int getPredId() { return predId_;}
00132
00133 int getSuperPredId()
00134 {
00135 if (superPred_)
00136 return superPred_->getSuperPredId();
00137 else
00138 return -1;
00139 }
00140
00141 int getParentSuperPredId()
00142 {
00143 if (superPred_)
00144 return superPred_->getParentSuperPredId();
00145 else
00146 return -1;
00147 }
00148
00149 SuperPred* getSuperPred() {return superPred_;}
00150 Array<int> * getConstants() {return constants_;}
00151
00152 int getGroundNodeCount()
00153 {
00154 if (superPred_)
00155 {
00156 return superPred_->getNumTuples();
00157 }
00158 else
00159 return 1;
00160 }
00161
00166 int getNumLinks() { return links_->size();}
00167
00168 Link * getLink(int index) {return (*links_)[index];}
00169
00177 void getMessage(int index, double msgs[])
00178 {
00179 msgs[0] = (*msgsArr_)[index][0];
00180 msgs[1] = (*msgsArr_)[index][1];
00181 }
00182
00187 void addAuxLink(Link *link)
00188 {
00189 auxLinks_->append(link);
00190 }
00191
00196 void addLink(Link *link, double inpMsgs[2])
00197 {
00198 links_->append(link);
00199 double *msgs;
00200 msgs = new double[2];
00201
00202 double cnt = link->getCount();
00203 for (int i = 0; i < 2; i++)
00204 {
00205 if (inpMsgs)
00206 {
00207 msgs[i] = inpMsgs[i];
00208 }
00209 else
00210 {
00211 msgs[i] = 0;
00212 }
00213 msgProds_[i] = msgProds_[i] + cnt*msgs[i];
00214 }
00215
00216 msgsArr_->append(msgs);
00217 msgs = new double[2];
00218 nextMsgsArr_->append(msgs);
00219 }
00220
00225 void addFactors(Array<Factor *> * const & allFactors,
00226 LinkIdToTwoWayMessageMap* const & lidToTWMsg);
00227
00231 void receiveMessage(double* inpMsgs, Link * const & link)
00232 {
00233 double *nextMsgs;
00234 int reverseFactorIndex = link->getReverseFactorIndex();
00235 nextMsgs = (*nextMsgsArr_)[reverseFactorIndex];
00236 nextMsgs[0] = inpMsgs[0];
00237 nextMsgs[1] = inpMsgs[1];
00238 }
00239
00240 double getExp()
00241 {
00242 double exp = msgProds_[1] - msgProds_[0];
00243 return exp;
00244 }
00245
00250 double * getProbs(double * const & probs)
00251 {
00252 double exps[2];
00253 double norm, smooth;
00254 exps[0] = 0;
00255 exps[1] = msgProds_[1] - msgProds_[0];
00256
00257 for (int i = 0; i < 2; i++)
00258 {
00259 if(exps[i] < MINEXP)
00260 exps[i] = MINEXP;
00261 if(exps[i] > -MINEXP)
00262 exps[i] = -MINEXP;
00263 probs[i] = expl(exps[i]);
00264 }
00265
00266 norm = probs[0] + probs[1];
00267 smooth = norm*SMOOTHTERM;
00268 norm += smooth;
00269
00270 for (int i = 0; i < 2 ; i++)
00271 {
00272 probs[i] += smooth/2;
00273 probs[i] = probs[i]/norm;
00274 }
00275 return probs;
00276 }
00277
00281 void sendMessage();
00282
00286 void sendAuxMessage();
00287
00291 void moveToNextStep();
00292
00293 ostream& print(ostream& out)
00294 {
00295 out << predId_ << ": ";
00296 if (superPred_ != NULL)
00297 printArray(*(superPred_->getConstantTuple(0)),out);
00298 else
00299 printArray(*constants_,out);
00300 return out;
00301 }
00302
00303 private:
00304
00305 int predId_;
00306
00307
00308
00309 SuperPred * superPred_;
00310 Array<int> *constants_;
00311
00312 Domain * domain_;
00313
00314 Array<Link *> * links_;
00315
00316 Array<Link *> * auxLinks_;
00317
00318 Array<double *> * msgsArr_;
00319 Array<double *> * nextMsgsArr_;
00320
00321
00322
00323
00324
00325 double msgProds_[2];
00326 };
00327
00328 #endif
00329