00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #ifndef BP_H_
00068 #define BP_H_
00069
00070 #include "inference.h"
00071 #include "bpparams.h"
00072 #include "twowaymessage.h"
00073 #include "superclause.h"
00074 #include "auxfactor.h"
00075 #include "node.h"
00076 #include "factorgraph.h"
00077
00078 const int bpdebug = true;
00079
00084 class BP : public Inference
00085 {
00086 public:
00087
00092 BP(FactorGraph* factorGraph, BPParams* bpParams,
00093 Array<Array<Predicate* >* >* queryFormulas = NULL)
00094 : Inference(NULL, -1, false, queryFormulas)
00095 {
00096 factorGraph_ = factorGraph;
00097 maxSteps_ = bpParams->maxSteps;
00098 maxSeconds_ = bpParams->maxSeconds;
00099 convergenceThresh_ = bpParams->convergenceThresh;
00100 convergeRequiredItrCnt_ = bpParams->convergeRequiredItrCnt;
00101 outputNetwork_ = bpParams->outputNetwork;
00102 }
00103
00107 ~BP()
00108 {
00109 }
00110
00114 void init()
00115 {
00116 Timer timer1;
00117 cout << "Initializing ";
00118 cout << "Belief Propagation..." << endl;
00119
00120 factorGraph_->init();
00121 if (bpdebug)
00122 {
00123 cout << "[init] ";
00124 Timer::printTime(cout, timer1.time());
00125 cout << endl;
00126 timer1.reset();
00127 }
00128 }
00129
00133 void infer()
00134 {
00135 Timer timer1;
00136
00137 double oldProbs[2];
00138 double newProbs[2];
00139 double diff;
00140 double maxDiff;
00141 int maxDiffNodeIndex;
00142 int convergeItrCnt = 0;
00143 bool converged = false;
00144 int numFactors = factorGraph_->getNumFactors();
00145 int numNodes = factorGraph_->getNumNodes();
00146
00147 cout << "factorcnt = " << numFactors
00148 << ", nodecnt = " << numNodes << endl;
00149
00150 if (bpdebug)
00151 {
00152 cout << "factors:" << endl;
00153 for (int i = 0; i < numFactors; i++)
00154 {
00155 factorGraph_->getFactor(i)->print(cout);
00156 cout << endl;
00157 }
00158 cout << "nodes:" << endl;
00159 for (int i = 0; i < numNodes; i++)
00160 {
00161 factorGraph_->getNode(i)->print(cout);
00162 cout << endl;
00163 }
00164 }
00165
00166
00167 int itr;
00168
00169
00170
00171 for (itr = 1; itr <= maxSteps_; itr++)
00172 {
00173 if (bpdebug)
00174 {
00175 cout<<"*************************************"<<endl;
00176 cout<<"Performing Iteration "<<itr<<" of BP"<<endl;
00177 cout<<"*************************************"<<endl;
00178 }
00179
00180 for (int i = 0; i < numFactors; i++)
00181 {
00182 if (bpdebug)
00183 {
00184 cout << "Sending messages for Factor: ";
00185 factorGraph_->getFactor(i)->print(cout); cout << endl;
00186 }
00187 factorGraph_->getFactor(i)->sendMessage();
00188 }
00189
00190 for (int i = 0; i < numNodes; i++)
00191 {
00192 if (bpdebug)
00193 {
00194 cout << "Sending messages for Node: ";
00195 factorGraph_->getNode(i)->print(cout); cout << endl;
00196 }
00197 factorGraph_->getNode(i)->sendMessage();
00198 }
00199
00200 for (int i = 0; i < numFactors; i++)
00201 {
00202 factorGraph_->getFactor(i)->moveToNextStep();
00203 if (bpdebug)
00204 {
00205 cout << "BP-Factor Iteration " << itr << " => ";
00206 factorGraph_->getFactor(i)->print(cout); cout << endl;
00207 }
00208 }
00209
00210 maxDiff = -1;
00211 maxDiffNodeIndex = -1;
00212 for (int i = 0; i < numNodes; i++)
00213 {
00214 if (bpdebug)
00215 {
00216 cout<<"************************************"<<endl;
00217 cout<<"Node "<<i<<":"<<endl;
00218 cout<<"************************************"<<endl;
00219 cout<<"Getting Old Probabilities =>"<<endl;
00220 cout<<endl;
00221 cout<<"Moving to next step "<<endl;
00222 cout<<endl;
00223 cout<<"Getting New Probabilities =>"<<endl;
00224 }
00225
00226 factorGraph_->getNode(i)->getProbs(oldProbs);
00227 factorGraph_->getNode(i)->moveToNextStep();
00228 factorGraph_->getNode(i)->getProbs(newProbs);
00229
00230 diff = abs(newProbs[1] - oldProbs[1]);
00231
00232 if (bpdebug)
00233 {
00234 cout << endl << endl << "Final Probs : " << endl;
00235 cout << "Node " << i << ": probs[" << 0 << "] = " << newProbs[0]
00236 << ", probs[" << 1 << "] = " << newProbs[1] << endl;
00237 cout << "BP-Node Iteration " << itr << ": " << newProbs[0]
00238 << " probs[" << 1 << "] = " << newProbs[1] << endl;
00239 cout << " : => ";
00240 factorGraph_->getNode(i)->print(cout);
00241 cout << endl;
00242 }
00243
00244 if (maxDiff < diff)
00245 {
00246 maxDiff = diff;
00247 maxDiffNodeIndex = i;
00248 }
00249 }
00250
00251 cout << "At Iteration " << itr << ": MaxDiff = " << maxDiff << endl;
00252 cout << endl;
00253
00254
00255 if (maxDiff < convergenceThresh_)
00256 convergeItrCnt++;
00257 else
00258 convergeItrCnt = 0;
00259
00260
00261
00262 if (convergeItrCnt >= convergeRequiredItrCnt_)
00263 {
00264 converged = true;
00265 break;
00266 }
00267 }
00268
00269 if (converged)
00270 {
00271 cout << "Converged in " << itr << " Iterations " << endl;
00272 }
00273 else
00274 {
00275 cout << "Did not converge in " << maxSteps_ << " (max allowed) Iterations"
00276 << endl;
00277 }
00278
00279 if (queryFormulas_)
00280 {
00281 cout << "Computing probabilities of query formulas ..." << endl;
00282 for (int i = 0; i < numNodes; i++)
00283 {
00284 if (bpdebug)
00285 {
00286 cout << "Sending auxiliary messages for Node: ";
00287 factorGraph_->getNode(i)->print(cout); cout << endl;
00288 }
00289 factorGraph_->getNode(i)->sendAuxMessage();
00290
00291 }
00292 for (int j = 0; j < qfProbs_->size(); j++)
00293 {
00294 (*qfProbs_)[j] = factorGraph_->getAuxFactor(j)->getProb();
00295 }
00296 }
00297 }
00298
00302 void printNetwork(ostream& out)
00303 {
00304 factorGraph_->printNetwork(out);
00305 }
00306
00310 void printProbabilities(ostream& out)
00311 {
00312 double probs[2];
00313 Array<int>* constants;
00314 Predicate* pred;
00315 int predId;
00316 Node* node;
00317 double exp;
00318 Domain* domain = factorGraph_->getDomain();
00319 for (int i = 0; i < factorGraph_->getNumNodes(); i++)
00320 {
00321 node = factorGraph_->getNode(i);
00322 predId = node->getPredId();
00323 node->getProbs(probs);
00324 exp = node->getExp();
00325 SuperPred * superPred = node->getSuperPred();
00326
00327 if (superPred)
00328 {
00329 for (int index = 0; index < superPred->getNumTuples(); index++)
00330 {
00331 constants = superPred->getConstantTuple(index);
00332 pred = domain->getPredicate(constants, predId);
00333 pred->printWithStrVar(out, domain);
00334 out << " " << probs[1] << endl;
00335
00336 }
00337 }
00338 else
00339 {
00340 constants = node->getConstants();
00341 assert(constants != NULL);
00342 pred = domain->getPredicate(constants, predId);
00343 pred->printWithStrVar(out, domain);
00344 out << " " << probs[1] << endl;
00345
00346 }
00347 }
00348 }
00349
00364 void getChangedPreds(vector<string>& changedPreds, vector<float>& probs,
00365 vector<float>& oldProbs, const float& probDelta)
00366 {
00367 }
00368
00375 double getProbability(GroundPredicate* const& gndPred)
00376 {
00377 double probs[2];
00378 Array<int>* constants;
00379 Predicate* pred;
00380 unsigned int predId;
00381 Node* node;
00382 Domain* domain = factorGraph_->getDomain();
00383 bool found = false;
00384 for (int i = 0; i < factorGraph_->getNumNodes(); i++)
00385 {
00386 node = factorGraph_->getNode(i);
00387 predId = node->getPredId();
00388 if (predId != gndPred->getId()) continue;
00389 node->getProbs(probs);
00390 SuperPred * superPred = node->getSuperPred();
00391
00392 if (superPred)
00393 {
00394 for (int index = 0; index < superPred->getNumTuples(); index++)
00395 {
00396 constants = superPred->getConstantTuple(index);
00397 pred = domain->getPredicate(constants, predId);
00398 if (!pred->same(gndPred))
00399 {
00400 delete pred;
00401 continue;
00402 }
00403 delete pred;
00404 found = true;
00405 return probs[1];
00406 }
00407 }
00408 else
00409 {
00410 constants = node->getConstants();
00411 assert(constants != NULL);
00412 pred = domain->getPredicate(constants, predId);
00413 if (!pred->same(gndPred))
00414 {
00415 delete pred;
00416 continue;
00417 }
00418 delete pred;
00419 found = true;
00420 return probs[1];
00421 }
00422 }
00423 return 0.5;
00424 }
00425
00432 double getProbabilityH(GroundPredicate* const& gndPred)
00433 {
00434 return 0.0;
00435 }
00436
00441 void printTruePreds(ostream& out)
00442 {
00443 }
00444
00449 void printTruePredsH(ostream& out)
00450 {
00451 }
00452
00453 private:
00454
00455 FactorGraph* factorGraph_;
00456
00457 int maxSteps_;
00458
00459 int maxSeconds_;
00460
00461
00462 double convergenceThresh_;
00463
00464 int convergeRequiredItrCnt_;
00465
00466 bool outputNetwork_;
00467 };
00468
00469 #endif