00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef MCMC_H_
00067 #define MCMC_H_
00068
00069 #include "inference.h"
00070 #include "mcmcparams.h"
00071
00072
00073 const bool mcmcdebug = false;
00074
00079 class MCMC : public Inference
00080 {
00081 public:
00082
00089 MCMC(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00090 MCMCParams* params)
00091 : Inference(state, seed, trackClauseTrueCnts)
00092 {
00093
00094 numChains_ = params->numChains;
00095 burnMinSteps_ = params->burnMinSteps;
00096 burnMaxSteps_ = params->burnMaxSteps;
00097 minSteps_ = params->minSteps;
00098 maxSteps_ = params->maxSteps;
00099 maxSeconds_ = params->maxSeconds;
00100 }
00101
00105 ~MCMC() {}
00106
00110 void printProbabilities(ostream& out)
00111 {
00112 for (int i = 0; i < state_->getNumAtoms(); i++)
00113 {
00114 double prob = getProbTrue(i);
00115
00116
00117 prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00118 state_->printGndPred(i, out);
00119 out << " " << prob << endl;
00120 }
00121 }
00122
00129 double getProbability(GroundPredicate* const& gndPred)
00130 {
00131 int idx = state_->getGndPredIndex(gndPred);
00132 double prob = 0.0;
00133 if (idx >= 0) prob = getProbTrue(idx);
00134
00135 return (prob*10000 + 1/2.0)/(10000 + 1.0);
00136 }
00137
00141 void printTruePreds(ostream& out)
00142 {
00143 for (int i = 0; i < state_->getNumAtoms(); i++)
00144 {
00145 double prob = getProbTrue(i);
00146
00147
00148 prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00149 if (prob >= 0.5) state_->printGndPred(i, out);
00150 }
00151 }
00152
00153 protected:
00154
00163 void initTruthValuesAndWts(const int& numChains)
00164 {
00165 int numPreds = state_->getNumAtoms();
00166 truthValues_.growToSize(numPreds);
00167 wtsWhenFalse_.growToSize(numPreds);
00168 wtsWhenTrue_.growToSize(numPreds);
00169 for (int i = 0; i < numPreds; i++)
00170 {
00171 truthValues_[i].growToSize(numChains, false);
00172 wtsWhenFalse_[i].growToSize(numChains, 0);
00173 wtsWhenTrue_[i].growToSize(numChains, 0);
00174 }
00175
00176 int numClauses = state_->getNumClauses();
00177 numTrueLits_.growToSize(numClauses);
00178 for (int i = 0; i < numClauses; i++)
00179 {
00180 numTrueLits_[i].growToSize(numChains, 0);
00181 }
00182 }
00183
00188 void initNumTrue()
00189 {
00190 int numPreds = state_->getNumAtoms();
00191 numTrue_.growToSize(numPreds, 0);
00192 }
00193
00200 void initNumTrueLits(const int& numChains)
00201 {
00202 for (int i = 0; i < state_->getNumClauses(); i++)
00203 {
00204 GroundClause* gndClause = state_->getGndClause(i);
00205 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00206 {
00207 const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1;
00208 for (int c = 0; c < numChains; c++)
00209 {
00210 if (truthValues_[atomIdx][c] == gndClause->getGroundPredicateSense(j))
00211 {
00212 numTrueLits_[i][c]++;
00213 assert(numTrueLits_[i][c] <= state_->getNumAtoms());
00214 }
00215 }
00216 }
00217 }
00218 }
00219
00227 void randomInitGndPredsTruthValues(const int& numChains)
00228 {
00229 for (int c = 0; c < numChains; c++)
00230 {
00231
00232 for (int i = 0; i < state_->getNumBlocks(); i++)
00233 {
00234
00235 if (state_->getBlockEvidence(i))
00236 {
00237
00238 setOthersInBlockToFalse(c, -1, i);
00239 continue;
00240 }
00241
00242 Array<int>& block = state_->getBlockArray(i);
00243 int chosen = random() % block.size();
00244 truthValues_[block[chosen]][c] = true;
00245 setOthersInBlockToFalse(c, chosen, i);
00246 }
00247
00248
00249 for (int i = 0; i < truthValues_.size(); i++)
00250 {
00251
00252 if (state_->getBlockIndex(i) == -1)
00253 {
00254 bool tv = genTruthValueForProb(0.5);
00255 truthValues_[i][c] = tv;
00256 }
00257 }
00258 }
00259 }
00260
00267 bool genTruthValueForProb(const double& p)
00268 {
00269 if (p == 1.0) return true;
00270 if (p == 0.0) return false;
00271 bool r = random() <= p*RAND_MAX;
00272 return r;
00273 }
00274
00284 double getProbabilityOfPred(const int& predIdx, const int& chainIdx,
00285 const double& invTemp)
00286 {
00287
00288 if (numChains_ > 1)
00289 {
00290 return 1.0 /
00291 ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] -
00292 wtsWhenTrue_[predIdx][chainIdx]) *
00293 invTemp));
00294 }
00295 else
00296 {
00297 GroundPredicate* gndPred = state_->getGndPred(predIdx);
00298 return 1.0 /
00299 ( 1.0 + exp((gndPred->getWtWhenFalse() -
00300 gndPred->getWtWhenTrue()) *
00301 invTemp));
00302 }
00303 }
00304
00313 void setOthersInBlockToFalse(const int& chainIdx, const int& atomIdx,
00314 const int& blockIdx)
00315 {
00316 Array<int>& block = state_->getBlockArray(blockIdx);
00317 for (int i = 0; i < block.size(); i++)
00318 {
00319 if (i != atomIdx)
00320 truthValues_[block[i]][chainIdx] = false;
00321 }
00322 }
00323
00334 void performGibbsStep(const int& chainIdx, const bool& burningIn,
00335 GroundPredicateHashArray& affectedGndPreds,
00336 Array<int>& affectedGndPredIndices)
00337 {
00338 if (mcmcdebug) cout << "Gibbs step" << endl;
00339
00340
00341 for (int i = 0; i < state_->getNumBlocks(); i++)
00342 {
00343
00344 if (state_->getBlockEvidence(i)) continue;
00345
00346 Array<int>& block = state_->getBlockArray(i);
00347
00348 int chosen = gibbsSampleFromBlock(chainIdx, block, 1);
00349
00350 bool truthValue;
00351 GroundPredicate* gndPred = state_->getGndPred(block[chosen]);
00352 if (numChains_ > 1) truthValue = truthValues_[block[chosen]][chainIdx];
00353 else truthValue = gndPred->getTruthValue();
00354
00355
00356 if (!truthValue)
00357 {
00358 for (int j = 0; j < block.size(); j++)
00359 {
00360
00361 bool otherTruthValue;
00362 GroundPredicate* otherGndPred = state_->getGndPred(block[j]);
00363 if (numChains_ > 1)
00364 otherTruthValue = truthValues_[block[j]][chainIdx];
00365 else
00366 otherTruthValue = otherGndPred->getTruthValue();
00367 if (otherTruthValue)
00368 {
00369
00370 if (numChains_ > 1)
00371 truthValues_[block[j]][chainIdx] = false;
00372 else
00373 otherGndPred->setTruthValue(false);
00374
00375 affectedGndPreds.clear();
00376 affectedGndPredIndices.clear();
00377 gndPredFlippedUpdates(block[j], chainIdx, affectedGndPreds,
00378 affectedGndPredIndices);
00379 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00380 chainIdx);
00381 }
00382 }
00383
00384
00385 if (numChains_ > 1) truthValues_[block[chosen]][chainIdx] = true;
00386 else gndPred->setTruthValue(true);
00387 affectedGndPreds.clear();
00388 affectedGndPredIndices.clear();
00389 gndPredFlippedUpdates(block[chosen], chainIdx, affectedGndPreds,
00390 affectedGndPredIndices);
00391 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00392 chainIdx);
00393 }
00394
00395
00396
00397 if (!burningIn) numTrue_[block[chosen]]++;
00398 }
00399
00400
00401 for (int i = 0; i < state_->getNumAtoms(); i++)
00402 {
00403
00404 if (state_->getBlockIndex(i) >= 0) continue;
00405
00406 if (mcmcdebug)
00407 {
00408 cout << "Chain " << chainIdx << ": Probability of pred "
00409 << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl;
00410 }
00411
00412 bool newAssignment
00413 = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1));
00414
00415
00416 bool truthValue;
00417 GroundPredicate* gndPred = state_->getGndPred(i);
00418 if (numChains_ > 1) truthValue = truthValues_[i][chainIdx];
00419 else truthValue = gndPred->getTruthValue();
00420
00421 if (newAssignment != truthValue)
00422 {
00423 if (mcmcdebug)
00424 {
00425 cout << "Chain " << chainIdx << ": Changing truth value of pred "
00426 << i << " to " << newAssignment << endl;
00427 }
00428
00429 if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment;
00430 else gndPred->setTruthValue(newAssignment);
00431 affectedGndPreds.clear();
00432 affectedGndPredIndices.clear();
00433 gndPredFlippedUpdates(i, chainIdx, affectedGndPreds,
00434 affectedGndPredIndices);
00435 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00436 chainIdx);
00437 }
00438
00439
00440
00441 if (!burningIn && newAssignment) numTrue_[i]++;
00442 }
00443
00444 if (!burningIn && trackClauseTrueCnts_)
00445 state_->getNumClauseGndings(clauseTrueCnts_, true);
00446
00447 if (mcmcdebug) cout << "End of Gibbs step" << endl;
00448 }
00449
00458 void updateWtsForGndPreds(GroundPredicateHashArray& gndPreds,
00459 Array<int>& gndPredIndices,
00460 const int& chainIdx)
00461 {
00462 if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
00463
00464 for (int g = 0; g < gndPreds.size(); g++)
00465 {
00466 double wtIfNoChange = 0, wtIfInverted = 0, wt;
00467
00468 Array<int>& negGndClauses =
00469 state_->getNegOccurenceArray(gndPredIndices[g] + 1);
00470 Array<int>& posGndClauses =
00471 state_->getPosOccurenceArray(gndPredIndices[g] + 1);
00472 int gndClauseIdx;
00473 bool sense;
00474
00475 if (mcmcdebug)
00476 {
00477 cout << "Ground clauses in which pred " << g << " occurs neg.: "
00478 << negGndClauses.size() << endl;
00479 cout << "Ground clauses in which pred " << g << " occurs pos.: "
00480 << posGndClauses.size() << endl;
00481 }
00482
00483 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00484 {
00485 if (i < negGndClauses.size())
00486 {
00487 gndClauseIdx = negGndClauses[i];
00488 if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
00489 sense = false;
00490 }
00491 else
00492 {
00493 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00494 if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
00495 sense = true;
00496 }
00497
00498 GroundClause* gndClause = state_->getGndClause(gndClauseIdx);
00499 if (gndClause->isHardClause())
00500 wt = state_->getClauseCost(gndClauseIdx);
00501 else
00502 wt = gndClause->getWt();
00503
00504 int numSatLiterals;
00505 if (numChains_ > 1)
00506 numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
00507 else
00508 numSatLiterals = state_->getNumTrueLits(gndClauseIdx);
00509 if (numSatLiterals > 1)
00510 {
00511
00512
00513 if (wt > 0)
00514 {
00515 wtIfNoChange += wt;
00516 wtIfInverted += wt;
00517 }
00518 }
00519 else
00520 if (numSatLiterals == 1)
00521 {
00522 if (wt > 0) wtIfNoChange += wt;
00523
00524 bool truthValue;
00525 if (numChains_ > 1)
00526 truthValue = truthValues_[gndPredIndices[g]][chainIdx];
00527 else
00528 truthValue = gndPreds[g]->getTruthValue();
00529
00530 if (truthValue == sense)
00531 {
00532
00533 if (wt < 0) wtIfInverted += abs(wt);
00534 }
00535 else
00536 {
00537
00538 if (wt > 0) wtIfInverted += wt;
00539 }
00540 }
00541 else
00542 if (numSatLiterals == 0)
00543 {
00544
00545 if (wt > 0) wtIfInverted += wt;
00546 else if (wt < 0) wtIfNoChange += abs(wt);
00547 }
00548 }
00549
00550 if (mcmcdebug)
00551 {
00552 cout << "wtIfNoChange of pred " << g << ": "
00553 << wtIfNoChange << endl;
00554 cout << "wtIfInverted of pred " << g << ": "
00555 << wtIfInverted << endl;
00556 }
00557
00558
00559 if (numChains_ > 1)
00560 {
00561 if (truthValues_[gndPredIndices[g]][chainIdx])
00562 {
00563 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
00564 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
00565 }
00566 else
00567 {
00568 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
00569 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
00570 }
00571 }
00572 else
00573 {
00574 if (gndPreds[g]->getTruthValue())
00575 {
00576 gndPreds[g]->setWtWhenTrue(wtIfNoChange);
00577 gndPreds[g]->setWtWhenFalse(wtIfInverted);
00578 }
00579 else
00580 {
00581 gndPreds[g]->setWtWhenFalse(wtIfNoChange);
00582 gndPreds[g]->setWtWhenTrue(wtIfInverted);
00583 }
00584 }
00585 }
00586 if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
00587 }
00588
00598 int gibbsSampleFromBlock(const int& chainIdx, const Array<int>& block,
00599 const double& invTemp)
00600 {
00601 Array<double> numerators;
00602 double denominator = 0;
00603
00604 for (int i = 0; i < block.size(); i++)
00605 {
00606 double prob = getProbabilityOfPred(block[i], chainIdx, invTemp);
00607 numerators.append(prob);
00608 denominator += prob;
00609 }
00610 double r = random();
00611 double numSum = 0.0;
00612 for (int i = 0; i < block.size(); i++)
00613 {
00614 numSum += numerators[i];
00615 if (r < ((numSum / denominator) * RAND_MAX))
00616 {
00617 return i;
00618 }
00619 }
00620 return block.size() - 1;
00621 }
00622
00631 void gndPredFlippedUpdates(const int& gndPredIdx, const int& chainIdx,
00632 GroundPredicateHashArray& affectedGndPreds,
00633 Array<int>& affectedGndPredIndices)
00634 {
00635 if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl;
00636 int numAtoms = state_->getNumAtoms();
00637 GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
00638 affectedGndPreds.append(gndPred, numAtoms);
00639 affectedGndPredIndices.append(gndPredIdx);
00640 assert(affectedGndPreds.size() <= numAtoms);
00641
00642 Array<int>& negGndClauses =
00643 state_->getNegOccurenceArray(gndPredIdx + 1);
00644 Array<int>& posGndClauses =
00645 state_->getPosOccurenceArray(gndPredIdx + 1);
00646 int gndClauseIdx;
00647 GroundClause* gndClause;
00648 bool sense;
00649
00650
00651 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00652 {
00653 if (i < negGndClauses.size())
00654 {
00655 gndClauseIdx = negGndClauses[i];
00656 sense = false;
00657 }
00658 else
00659 {
00660 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00661 sense = true;
00662 }
00663 gndClause = state_->getGndClause(gndClauseIdx);
00664
00665
00666 if (numChains_ > 1)
00667 {
00668 if (truthValues_[gndPredIdx][chainIdx] == sense)
00669 numTrueLits_[gndClauseIdx][chainIdx]++;
00670 else
00671 numTrueLits_[gndClauseIdx][chainIdx]--;
00672 }
00673 else
00674 {
00675 if (gndPred->getTruthValue() == sense)
00676 state_->incrementNumTrueLits(gndClauseIdx);
00677 else
00678 state_->decrementNumTrueLits(gndClauseIdx);
00679 }
00680
00681 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00682 {
00683 const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr();
00684 GroundPredicate* pred =
00685 (GroundPredicate*)gndClause->getGroundPredicate(j,
00686 (GroundPredicateHashArray*)gpha);
00687 affectedGndPreds.append(pred, numAtoms);
00688 affectedGndPredIndices.append(
00689 abs(gndClause->getGroundPredicateIndex(j)) - 1);
00690 assert(affectedGndPreds.size() <= numAtoms);
00691 }
00692 }
00693 if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl;
00694 }
00695
00696 double getProbTrue(const int& predIdx) const { return numTrue_[predIdx]; }
00697
00698 void setProbTrue(const int& predIdx, const double& p)
00699 {
00700 assert(p >= 0);
00701 numTrue_[predIdx] = p;
00702 }
00703
00710 void saveLowStateToChain(const int& chainIdx)
00711 {
00712 for (int i = 0; i < state_->getNumAtoms(); i++)
00713 truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1);
00714 }
00715
00721 void setMCMCParameters(MCMCParams* params)
00722 {
00723
00724 numChains_ = params->numChains;
00725 burnMinSteps_ = params->burnMinSteps;
00726 burnMaxSteps_ = params->burnMaxSteps;
00727 minSteps_ = params->minSteps;
00728 maxSteps_ = params->maxSteps;
00729 maxSeconds_ = params->maxSeconds;
00730 }
00731
00732 protected:
00733
00735
00736 int numChains_;
00737
00738 int burnMinSteps_;
00739
00740 int burnMaxSteps_;
00741
00742 int minSteps_;
00743
00744 int maxSteps_;
00745
00746 int maxSeconds_;
00748
00749
00750 Array<Array<bool> > truthValues_;
00751
00752 Array<Array<double> > wtsWhenFalse_;
00753
00754 Array<Array<double> > wtsWhenTrue_;
00755
00756
00757
00758 Array<double> numTrue_;
00759
00760
00761
00762 Array<Array<int> > numTrueLits_;
00763 };
00764
00765 #endif