00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #include "bpfactor.h"
00067
00068
00069
00070
00071
00072
00073 void BPFactor::initFactorMesssages()
00074 {
00075 Predicate *pred;
00076 int numPreds = clause_->getNumPredicates();
00077 int stateCnt = (int)pow(2.0, numPreds);
00078 factorMsgs_ = new double[stateCnt];
00079 bool isSatisfied;
00080 for (int state = 0; state < stateCnt; state++)
00081 {
00082 isSatisfied = false;
00083 for (int predno = 0; predno < numPreds; predno++)
00084 {
00085 pred = clause_->getPredicate(predno);
00086 bool predBit = state & (1<<predno);
00087 if (pred->getSense() == predBit)
00088 {
00089 isSatisfied = true;
00090 break;
00091 }
00092 }
00093
00094 if (isSatisfied)
00095 {
00096
00097 factorMsgs_[state] = clause_->getWt();
00098 }
00099 else
00100 {
00101 factorMsgs_[state] = 0;
00102 }
00103 }
00104 }
00105
00106
00107 double* BPFactor::multiplyMessagesAndSumOut(int inpPredIndex)
00108 {
00109 int numPreds = clause_->getNumPredicates();
00110 int stateCnt = (int)pow(2.0, numPreds);
00111 double * prodMsgs = new double[stateCnt];
00112
00113 BPNode *node;
00114 double *gndNodeCnts = new double[numPreds];
00115
00116
00117
00118
00119 for (int predno = 0; predno < numPreds; predno++)
00120 gndNodeCnts[predno] = 0;
00121
00122 for (int lno = 0; lno < links_->size(); lno++)
00123 {
00124 int predIndex = (*links_)[lno]->getPredIndex();
00125 node = (*links_)[lno]->getNode();
00126 gndNodeCnts[predIndex] += node->getGroundNodeCount();
00127 }
00128
00129
00130 for (int state = 0; state < stateCnt; state++)
00131 prodMsgs[state] = factorMsgs_[state];
00132
00133 for (int lno = 0; lno < links_->size(); lno++)
00134 {
00135 int predIndex = (*links_)[lno]->getPredIndex();
00136 if (predIndex == inpPredIndex)
00137 continue;
00138
00139 node = (*links_)[lno]->getNode();
00140
00141
00142
00143 assert(gndNodeCnts[predIndex] != 0);
00144 double wt = node->getGroundNodeCount()/gndNodeCnts[predIndex];
00145
00146 for (int state = 0; state < stateCnt; state++)
00147 {
00148 bool predBit = state & (1<<predIndex);
00149 if (predBit)
00150 prodMsgs[state] += (*msgsArr_)[lno][1]*wt;
00151 else
00152 prodMsgs[state] += (*msgsArr_)[lno][0]*wt;
00153 }
00154 }
00155
00156
00157 double *outMsgs = new double[2];
00158 double maxMsgs[2];
00159 maxMsgs[0] = maxMsgs[1] = 0;
00160 bool firstTime[2];
00161 firstTime[0] = firstTime[1] = true;
00162
00163
00164 for (int state = 0; state < stateCnt; state++)
00165 {
00166 bool predBit = state & (1<<inpPredIndex);
00167 if (predBit && (maxMsgs[1] < prodMsgs[state] || firstTime[1]))
00168 {
00169 firstTime[1] = false;
00170 maxMsgs[1] = prodMsgs[state];
00171 }
00172
00173 if (!predBit && (maxMsgs[0] < prodMsgs[state] || firstTime[0]))
00174 {
00175 firstTime[0] = false;
00176 maxMsgs[0] = prodMsgs[state];
00177 }
00178 }
00179
00180 outMsgs[0] = outMsgs[1] = 0;
00181 for (int state = 0; state < stateCnt; state++)
00182 {
00183 bool predBit = state & (1<<inpPredIndex);
00184 if (predBit)
00185 outMsgs[1] += expl(prodMsgs[state] - maxMsgs[1]);
00186 else
00187 outMsgs[0] += expl(prodMsgs[state] - maxMsgs[0]);
00188 }
00189 outMsgs[1] = maxMsgs[1] + logl(outMsgs[1]);
00190 outMsgs[0] = maxMsgs[0] + logl(outMsgs[0]);
00191 outMsgs[1] = outMsgs[1] - outMsgs[0];
00192 outMsgs[0] = 0;
00193 delete [] prodMsgs;
00194 delete [] gndNodeCnts;
00195 return outMsgs;
00196 }
00197
00198
00199 void BPFactor::sendMessage()
00200 {
00201 double *outMsgs = NULL;
00202 BPLink *link;
00203 BPNode *node;
00204 int predIndex;
00205 for (int lindex = 0; lindex < links_->size(); lindex++)
00206 {
00207 link = (*links_)[lindex];
00208 node = link->getNode();
00209 predIndex = link->getPredIndex();
00210
00211 outMsgs = multiplyMessagesAndSumOut(predIndex);
00212
00213 node->receiveMessage(outMsgs, link);
00214 delete [] outMsgs;
00215 }
00216 }
00217
00218
00219 void BPFactor::moveToNextStep()
00220 {
00221 double * msgs;
00222 double * nextMsgs;
00223
00224
00225
00226 for (int lindex = 0;lindex < links_->size(); lindex++)
00227 {
00228 msgs = (*msgsArr_)[lindex];
00229 nextMsgs = (*nextMsgsArr_)[lindex];
00230 for (int i = 0; i < 2; i++)
00231 {
00232 msgs[i] = nextMsgs[i];
00233 nextMsgs[i] = 0;
00234 }
00235 }
00236 }
00237
00241 ostream& BPFactor::print(ostream& out)
00242 {
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255 for (int i = 0; i < links_->size(); i++)
00256 {
00257 (*links_)[i]->getNode()->print(out);
00258 if (i != links_->size() - 1) out << "/ ";
00259 }
00260
00261 return out;
00262 }
00263
00267 ostream& BPFactor::printWts(ostream& out)
00268 {
00269 for (int i = 0; i < (int)pow(2.0, clause_->getNumPredicates()); i++)
00270 {
00271 out << factorMsgs_[i] * outputWt_ << " ";
00272 }
00273
00274 return out;
00275 }
00276