00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #ifndef MCMC_H_
00068 #define MCMC_H_
00069
00070 #include "inference.h"
00071 #include "mcmcparams.h"
00072
00073
00074 const bool mcmcdebug = false;
00075
00080 class MCMC : public Inference
00081 {
00082 public:
00083
00090 MCMC(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00091 MCMCParams* params, Array<Array<Predicate* >* >* queryFormulas = NULL)
00092 : Inference(state, seed, trackClauseTrueCnts, queryFormulas)
00093 {
00094
00095 numChains_ = params->numChains;
00096 burnMinSteps_ = params->burnMinSteps;
00097 burnMaxSteps_ = params->burnMaxSteps;
00098 minSteps_ = params->minSteps;
00099 maxSteps_ = params->maxSteps;
00100 maxSeconds_ = params->maxSeconds;
00101 }
00102
00103
00104 MCMC(HVariableState* state, long int seed, const bool& trackClauseTrueCnts,
00105 MCMCParams* params)
00106 : Inference(state, seed, trackClauseTrueCnts)
00107 {
00108
00109 numChains_ = params->numChains;
00110 burnMinSteps_ = params->burnMinSteps;
00111 burnMaxSteps_ = params->burnMaxSteps;
00112 minSteps_ = params->minSteps;
00113 maxSteps_ = params->maxSteps;
00114 maxSeconds_ = params->maxSeconds;
00115 }
00116
00120 ~MCMC() {}
00121
00122
00123
00124
00125
00126
00127 double computeHybridClauseValue(int clauseIdx, int c)
00128 {
00129 double contClauseContPartValue = HybridClauseContPartValue(clauseIdx, c);
00130 double contClauseDisPartValue = HybridClauseDisPartValue(clauseIdx, c);
00131 return hstate_->hybridWts_[clauseIdx] * contClauseContPartValue * contClauseDisPartValue;
00132 }
00133
00134 double HybridClauseContPartValue(int contClauseIdx, int c)
00135 {
00136 PolyNomial& pl = hstate_->GetHybridClausePolynomial(contClauseIdx);
00137
00138 assert(hstate_->hybridContClause_[contClauseIdx].size() == pl.GetVarNum());
00139
00140 Array<double> arVar;
00141 for(int i = 0; i < hstate_->hybridContClause_[contClauseIdx].size(); ++i)
00142 {
00143 arVar.append(truthValuesCont_[hstate_->hybridContClause_[contClauseIdx][i] - 1][c]);
00144 }
00145 double v = pl.ComputePlValue(arVar);
00146
00147 return v;
00148 }
00149
00150 double HybridClauseDisPartValue(int contClauseIdx, int c)
00151 {
00152 bool bAndOr = hstate_->hybridConjunctionDisjunction_[contClauseIdx];
00153 int numTrueLits = 0;
00154 int numFalseLits = 0;
00155 for(int j = 0; j < hstate_->hybridDisClause_[contClauseIdx].size(); ++j)
00156 {
00157 int atomIdx = hstate_->hybridDisClause_[contClauseIdx][j];
00158 if((atomIdx > 0) == truthValues_[abs(atomIdx)-1][c])
00159 {
00160 numTrueLits ++;
00161 if(!bAndOr)
00162 {
00163 break;
00164 }
00165 }
00166 else
00167 {
00168 numFalseLits ++;
00169 if(bAndOr)
00170 {
00171 break;
00172 }
00173 }
00174 }
00175 if(!bAndOr)
00176 {
00177 return (numTrueLits > 0)?1.0:0.0;
00178 }
00179 else
00180 {
00181 return (numFalseLits > 0)?0.0:1.0;
00182 }
00183 return 0.0;
00184 }
00185
00186 void updateDisPredValue(int predIdx, int chainIdx, bool updateValue)
00187 {
00188 bool bBak = truthValues_[predIdx][chainIdx];
00189 Array<int>& occDisPos = hstate_->getPosOccurenceArray(predIdx + 1);
00190 Array<int>& occDisNeg = hstate_->getNegOccurenceArray(predIdx + 1);
00191 Array<int>& occCont = hstate_->getContDisOccurenceArray(predIdx + 1) ;
00192
00193 if(updateValue)
00194 {
00195 if(!bBak)
00196 {
00197
00198 for(int j = 0; j < occDisPos.size(); ++j)
00199 {
00200 int disClauseIdx = occDisPos[j];
00201 numTrueLits_[disClauseIdx][chainIdx] += 1;
00202 }
00203
00204 for( int j = 0; j < occDisNeg.size(); ++j)
00205 {
00206 int disClauseIdx = occDisNeg[j];
00207 numTrueLits_[disClauseIdx][chainIdx] -= 1;
00208 }
00209 }
00210 }
00211 else
00212 {
00213 if(bBak)
00214 {
00215 for( int j = 0; j < occDisPos.size(); ++j)
00216 {
00217 int disClauseIdx = occDisPos[j];
00218 numTrueLits_[disClauseIdx][chainIdx] -= 1;
00219 }
00220
00221 for( int j = 0; j < occDisNeg.size(); ++j)
00222 {
00223 int disClauseIdx = occDisNeg[j];
00224 numTrueLits_[disClauseIdx][chainIdx] += 1;
00225 }
00226 }
00227 }
00228
00229 truthValues_[predIdx][chainIdx] = updateValue;
00230
00231 if(bBak != updateValue)
00232 {
00233 for( int j = 0; j < occCont.size(); ++j)
00234 {
00235 int hybridClauseIdx = occCont[j];
00236 hybridClauseDisPartValueMCMC_[hybridClauseIdx] = HybridClauseDisPartValue(hybridClauseIdx, chainIdx)==1.0?true:false;
00237 }
00238 }
00239 }
00240
00241 void updateProposalContValue(int contPredIdx, int chainIdx)
00242 {
00243 Array<int>& occContCont = hstate_->hybridContOccurrence_[contPredIdx + 1];
00244 PolyNomial pl;
00245 for(int j = 0; j < occContCont.size(); j++)
00246 {
00247 int contClauseIdx = occContCont[j];
00248 PolyNomial pltmp = hstate_->hybridPls_[contClauseIdx];
00249
00250 int inIdx = -1;
00251 Array<double> arVars;
00252 for(int k = 0; k < hstate_->hybridContClause_[contClauseIdx].size(); k++)
00253 {
00254 arVars.append(truthValuesCont_[hstate_->hybridContClause_[contClauseIdx][k]-1][chainIdx]);
00255 if (hstate_->hybridContClause_[contClauseIdx][k] == contPredIdx+1)
00256 {
00257
00258 inIdx = k;
00259 }
00260 }
00261 if (inIdx == -1)
00262 {
00263 cout << "faint" << endl;
00264 }
00265
00266 pltmp.ReduceToOneVar(arVars, inIdx);
00267 pl.AddPl(pltmp);
00268 }
00269
00270 double miu = 0, stdev = 0;
00271 pl.GetGaussianPara(&miu, &stdev);
00272 truthValuesCont_[contPredIdx][chainIdx] = ExtRandom::gaussRandom(miu, stdev);
00273
00274 for(int j = 0; j < occContCont.size(); j++)
00275 {
00276 int hybridClauseIdx = occContCont[j];
00277 hybridClauseContPartValueMCMC_[hybridClauseIdx][chainIdx] = HybridClauseContPartValue(hybridClauseIdx,chainIdx);
00278 }
00279 }
00280
00281
00282 double getProbabilityOfPredH(const int& predIdx, const int& chainIdx, const double& invTemp)
00283 {
00284
00285 if (numChains_ > 1)
00286 {
00287 double wtDisAsTrue = 0, wtDisAsFalse = 0;
00288 bool bBak = truthValues_[predIdx][chainIdx];
00289
00290 Array<int>& occDisPos = hstate_->getPosOccurenceArray(predIdx + 1);
00291 Array<int>& occDisNeg = hstate_->getNegOccurenceArray(predIdx + 1);
00292 Array<int>& occCont = hstate_->getContDisOccurenceArray(predIdx + 1) ;
00293
00294
00295 for(int j = 0; j < occDisPos.size(); j++)
00296 {
00297 int disClauseIdx = occDisPos[j];
00298 wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00299
00300 if(numTrueLits_[disClauseIdx][chainIdx] > 1)
00301 {
00302 wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00303 }
00304 else if(numTrueLits_[disClauseIdx][chainIdx] == 1 && !bBak)
00305 {
00306 wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00307 }
00308 }
00309
00310
00311 for(int j = 0; j < occDisNeg.size(); j++)
00312 {
00313 int disClauseIdx = occDisNeg[j];
00314 if(numTrueLits_[disClauseIdx][chainIdx] > 1)
00315 {
00316 wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00317 }
00318 else if(numTrueLits_[disClauseIdx][chainIdx] == 1 && bBak)
00319 {
00320 wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00321 }
00322 wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00323 }
00324
00325 for (int j = 0; j < occCont.size(); j++) {
00326 int hybridClauseIdx = occCont[j];
00327 double contPart = hybridClauseContPartValueMCMC_[hybridClauseIdx][chainIdx];
00328 double wt = hstate_->hybridWts_[hybridClauseIdx];
00329
00330 truthValues_[predIdx][chainIdx] = true;
00331 double disPartAsTrue = HybridClauseDisPartValue(hybridClauseIdx,chainIdx);
00332 truthValues_[predIdx][chainIdx] = false;
00333 double disPartAsFalse = HybridClauseDisPartValue(hybridClauseIdx,chainIdx);
00334 truthValues_[predIdx][chainIdx] = bBak;
00335
00336 wtDisAsTrue += wt*disPartAsTrue*contPart;
00337 wtDisAsFalse += wt*disPartAsFalse*contPart;
00338 }
00339
00340
00341 double wtDiff = (wtDisAsFalse - wtDisAsTrue) * invTemp;
00342 double prob;
00343 if (wtDiff > 403.429)
00344 {
00345 prob = 0;
00346 }
00347 else if (wtDiff < -403.429)
00348 {
00349 prob = 1;
00350 }
00351 else
00352 {
00353 prob = 1 / ( 1 + exp(wtDiff));
00354 }
00355 return prob;
00356 }
00357 else
00358 {
00359 GroundPredicate* gndPred = hstate_->getGndPred(predIdx);
00360 double wtDisAsTrue = 0, wtDisAsFalse = 0;
00361 bool bBak = gndPred->getTruthValue();
00362
00363 Array<int>& occDisPos = hstate_->getPosOccurenceArray(predIdx + 1);
00364 Array<int>& occDisNeg = hstate_->getNegOccurenceArray(predIdx + 1);
00365 Array<int>& occCont = hstate_->getContDisOccurenceArray(predIdx + 1) ;
00366
00367 gndPred->setTruthValue(true);
00368
00369 for(int j = 0; j < occDisPos.size(); j++)
00370 {
00371 int disClauseIdx = occDisPos[j];
00372 wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00373 }
00374
00375 for(int j = 0; j < occDisNeg.size(); j++)
00376 {
00377 int disClauseIdx = occDisNeg[j];
00378 for(int k = 0; k < hstate_->clause_[disClauseIdx].size(); k++)
00379 {
00380 int lit = hstate_->clause_[disClauseIdx][k];
00381 GroundPredicate* gndPredtmp = hstate_->getGndPred(abs(lit)-1);
00382 if ((lit > 0) == gndPredtmp->getTruthValue())
00383 {
00384 wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00385 break;
00386 }
00387 }
00388 }
00389
00390 for (int j = 0; j < occCont.size(); j++)
00391 {
00392 int contClauseIdx = occCont[j];
00393 wtDisAsTrue += hstate_->HybridClauseValue(contClauseIdx);
00394
00395 }
00396
00397
00398 gndPred->setTruthValue(false);
00399 for(int j = 0; j < occDisNeg.size(); j++)
00400 {
00401 int disClauseIdx = occDisNeg[j];
00402 wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00403 }
00404
00405 for(int j = 0; j < occDisPos.size(); j++)
00406 {
00407 int disClauseIdx = occDisPos[j];
00408 for(int k = 0; k < hstate_->clause_[disClauseIdx].size(); k++)
00409 {
00410 int lit = hstate_->clause_[disClauseIdx][k];
00411 GroundPredicate* gndPredtmp = hstate_->getGndPred(abs(lit)-1);
00412 if ((lit > 0) == gndPredtmp->getTruthValue())
00413 {
00414 wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00415 break;
00416 }
00417 }
00418 }
00419
00420 for (int j = 0; j < occCont.size(); j++)
00421 {
00422 int contClauseIdx = occCont[j];
00423 wtDisAsFalse += hstate_->HybridClauseValue(contClauseIdx);
00424 }
00425
00426 gndPred->setTruthValue(bBak);
00427
00428 return 1.0 / ( 1.0 + exp((wtDisAsFalse - wtDisAsTrue) * invTemp));
00429 }
00430 }
00431
00435 virtual void printNetwork(ostream& out)
00436 {
00437 }
00438
00442 void printProbabilities(ostream& out)
00443 {
00444 for (int i = 0; i < state_->getNumAtoms(); i++)
00445 {
00446 double prob = getProbTrue(i);
00447
00448
00449 prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00450 state_->printGndPred(i, out);
00451 out << " " << prob << endl;
00452 }
00453 }
00454
00468 void getChangedPreds(vector<string>& changedPreds, vector<float>& probs,
00469 vector<float>& oldProbs, const float& probDelta)
00470 {
00471 changedPreds.clear();
00472 probs.clear();
00473 int numAtoms = state_->getNumAtoms();
00474
00475 oldProbs.resize(numAtoms, 0);
00476 for (int i = 0; i < numAtoms; i++)
00477 {
00478 double prob = getProbTrue(i);
00479 if (abs(prob - oldProbs[i]) > probDelta)
00480 {
00481
00482
00483 oldProbs[i] = prob;
00484
00485 prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00486 ostringstream oss(ostringstream::out);
00487 state_->printGndPred(i, oss);
00488 changedPreds.push_back(oss.str());
00489 probs.push_back(prob);
00490 }
00491 }
00492 }
00493
00494
00495 double getProbabilityH(GroundPredicate* const& gndPred)
00496 {
00497 int idx = hstate_->getGndPredIndex(gndPred);
00498 double prob = 0.0;
00499 if (idx >= 0) prob = getProbTrue(idx);
00500
00501 return (prob*10000 + 1/2.0)/(10000 + 1.0);
00502 }
00503
00510 double getProbability(GroundPredicate* const& gndPred)
00511 {
00512 int idx = state_->getGndPredIndex(gndPred);
00513 double prob = 0.0;
00514 if (idx >= 0) prob = getProbTrue(idx);
00515
00516 return (prob*10000 + 1/2.0)/(10000 + 1.0);
00517 }
00518
00522 void printTruePreds(ostream& out)
00523 {
00524 for (int i = 0; i < state_->getNumAtoms(); i++)
00525 {
00526 double prob = getProbTrue(i);
00527
00528
00529 prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00530 if (prob >= 0.5) state_->printGndPred(i, out);
00531 }
00532 }
00533
00534 void printTruePredsH(ostream& out)
00535 {
00536 for (int i = 0; i < hstate_->getNumAtoms(); i++)
00537 {
00538 double prob = getProbTrue(i);
00539
00540
00541 prob = (prob * 10000 + 1/2.0) / (10000 + 1.0);
00542 if (prob >= 0.5) hstate_->printGndPred(i, out);
00543 }
00544 }
00545
00546
00547 protected:
00548
00557 void initTruthValuesAndWts(const int& numChains)
00558 {
00559 int numPreds = state_->getNumAtoms();
00560 truthValues_.growToSize(numPreds);
00561 wtsWhenFalse_.growToSize(numPreds);
00562 wtsWhenTrue_.growToSize(numPreds);
00563 for (int i = 0; i < numPreds; i++)
00564 {
00565 truthValues_[i].growToSize(numChains, false);
00566 wtsWhenFalse_[i].growToSize(numChains, 0);
00567 wtsWhenTrue_[i].growToSize(numChains, 0);
00568 }
00569
00570 int numClauses = state_->getNumClauses();
00571 numTrueLits_.growToSize(numClauses);
00572 for (int i = 0; i < numClauses; i++)
00573 {
00574 numTrueLits_[i].growToSize(numChains, 0);
00575 }
00576 }
00577
00582 void initNumTrue()
00583 {
00584 int numPreds = state_->getNumAtoms();
00585 numTrue_.growToSize(numPreds);
00586 for (int i = 0; i < numTrue_.size(); i++)
00587 numTrue_[i] = 0;
00588 }
00589
00596 void initNumTrueLits(const int& numChains)
00597 {
00598
00599 if (numChains == 1) state_->resetMakeBreakCostWatch();
00600 for (int i = 0; i < state_->getNumClauses(); i++)
00601 {
00602 GroundClause* gndClause = state_->getGndClause(i);
00603 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00604 {
00605 const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1;
00606 const bool sense = gndClause->getGroundPredicateSense(j);
00607 if (numChains > 1)
00608 {
00609 for (int c = 0; c < numChains; c++)
00610 {
00611 if (truthValues_[atomIdx][c] == sense)
00612 {
00613 numTrueLits_[i][c]++;
00614 assert(numTrueLits_[i][c] <= state_->getNumAtoms());
00615 }
00616 }
00617 }
00618 else
00619 {
00620 GroundPredicate* gndPred = state_->getGndPred(atomIdx);
00621 if (gndPred->getTruthValue() == sense)
00622 state_->incrementNumTrueLits(i);
00623 assert(state_->getNumTrueLits(i) <= state_->getNumAtoms());
00624 }
00625 }
00626 }
00627 }
00628
00636 void randomInitGndPredsTruthValues(const int& numChains)
00637 {
00638 for (int c = 0; c < numChains; c++)
00639 {
00640 if (mcmcdebug) cout << "Chain " << c << ":" << endl;
00641
00642 for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++)
00643 {
00644
00645 if (state_->getDomain()->getBlockEvidence(i))
00646 {
00647
00648 setOthersInBlockToFalse(c, -1, i);
00649 continue;
00650 }
00651
00652 bool ok = false;
00653 while (!ok)
00654 {
00655 const Predicate* pred = state_->getDomain()->getRandomPredInBlock(i);
00656 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00657 int idx = state_->getIndexOfGroundPredicate(gndPred);
00658
00659 delete gndPred;
00660 delete pred;
00661
00662 if (idx >= 0)
00663 {
00664
00665 if (numChains_ > 1)
00666 truthValues_[idx][c] = true;
00667 else
00668 {
00669 GroundPredicate* gndPred = state_->getGndPred(i);
00670 gndPred->setTruthValue(true);
00671 }
00672 setOthersInBlockToFalse(c, idx, i);
00673 ok = true;
00674 }
00675 }
00676 }
00677
00678
00679 for (int i = 0; i < truthValues_.size(); i++)
00680 {
00681
00682 if (state_->getBlockIndex(i) == -1)
00683 {
00684 bool tv = genTruthValueForProb(0.5);
00685
00686 if (numChains_ > 1)
00687 truthValues_[i][c] = tv;
00688 else
00689 {
00690 GroundPredicate* gndPred = state_->getGndPred(i);
00691 gndPred->setTruthValue(tv);
00692 }
00693 if (mcmcdebug) cout << "Pred " << i << " set to " << tv << endl;
00694 }
00695 }
00696 }
00697 }
00698
00705 bool genTruthValueForProb(const double& p)
00706 {
00707 if (p == 1.0) return true;
00708 if (p == 0.0) return false;
00709 bool r = random() <= p*RAND_MAX;
00710 return r;
00711 }
00712
00722 double getProbabilityOfPred(const int& predIdx, const int& chainIdx,
00723 const double& invTemp)
00724 {
00725
00726 if (numChains_ > 1)
00727 {
00728 return 1.0 /
00729 ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] -
00730 wtsWhenTrue_[predIdx][chainIdx]) *
00731 invTemp));
00732 }
00733 else
00734 {
00735 GroundPredicate* gndPred = state_->getGndPred(predIdx);
00736 return 1.0 /
00737 ( 1.0 + exp((gndPred->getWtWhenFalse() -
00738 gndPred->getWtWhenTrue()) *
00739 invTemp));
00740 }
00741 }
00742
00751 void setOthersInBlockToFalse(const int& chainIdx, const int& atomIdx,
00752 const int& blockIdx)
00753 {
00754 int blockSize = state_->getDomain()->getBlockSize(blockIdx);
00755 for (int i = 0; i < blockSize; i++)
00756 {
00757 const Predicate* pred = state_->getDomain()->getPredInBlock(i, blockIdx);
00758 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00759 int idx = state_->getIndexOfGroundPredicate(gndPred);
00760
00761 delete gndPred;
00762 delete pred;
00763
00764
00765 if (idx >= 0 && idx != atomIdx)
00766 truthValues_[idx][chainIdx] = false;
00767 }
00768 }
00769
00780 void performGibbsStep(const int& chainIdx, const bool& burningIn,
00781 GroundPredicateHashArray& affectedGndPreds,
00782 Array<int>& affectedGndPredIndices)
00783 {
00784 if (mcmcdebug) cout << "Gibbs step" << endl;
00785
00786
00787 for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++)
00788 {
00789
00790 if (state_->getDomain()->getBlockEvidence(i)) continue;
00791
00792 int chosen = gibbsSampleFromBlock(chainIdx, i, 1);
00793
00794 bool truthValue;
00795 const Predicate* pred = state_->getDomain()->getPredInBlock(chosen, i);
00796 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00797 int idx = state_->getIndexOfGroundPredicate(gndPred);
00798
00799 delete gndPred;
00800 delete pred;
00801
00802
00803 if (idx >= 0)
00804 {
00805 gndPred = state_->getGndPred(idx);
00806 if (numChains_ > 1) truthValue = truthValues_[idx][chainIdx];
00807 else truthValue = gndPred->getTruthValue();
00808
00809
00810 if (!truthValue)
00811 {
00812 int blockSize = state_->getDomain()->getBlockSize(i);
00813 for (int j = 0; j < blockSize; j++)
00814 {
00815
00816 bool otherTruthValue;
00817 const Predicate* otherPred =
00818 state_->getDomain()->getPredInBlock(j, i);
00819 GroundPredicate* otherGndPred =
00820 new GroundPredicate((Predicate*)otherPred);
00821 int otherIdx = state_->getIndexOfGroundPredicate(gndPred);
00822
00823 delete otherGndPred;
00824 delete otherPred;
00825
00826
00827 if (otherIdx >= 0)
00828 {
00829 otherGndPred = state_->getGndPred(otherIdx);
00830 if (numChains_ > 1)
00831 otherTruthValue = truthValues_[otherIdx][chainIdx];
00832 else
00833 otherTruthValue = otherGndPred->getTruthValue();
00834 if (otherTruthValue)
00835 {
00836
00837 if (numChains_ > 1)
00838 truthValues_[otherIdx][chainIdx] = false;
00839 else
00840 otherGndPred->setTruthValue(false);
00841
00842 affectedGndPreds.clear();
00843 affectedGndPredIndices.clear();
00844 gndPredFlippedUpdates(otherIdx, chainIdx, affectedGndPreds,
00845 affectedGndPredIndices);
00846 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00847 chainIdx);
00848 }
00849 }
00850 }
00851
00852
00853 if (numChains_ > 1) truthValues_[idx][chainIdx] = true;
00854 else gndPred->setTruthValue(true);
00855 affectedGndPreds.clear();
00856 affectedGndPredIndices.clear();
00857 gndPredFlippedUpdates(idx, chainIdx, affectedGndPreds,
00858 affectedGndPredIndices);
00859 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00860 chainIdx);
00861 }
00862
00863
00864 if (!burningIn) numTrue_[idx]++;
00865 }
00866 }
00867
00868
00869 for (int i = 0; i < state_->getNumAtoms(); i++)
00870 {
00871
00872 if (state_->getBlockIndex(i) >= 0) continue;
00873
00874 if (mcmcdebug)
00875 {
00876 cout << "Chain " << chainIdx << ": Probability of pred "
00877 << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl;
00878 }
00879
00880 bool newAssignment
00881 = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1));
00882
00883
00884 bool truthValue;
00885 GroundPredicate* gndPred = state_->getGndPred(i);
00886 if (numChains_ > 1) truthValue = truthValues_[i][chainIdx];
00887 else truthValue = gndPred->getTruthValue();
00888
00889 if (newAssignment != truthValue)
00890 {
00891 if (mcmcdebug)
00892 {
00893 cout << "Chain " << chainIdx << ": Changing truth value of pred "
00894 << i << " to " << newAssignment << endl;
00895 }
00896
00897 if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment;
00898 else gndPred->setTruthValue(newAssignment);
00899 affectedGndPreds.clear();
00900 affectedGndPredIndices.clear();
00901 gndPredFlippedUpdates(i, chainIdx, affectedGndPreds,
00902 affectedGndPredIndices);
00903 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00904 chainIdx);
00905 }
00906
00907
00908
00909 if (!burningIn && newAssignment) numTrue_[i]++;
00910 }
00911
00912 if (!burningIn && trackClauseTrueCnts_)
00913 state_->getNumClauseGndings(clauseTrueCnts_, true);
00914
00915 if (mcmcdebug) cout << "End of Gibbs step" << endl;
00916 }
00917
00926 void updateWtsForGndPredsH(GroundPredicateHashArray& gndPreds,
00927 Array<int>& gndPredIndices, const int& chainIdx)
00928 {
00929 if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
00930
00931 for (int g = 0; g < gndPreds.size(); g++)
00932 {
00933 double wtIfNoChange = 0, wtIfInverted = 0, wt;
00934
00935 Array<int>& negGndClauses =
00936 hstate_->getNegOccurenceArray(gndPredIndices[g] + 1);
00937 Array<int>& posGndClauses =
00938 hstate_->getPosOccurenceArray(gndPredIndices[g] + 1);
00939
00940 int gndClauseIdx;
00941 bool sense;
00942 if (mcmcdebug)
00943 {
00944 cout << "Ground clauses in which pred " << g << " occurs neg.: "
00945 << negGndClauses.size() << endl;
00946 cout << "Ground clauses in which pred " << g << " occurs pos.: "
00947 << posGndClauses.size() << endl;
00948 }
00949
00950 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00951 {
00952 if (i < negGndClauses.size())
00953 {
00954 gndClauseIdx = negGndClauses[i];
00955 if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
00956 sense = false;
00957 }
00958 else
00959 {
00960 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00961 if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
00962 sense = true;
00963 }
00964
00965 GroundClause* gndClause = hstate_->getGndClause(gndClauseIdx);
00966 if (gndClause->isHardClause())
00967 wt = hstate_->getClauseCost(gndClauseIdx);
00968 else
00969 wt = gndClause->getWt();
00970
00971 int numSatLiterals;
00972 if (numChains_ > 1)
00973 numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
00974 else
00975 numSatLiterals = hstate_->getNumTrueLits(gndClauseIdx);
00976 if (numSatLiterals > 1)
00977 {
00978
00979
00980 if (wt > 0)
00981 {
00982 wtIfNoChange += wt;
00983 wtIfInverted += wt;
00984 }
00985 }
00986 else
00987 if (numSatLiterals == 1)
00988 {
00989 if (wt > 0) wtIfNoChange += wt;
00990
00991 bool truthValue;
00992 if (numChains_ > 1)
00993 truthValue = truthValues_[gndPredIndices[g]][chainIdx];
00994 else
00995 truthValue = gndPreds[g]->getTruthValue();
00996
00997 if (truthValue == sense)
00998 {
00999
01000 if (wt < 0) wtIfInverted += abs(wt);
01001 }
01002 else
01003 {
01004
01005 if (wt > 0) wtIfInverted += wt;
01006 }
01007 }
01008 else
01009 if (numSatLiterals == 0)
01010 {
01011
01012 if (wt > 0) wtIfInverted += wt;
01013 else if (wt < 0) wtIfNoChange += abs(wt);
01014 }
01015 }
01016
01017 if (mcmcdebug)
01018 {
01019 cout << "wtIfNoChange of pred " << g << ": "
01020 << wtIfNoChange << endl;
01021 cout << "wtIfInverted of pred " << g << ": "
01022 << wtIfInverted << endl;
01023 }
01024
01025
01026 if (numChains_ > 1)
01027 {
01028 if (truthValues_[gndPredIndices[g]][chainIdx])
01029 {
01030 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01031 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01032 }
01033 else
01034 {
01035 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01036 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01037 }
01038 }
01039 else
01040 {
01041 if (gndPreds[g]->getTruthValue())
01042 {
01043 gndPreds[g]->setWtWhenTrue(wtIfNoChange);
01044 gndPreds[g]->setWtWhenFalse(wtIfInverted);
01045 }
01046 else
01047 {
01048 gndPreds[g]->setWtWhenFalse(wtIfNoChange);
01049 gndPreds[g]->setWtWhenTrue(wtIfInverted);
01050 }
01051 }
01052 }
01053 if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
01054 }
01055
01056
01057
01058
01067 void updateWtsForGndPreds(GroundPredicateHashArray& gndPreds,
01068 Array<int>& gndPredIndices,
01069 const int& chainIdx)
01070 {
01071 if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
01072
01073 for (int g = 0; g < gndPreds.size(); g++)
01074 {
01075 double wtIfNoChange = 0, wtIfInverted = 0, wt;
01076
01077 Array<int>& negGndClauses =
01078 state_->getNegOccurenceArray(gndPredIndices[g] + 1);
01079 Array<int>& posGndClauses =
01080 state_->getPosOccurenceArray(gndPredIndices[g] + 1);
01081 int gndClauseIdx;
01082 bool sense;
01083
01084 if (mcmcdebug)
01085 {
01086 cout << "Ground clauses in which pred " << g << " occurs neg.: "
01087 << negGndClauses.size() << endl;
01088 cout << "Ground clauses in which pred " << g << " occurs pos.: "
01089 << posGndClauses.size() << endl;
01090 }
01091
01092 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
01093 {
01094 if (i < negGndClauses.size())
01095 {
01096 gndClauseIdx = negGndClauses[i];
01097 if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
01098 sense = false;
01099 }
01100 else
01101 {
01102 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
01103 if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
01104 sense = true;
01105 }
01106
01107 GroundClause* gndClause = state_->getGndClause(gndClauseIdx);
01108 if (gndClause->isHardClause())
01109 wt = state_->getClauseCost(gndClauseIdx);
01110 else
01111 wt = gndClause->getWt();
01112
01113 int numSatLiterals;
01114 if (numChains_ > 1)
01115 numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
01116 else
01117 numSatLiterals = state_->getNumTrueLits(gndClauseIdx);
01118
01119 if (numSatLiterals > 1)
01120 {
01121
01122
01123 if (wt > 0)
01124 {
01125 wtIfNoChange += wt;
01126 wtIfInverted += wt;
01127 }
01128 }
01129 else
01130 if (numSatLiterals == 1)
01131 {
01132 if (wt > 0) wtIfNoChange += wt;
01133
01134 bool truthValue;
01135 if (numChains_ > 1)
01136 truthValue = truthValues_[gndPredIndices[g]][chainIdx];
01137 else
01138 truthValue = gndPreds[g]->getTruthValue();
01139
01140 if (truthValue == sense)
01141 {
01142
01143 if (wt < 0) wtIfInverted += abs(wt);
01144 }
01145 else
01146 {
01147
01148 if (wt > 0) wtIfInverted += wt;
01149 }
01150 }
01151 else
01152 if (numSatLiterals == 0)
01153 {
01154
01155 if (wt > 0) wtIfInverted += wt;
01156 else if (wt < 0) wtIfNoChange += abs(wt);
01157 }
01158 }
01159
01160 if (mcmcdebug)
01161 {
01162 cout << "wtIfNoChange of pred " << g << ": "
01163 << wtIfNoChange << endl;
01164 cout << "wtIfInverted of pred " << g << ": "
01165 << wtIfInverted << endl;
01166 }
01167
01168
01169 if (numChains_ > 1)
01170 {
01171 if (truthValues_[gndPredIndices[g]][chainIdx])
01172 {
01173 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01174 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01175 }
01176 else
01177 {
01178 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01179 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01180 }
01181 }
01182 else
01183 {
01184 if (gndPreds[g]->getTruthValue())
01185 {
01186 gndPreds[g]->setWtWhenTrue(wtIfNoChange);
01187 gndPreds[g]->setWtWhenFalse(wtIfInverted);
01188 }
01189 else
01190 {
01191 gndPreds[g]->setWtWhenFalse(wtIfNoChange);
01192 gndPreds[g]->setWtWhenTrue(wtIfInverted);
01193 }
01194 }
01195 }
01196 if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
01197 }
01198
01208 int gibbsSampleFromBlock(const int& chainIdx, const int& blockIndex,
01209 const double& invTemp)
01210 {
01211 Array<double> numerators;
01212 double denominator = 0;
01213
01214 int blockSize = state_->getDomain()->getBlockSize(blockIndex);
01215 for (int i = 0; i < blockSize; i++)
01216 {
01217 const Predicate* pred =
01218 state_->getDomain()->getPredInBlock(i, blockIndex);
01219 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
01220 int idx = state_->getIndexOfGroundPredicate(gndPred);
01221
01222 delete gndPred;
01223 delete pred;
01224
01225
01226 double prob = 0.0;
01227
01228 if (idx >= 0)
01229 prob = getProbabilityOfPred(idx, chainIdx, invTemp);
01230
01231 numerators.append(prob);
01232 denominator += prob;
01233 }
01234
01235 double r = random();
01236 double numSum = 0.0;
01237 for (int i = 0; i < blockSize; i++)
01238 {
01239 numSum += numerators[i];
01240 if (r < ((numSum / denominator) * RAND_MAX))
01241 {
01242 return i;
01243 }
01244 }
01245 return blockSize - 1;
01246 }
01247
01256 void gndPredFlippedUpdates(const int& gndPredIdx, const int& chainIdx,
01257 GroundPredicateHashArray& affectedGndPreds,
01258 Array<int>& affectedGndPredIndices)
01259 {
01260 if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl;
01261 int numAtoms = state_->getNumAtoms();
01262 GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
01263 affectedGndPreds.append(gndPred, numAtoms);
01264 affectedGndPredIndices.append(gndPredIdx);
01265 assert(affectedGndPreds.size() <= numAtoms);
01266
01267 Array<int>& negGndClauses =
01268 state_->getNegOccurenceArray(gndPredIdx + 1);
01269 Array<int>& posGndClauses =
01270 state_->getPosOccurenceArray(gndPredIdx + 1);
01271 int gndClauseIdx;
01272 GroundClause* gndClause;
01273 bool sense;
01274
01275
01276 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
01277 {
01278 if (i < negGndClauses.size())
01279 {
01280 gndClauseIdx = negGndClauses[i];
01281 sense = false;
01282 }
01283 else
01284 {
01285 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
01286 sense = true;
01287 }
01288 gndClause = state_->getGndClause(gndClauseIdx);
01289
01290
01291 if (numChains_ > 1)
01292 {
01293 if (truthValues_[gndPredIdx][chainIdx] == sense)
01294 numTrueLits_[gndClauseIdx][chainIdx]++;
01295 else
01296 numTrueLits_[gndClauseIdx][chainIdx]--;
01297 }
01298 else
01299 {
01300 if (gndPred->getTruthValue() == sense)
01301 state_->incrementNumTrueLits(gndClauseIdx);
01302 else
01303 state_->decrementNumTrueLits(gndClauseIdx);
01304 }
01305
01306 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
01307 {
01308 const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr();
01309 GroundPredicate* pred =
01310 (GroundPredicate*)gndClause->getGroundPredicate(j,
01311 (GroundPredicateHashArray*)gpha);
01312 affectedGndPreds.append(pred, numAtoms);
01313 affectedGndPredIndices.append(
01314 abs(gndClause->getGroundPredicateIndex(j)) - 1);
01315 assert(affectedGndPreds.size() <= numAtoms);
01316 }
01317 }
01318 if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl;
01319 }
01320
01321 double getProbTrue(const int& predIdx) const { return numTrue_[predIdx]; }
01322
01323 void setProbTrue(const int& predIdx, const double& p)
01324 {
01325 assert(p >= 0);
01326 numTrue_[predIdx] = p;
01327 }
01328
01335 void saveLowStateToChain(const int& chainIdx)
01336 {
01337 for (int i = 0; i < state_->getNumAtoms(); i++)
01338 truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1);
01339 }
01340
01346 void setMCMCParameters(MCMCParams* params)
01347 {
01348
01349 numChains_ = params->numChains;
01350 burnMinSteps_ = params->burnMinSteps;
01351 burnMaxSteps_ = params->burnMaxSteps;
01352 minSteps_ = params->minSteps;
01353 maxSteps_ = params->maxSteps;
01354 maxSeconds_ = params->maxSeconds;
01355 }
01356
01357 void scaleSamples(double factor)
01358 {
01359 minSteps_ = (int)(minSteps_ * factor);
01360 maxSteps_ = (int)(maxSteps_ * factor);
01361 }
01362
01363 public:
01364
01365 Array<Array<double> > truthValuesCont_;
01366
01367 Array<Array<bool> > hybridClauseDisPartValueMCMC_;
01368
01369 Array<Array<double> > hybridClauseContPartValueMCMC_;
01370
01371 protected:
01372
01374
01375 int numChains_;
01376
01377 int burnMinSteps_;
01378
01379 int burnMaxSteps_;
01380
01381 int minSteps_;
01382
01383 int maxSteps_;
01384
01385 int maxSeconds_;
01387
01388
01389 Array<Array<bool> > truthValues_;
01390
01391 Array<Array<double> > wtsWhenFalse_;
01392
01393 Array<Array<double> > wtsWhenTrue_;
01394
01395
01396
01397 Array<double> numTrue_;
01398
01399
01400
01401 Array<Array<int> > numTrueLits_;
01402 };
01403
01404 #endif