00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include "factor.h"
00068
00069
00070
00071
00072
00073
00074 void Factor::initFactorMesssages()
00075 {
00076 Predicate *pred;
00077 int numPreds = clause_->getNumPredicates();
00078 int stateCnt = (int)pow(2.0, numPreds);
00079 factorMsgs_ = new double[stateCnt];
00080 bool isSatisfied;
00081 for (int state = 0; state < stateCnt; state++)
00082 {
00083 isSatisfied = false;
00084 for (int predno = 0; predno < numPreds; predno++)
00085 {
00086 pred = clause_->getPredicate(predno);
00087 bool predBit = state & (1<<predno);
00088 if (pred->getSense() == predBit)
00089 {
00090 isSatisfied = true;
00091 break;
00092 }
00093 }
00094
00095 if (isSatisfied)
00096 {
00097
00098
00099 factorMsgs_[state] = clause_->getWt();
00100 }
00101 else
00102 {
00103 factorMsgs_[state] = 0;
00104 }
00105 }
00106 }
00107
00108
00109 double* Factor::multiplyMessagesAndSumOut(int inpPredIndex)
00110 {
00111 int numPreds = clause_->getNumPredicates();
00112 int stateCnt = (int)pow(2.0, numPreds);
00113 double * prodMsgs = new double[stateCnt];
00114
00115 Node *node;
00116 double *gndNodeCnts = new double[numPreds];
00117
00118
00119
00120
00121 for (int predno = 0; predno < numPreds; predno++)
00122 gndNodeCnts[predno] = 0;
00123
00124 for (int lno = 0; lno < links_->size(); lno++)
00125 {
00126 int predIndex = (*links_)[lno]->getPredIndex();
00127 node = (*links_)[lno]->getNode();
00128 gndNodeCnts[predIndex] += node->getGroundNodeCount();
00129 }
00130
00131
00132 for (int state = 0; state < stateCnt; state++)
00133 prodMsgs[state] = factorMsgs_[state];
00134
00135 for (int lno = 0; lno < links_->size(); lno++)
00136 {
00137 int predIndex = (*links_)[lno]->getPredIndex();
00138 if (predIndex == inpPredIndex)
00139 continue;
00140
00141 node = (*links_)[lno]->getNode();
00142
00143
00144
00145 assert(gndNodeCnts[predIndex] != 0);
00146
00147 double cnt = node->getGroundNodeCount()/gndNodeCnts[predIndex];
00148
00149
00150
00151
00152 for (int state = 0; state < stateCnt; state++)
00153 {
00154 bool predBit = state & (1<<predIndex);
00155 if (predBit)
00156 prodMsgs[state] += (*msgsArr_)[lno][1]*cnt;
00157 else
00158 prodMsgs[state] += (*msgsArr_)[lno][0]*cnt;
00159
00160 }
00161 }
00162
00163
00164 double *outMsgs = new double[2];
00165 double maxMsgs[2];
00166 maxMsgs[0] = maxMsgs[1] = 0;
00167 bool firstTime[2];
00168 firstTime[0] = firstTime[1] = true;
00169
00170
00171 for (int state = 0; state < stateCnt; state++)
00172 {
00173 bool predBit = state & (1<<inpPredIndex);
00174 if (predBit && (maxMsgs[1] < prodMsgs[state] || firstTime[1]))
00175 {
00176 firstTime[1] = false;
00177 maxMsgs[1] = prodMsgs[state];
00178 }
00179
00180 if (!predBit && (maxMsgs[0] < prodMsgs[state] || firstTime[0]))
00181 {
00182 firstTime[0] = false;
00183 maxMsgs[0] = prodMsgs[state];
00184 }
00185 }
00186
00187 outMsgs[0] = outMsgs[1] = 0;
00188 for (int state = 0; state < stateCnt; state++)
00189 {
00190 bool predBit = state & (1<<inpPredIndex);
00191 if (predBit)
00192 outMsgs[1] += expl(prodMsgs[state] - maxMsgs[1]);
00193 else
00194 outMsgs[0] += expl(prodMsgs[state] - maxMsgs[0]);
00195 }
00196 outMsgs[1] = maxMsgs[1] + logl(outMsgs[1]);
00197 outMsgs[0] = maxMsgs[0] + logl(outMsgs[0]);
00198 outMsgs[1] = outMsgs[1] - outMsgs[0];
00199 outMsgs[0] = 0;
00200 delete [] prodMsgs;
00201 delete [] gndNodeCnts;
00202 return outMsgs;
00203 }
00204
00205
00206 void Factor::sendMessage()
00207 {
00208 double *outMsgs = NULL;
00209 Link *link;
00210 Node *node;
00211 int predIndex;
00212 for (int lindex = 0; lindex < links_->size(); lindex++)
00213 {
00214 link = (*links_)[lindex];
00215 node = link->getNode();
00216 predIndex = link->getPredIndex();
00217
00218 outMsgs = multiplyMessagesAndSumOut(predIndex);
00219
00220 node->receiveMessage(outMsgs, link);
00221 delete [] outMsgs;
00222 }
00223 }
00224
00225
00226 void Factor::moveToNextStep()
00227 {
00228 double * msgs;
00229 double * nextMsgs;
00230
00231
00232
00233 for (int lindex = 0;lindex < links_->size(); lindex++)
00234 {
00235 msgs = (*msgsArr_)[lindex];
00236 nextMsgs = (*nextMsgsArr_)[lindex];
00237 for (int i = 0; i < 2; i++)
00238 {
00239 msgs[i] = nextMsgs[i];
00240 nextMsgs[i] = 0;
00241 }
00242 }
00243 }
00244
00248 ostream& Factor::print(ostream& out)
00249 {
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262 for (int i = 0; i < links_->size(); i++)
00263 {
00264 (*links_)[i]->getNode()->print(out);
00265 if (i != links_->size() - 1) out << "/ ";
00266 }
00267
00268 return out;
00269 }
00270
00274 ostream& Factor::printWts(ostream& out)
00275 {
00276 for (int i = 0; i < (int)pow(2.0, clause_->getNumPredicates()); i++)
00277 {
00278 out << factorMsgs_[i] * outputWt_ << " ";
00279 }
00280
00281 return out;
00282 }
00283