00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef _BPFACTOR_H_Jan_2008
00067 #define _BPFACTOR_H_Jan_2008
00068
00069 #include <math.h>
00070 #include "util.h"
00071 #include "mrf.h"
00072 #include "array.h"
00073 #include "bplink.h"
00074 #include "bpnode.h"
00075 #include "superclause.h"
00076
00077 using namespace std;
00078 using namespace __gnu_cxx;
00079
00083 class BPFactor
00084 {
00085 public:
00086 BPFactor(Clause * const & clause, SuperClause * const & superClause,
00087 Array<int> * const & constants, Domain * const & domain,
00088 double outputWt)
00089 {
00090 clause_ = clause;
00091 superClause_ = superClause;
00092 constants_ = constants;
00093 domain_ = domain;
00094
00095 msgsArr_ = new Array<double *>;
00096 nextMsgsArr_ = new Array<double *>;
00097 links_ = new Array<BPLink *>();
00098 outputWt_ = outputWt;
00099 initFactorMesssages();
00100 }
00101
00102 ~BPFactor()
00103 {
00104 for (int i = 0; i < msgsArr_->size(); i++)
00105 {
00106 delete (*msgsArr_)[i];
00107 delete (*nextMsgsArr_)[i];
00108 }
00109 delete links_;
00110 delete msgsArr_;
00111 delete nextMsgsArr_;
00112 }
00113
00114 void initFactorMesssages();
00115
00116 int getSuperClauseId()
00117 {
00118 if (superClause_)
00119 return superClause_->getSuperClauseId();
00120 else
00121 return -1;
00122 }
00123
00124 int getParentSuperClauseId()
00125 {
00126 if( superClause_)
00127 return superClause_->getParentSuperClauseId();
00128 else
00129 return -1;
00130 }
00131
00132 SuperClause* getSuperClause() {return superClause_;}
00133 Clause * getClause() {return clause_;}
00134 Domain *getDomain() {return domain_;}
00135 Array<int> * getConstants() {return constants_;}
00136 int getNumLinks() {return links_->size();}
00137
00138 void getMessage(int index, double msgs[])
00139 {
00140 msgs[0] = (*msgsArr_)[index][0];
00141 msgs[1] = (*msgsArr_)[index][1];
00142 }
00143
00144 void addLink(BPLink *link, double inpMsgs[2])
00145 {
00146 links_->append(link);
00147 double *msgs;
00148 msgs = new double[2];
00149
00150 if (inpMsgs)
00151 {
00152 msgs[0] = inpMsgs[0];
00153 msgs[1] = inpMsgs[1];
00154 }
00155 else
00156 {
00157 msgs[0] = msgs[1] = 0;
00158 }
00159 msgsArr_->append(msgs);
00160 msgs = new double[2];
00161 nextMsgsArr_->append(msgs);
00162 }
00163
00167 void receiveMessage(double* inpMsgs, BPLink *link)
00168 {
00169 double *nextMsgs;
00170 int reverseNodeIndex = link->getReverseNodeIndex();
00171 nextMsgs = (*nextMsgsArr_)[reverseNodeIndex];
00172 nextMsgs[0] = inpMsgs[0];
00173 nextMsgs[1] = inpMsgs[1];
00174 }
00175
00176
00177 double* multiplyMessagesAndSumOut(int predIndex);
00178
00179
00180 void sendMessage();
00181
00182
00183 void moveToNextStep();
00184
00185 ostream& print(ostream& out);
00186
00187 ostream& printWts(ostream& out);
00188
00189 private:
00190 Clause * clause_;
00191
00192
00193
00194 SuperClause * superClause_;
00195 Array<int> * constants_;
00196 Domain * domain_;
00197 Array<BPLink *> *links_;
00198 Array<double *> *msgsArr_;
00199 Array<double *> *nextMsgsArr_;
00200 double *factorMsgs_;
00201 double outputWt_;
00202 };
00203
00204 #endif
00205