#include <mcmc.h>
Inheritance diagram for MCMC:
Public Member Functions | |
MCMC (VariableState *state, long int seed, const bool &trackClauseTrueCnts, MCMCParams *params) | |
Constructor. | |
~MCMC () | |
Destructor. | |
void | printProbabilities (ostream &out) |
Prints the probabilities of each predicate to a stream. | |
void | getPredsWithNonZeroProb (vector< string > &nonZeroPreds, vector< float > &probs) |
Puts the predicates with non-zero probability in string form and the corresponding probabilities of each predicate in two vectors. | |
double | getProbability (GroundPredicate *const &gndPred) |
Gets the probability of a ground predicate. | |
void | printTruePreds (ostream &out) |
Prints each predicate with a probability of 0.5 or greater to a stream. | |
Protected Member Functions | |
void | initTruthValuesAndWts (const int &numChains) |
Initializes truth values and weights in the ground preds. | |
void | initNumTrue () |
Initializes structure for holding number of times a predicate was set to true. | |
void | initNumTrueLits (const int &numChains) |
Initializes the number of true lits in each clause in each chain. | |
void | randomInitGndPredsTruthValues (const int &numChains) |
Randomly initializes the ground predicate truth values, taking blocks into account. | |
bool | genTruthValueForProb (const double &p) |
Generates a truth value based on a probability. | |
double | getProbabilityOfPred (const int &predIdx, const int &chainIdx, const double &invTemp) |
Computes the probability of a ground predicate in a chain. | |
void | setOthersInBlockToFalse (const int &chainIdx, const int &atomIdx, const int &blockIdx) |
Sets the truth values of all atoms for a given chain in a block except for the one given. | |
void | performGibbsStep (const int &chainIdx, const bool &burningIn, GroundPredicateHashArray &affectedGndPreds, Array< int > &affectedGndPredIndices) |
Performs one step of Gibbs sampling in one chain. | |
void | updateWtsForGndPreds (GroundPredicateHashArray &gndPreds, Array< int > &gndPredIndices, const int &chainIdx) |
Updates the weights of affected ground predicates. | |
int | gibbsSampleFromBlock (const int &chainIdx, const Array< int > &block, const double &invTemp) |
Chooses an atom from a block according to their probabilities. | |
void | gndPredFlippedUpdates (const int &gndPredIdx, const int &chainIdx, GroundPredicateHashArray &affectedGndPreds, Array< int > &affectedGndPredIndices) |
Updates information when a ground predicate is flipped and retrieves the Markov blanket of the ground predicate. | |
double | getProbTrue (const int &predIdx) const |
void | setProbTrue (const int &predIdx, const double &p) |
void | saveLowStateToChain (const int &chainIdx) |
The atom assignment in the best state is saved to a chain in the ground predicates. | |
void | setMCMCParameters (MCMCParams *params) |
Sets the user-set parameters for this MCMC algorithm. | |
Protected Attributes | |
int | numChains_ |
int | burnMinSteps_ |
int | burnMaxSteps_ |
int | minSteps_ |
int | maxSteps_ |
int | maxSeconds_ |
Array< Array< bool > > | truthValues_ |
Array< Array< double > > | wtsWhenFalse_ |
Array< Array< double > > | wtsWhenTrue_ |
Array< double > | numTrue_ |
Array< Array< int > > | numTrueLits_ |
This class does not implement all pure virtual functions of Inference and is thus an abstract class.
Definition at line 79 of file mcmc.h.
MCMC::MCMC | ( | VariableState * | state, | |
long int | seed, | |||
const bool & | trackClauseTrueCnts, | |||
MCMCParams * | params | |||
) | [inline] |
Constructor.
User-set parameters are set.
Definition at line 89 of file mcmc.h.
References MCMCParams::burnMaxSteps, burnMaxSteps_, MCMCParams::burnMinSteps, burnMinSteps_, MCMCParams::maxSeconds, maxSeconds_, MCMCParams::maxSteps, maxSteps_, MCMCParams::minSteps, minSteps_, MCMCParams::numChains, and numChains_.
00091 : Inference(state, seed, trackClauseTrueCnts) 00092 { 00093 // User-set parameters 00094 numChains_ = params->numChains; 00095 burnMinSteps_ = params->burnMinSteps; 00096 burnMaxSteps_ = params->burnMaxSteps; 00097 minSteps_ = params->minSteps; 00098 maxSteps_ = params->maxSteps; 00099 maxSeconds_ = params->maxSeconds; 00100 }
void MCMC::getPredsWithNonZeroProb | ( | vector< string > & | nonZeroPreds, | |
vector< float > & | probs | |||
) | [inline, virtual] |
Puts the predicates with non-zero probability in string form and the corresponding probabilities of each predicate in two vectors.
nonZeroPreds | Predicates with non-zero probability are put here. | |
probs | The probabilities corresponding to the predicates in nonZeroPreds are put here. |
Implements Inference.
Definition at line 131 of file mcmc.h.
References VariableState::getNumAtoms(), getProbTrue(), VariableState::printGndPred(), and Inference::state_.
00133 { 00134 nonZeroPreds.clear(); 00135 probs.clear(); 00136 for (int i = 0; i < state_->getNumAtoms(); i++) 00137 { 00138 double prob = getProbTrue(i); 00139 if (prob > 0) 00140 { 00141 // Uniform smoothing 00142 prob = (prob*10000 + 1/2.0)/(10000 + 1.0); 00143 ostringstream oss(ostringstream::out); 00144 state_->printGndPred(i, oss); 00145 nonZeroPreds.push_back(oss.str()); 00146 probs.push_back(prob); 00147 } 00148 } 00149 }
double MCMC::getProbability | ( | GroundPredicate *const & | gndPred | ) | [inline, virtual] |
Gets the probability of a ground predicate.
gndPred | GroundPredicate whose probability is being retrieved. |
Implements Inference.
Definition at line 157 of file mcmc.h.
References VariableState::getGndPredIndex(), getProbTrue(), and Inference::state_.
00158 { 00159 int idx = state_->getGndPredIndex(gndPred); 00160 double prob = 0.0; 00161 if (idx >= 0) prob = getProbTrue(idx); 00162 // Uniform smoothing 00163 return (prob*10000 + 1/2.0)/(10000 + 1.0); 00164 }
void MCMC::initTruthValuesAndWts | ( | const int & | numChains | ) | [inline, protected] |
Initializes truth values and weights in the ground preds.
numChains | Number of chains for which the values should be initialized. | |
startIdx | All predicates with index greater than or equal to this will be initialized. |
Definition at line 191 of file mcmc.h.
References VariableState::getNumAtoms(), VariableState::getNumClauses(), Array< Type >::growToSize(), numTrueLits_, Inference::state_, truthValues_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
00192 { 00193 int numPreds = state_->getNumAtoms(); 00194 truthValues_.growToSize(numPreds); 00195 wtsWhenFalse_.growToSize(numPreds); 00196 wtsWhenTrue_.growToSize(numPreds); 00197 for (int i = 0; i < numPreds; i++) 00198 { 00199 truthValues_[i].growToSize(numChains, false); 00200 wtsWhenFalse_[i].growToSize(numChains, 0); 00201 wtsWhenTrue_[i].growToSize(numChains, 0); 00202 } 00203 00204 int numClauses = state_->getNumClauses(); 00205 numTrueLits_.growToSize(numClauses); 00206 for (int i = 0; i < numClauses; i++) 00207 { 00208 numTrueLits_[i].growToSize(numChains, 0); 00209 } 00210 }
void MCMC::initNumTrueLits | ( | const int & | numChains | ) | [inline, protected] |
Initializes the number of true lits in each clause in each chain.
numChains | Number of chains for which the initialization should take place. |
Definition at line 230 of file mcmc.h.
References VariableState::getAtomInClause(), VariableState::getGndClause(), VariableState::getNumAtoms(), VariableState::getNumClauses(), numTrueLits_, Inference::state_, and truthValues_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
00231 { 00232 for (int i = 0; i < state_->getNumClauses(); i++) 00233 { 00234 GroundClause* gndClause = state_->getGndClause(i); 00235 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++) 00236 { 00237 const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1; 00238 for (int c = 0; c < numChains; c++) 00239 { 00240 if (truthValues_[atomIdx][c] == gndClause->getGroundPredicateSense(j)) 00241 { 00242 numTrueLits_[i][c]++; 00243 assert(numTrueLits_[i][c] <= state_->getNumAtoms()); 00244 } 00245 } 00246 } 00247 } 00248 }
void MCMC::randomInitGndPredsTruthValues | ( | const int & | numChains | ) | [inline, protected] |
Randomly initializes the ground predicate truth values, taking blocks into account.
numChains | Number of chains for which the initialization should take place. |
Definition at line 257 of file mcmc.h.
References genTruthValueForProb(), VariableState::getBlockArray(), VariableState::getBlockEvidence(), VariableState::getBlockIndex(), VariableState::getNumBlocks(), setOthersInBlockToFalse(), Array< Type >::size(), Inference::state_, and truthValues_.
Referenced by GibbsSampler::init().
00258 { 00259 for (int c = 0; c < numChains; c++) 00260 { 00261 // For each block: select one to set to true 00262 for (int i = 0; i < state_->getNumBlocks(); i++) 00263 { 00264 // If evidence atom exists, then all others are false 00265 if (state_->getBlockEvidence(i)) 00266 { 00267 // If 2nd argument is -1, then all are set to false 00268 setOthersInBlockToFalse(c, -1, i); 00269 continue; 00270 } 00271 00272 Array<int>& block = state_->getBlockArray(i); 00273 int chosen = random() % block.size(); 00274 truthValues_[block[chosen]][c] = true; 00275 setOthersInBlockToFalse(c, chosen, i); 00276 } 00277 00278 // Random tv for all not in blocks 00279 for (int i = 0; i < truthValues_.size(); i++) 00280 { 00281 // Predicates in blocks have been handled above 00282 if (state_->getBlockIndex(i) == -1) 00283 { 00284 bool tv = genTruthValueForProb(0.5); 00285 truthValues_[i][c] = tv; 00286 } 00287 } 00288 } 00289 }
bool MCMC::genTruthValueForProb | ( | const double & | p | ) | [inline, protected] |
Generates a truth value based on a probability.
p | Number between 0 and 1. With probability p, truth value will be true and with probability 1 - p, it will be false. |
Definition at line 297 of file mcmc.h.
Referenced by SimulatedTempering::infer(), performGibbsStep(), and randomInitGndPredsTruthValues().
00298 { 00299 if (p == 1.0) return true; 00300 if (p == 0.0) return false; 00301 bool r = random() <= p*RAND_MAX; 00302 return r; 00303 }
double MCMC::getProbabilityOfPred | ( | const int & | predIdx, | |
const int & | chainIdx, | |||
const double & | invTemp | |||
) | [inline, protected] |
Computes the probability of a ground predicate in a chain.
predIdx | Index of predicate. | |
chainIdx | Index of chain. | |
invTemp | InvTemp used in simulated tempering. |
Definition at line 314 of file mcmc.h.
References VariableState::getGndPred(), GroundPredicate::getWtWhenFalse(), GroundPredicate::getWtWhenTrue(), numChains_, Inference::state_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by gibbsSampleFromBlock(), SimulatedTempering::infer(), and performGibbsStep().
00316 { 00317 // Different for multi-chain 00318 if (numChains_ > 1) 00319 { 00320 return 1.0 / 00321 ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] - 00322 wtsWhenTrue_[predIdx][chainIdx]) * 00323 invTemp)); 00324 } 00325 else 00326 { 00327 GroundPredicate* gndPred = state_->getGndPred(predIdx); 00328 return 1.0 / 00329 ( 1.0 + exp((gndPred->getWtWhenFalse() - 00330 gndPred->getWtWhenTrue()) * 00331 invTemp)); 00332 } 00333 }
void MCMC::setOthersInBlockToFalse | ( | const int & | chainIdx, | |
const int & | atomIdx, | |||
const int & | blockIdx | |||
) | [inline, protected] |
Sets the truth values of all atoms for a given chain in a block except for the one given.
chainIdx | Index of chain for which atoms are being set. | |
atomIdx | Index of atom in block exempt from being set to false. | |
blockIdx | Index of block whose atoms are set to false. |
Definition at line 343 of file mcmc.h.
References VariableState::getBlockArray(), Array< Type >::size(), Inference::state_, and truthValues_.
Referenced by randomInitGndPredsTruthValues().
00345 { 00346 Array<int>& block = state_->getBlockArray(blockIdx); 00347 for (int i = 0; i < block.size(); i++) 00348 { 00349 if (i != atomIdx) 00350 truthValues_[block[i]][chainIdx] = false; 00351 } 00352 }
void MCMC::performGibbsStep | ( | const int & | chainIdx, | |
const bool & | burningIn, | |||
GroundPredicateHashArray & | affectedGndPreds, | |||
Array< int > & | affectedGndPredIndices | |||
) | [inline, protected] |
Performs one step of Gibbs sampling in one chain.
chainIdx | Index of chain in which the Gibbs step is performed. | |
burningIn | If true, burning-in is occuring. Otherwise, false. | |
affectedGndPreds | Used to store GroundPredicates which are affected by changing truth values. | |
affectedGndPredIndices | Used to store indices of GroundPredicates which are affected by changing truth values. |
Definition at line 364 of file mcmc.h.
References Inference::clauseTrueCnts_, Array< Type >::clear(), HashArray< Type, HashFn, EqualFn >::clear(), genTruthValueForProb(), VariableState::getBlockArray(), VariableState::getBlockEvidence(), VariableState::getBlockIndex(), VariableState::getGndPred(), VariableState::getNumAtoms(), VariableState::getNumBlocks(), VariableState::getNumClauseGndings(), getProbabilityOfPred(), GroundPredicate::getTruthValue(), gibbsSampleFromBlock(), gndPredFlippedUpdates(), numChains_, numTrue_, GroundPredicate::setTruthValue(), Inference::state_, Inference::trackClauseTrueCnts_, truthValues_, and updateWtsForGndPreds().
Referenced by MCSAT::infer(), and GibbsSampler::infer().
00367 { 00368 if (mcmcdebug) cout << "Gibbs step" << endl; 00369 00370 // For each block: select one to set to true 00371 for (int i = 0; i < state_->getNumBlocks(); i++) 00372 { 00373 // If evidence atom exists, then all others stay false 00374 if (state_->getBlockEvidence(i)) continue; 00375 00376 Array<int>& block = state_->getBlockArray(i); 00377 // chosen is index in the block, block[chosen] is index in gndPreds_ 00378 int chosen = gibbsSampleFromBlock(chainIdx, block, 1); 00379 // Truth values are stored differently for multi-chain 00380 bool truthValue; 00381 GroundPredicate* gndPred = state_->getGndPred(block[chosen]); 00382 if (numChains_ > 1) truthValue = truthValues_[block[chosen]][chainIdx]; 00383 else truthValue = gndPred->getTruthValue(); 00384 // If chosen pred was false, then need to set previous true 00385 // one to false and update wts 00386 if (!truthValue) 00387 { 00388 for (int j = 0; j < block.size(); j++) 00389 { 00390 // Truth values are stored differently for multi-chain 00391 bool otherTruthValue; 00392 GroundPredicate* otherGndPred = state_->getGndPred(block[j]); 00393 if (numChains_ > 1) 00394 otherTruthValue = truthValues_[block[j]][chainIdx]; 00395 else 00396 otherTruthValue = otherGndPred->getTruthValue(); 00397 if (otherTruthValue) 00398 { 00399 // Truth values are stored differently for multi-chain 00400 if (numChains_ > 1) 00401 truthValues_[block[j]][chainIdx] = false; 00402 else 00403 otherGndPred->setTruthValue(false); 00404 00405 affectedGndPreds.clear(); 00406 affectedGndPredIndices.clear(); 00407 gndPredFlippedUpdates(block[j], chainIdx, affectedGndPreds, 00408 affectedGndPredIndices); 00409 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00410 chainIdx); 00411 } 00412 } 00413 // Set truth value and update wts for chosen atom 00414 // Truth values are stored differently for multi-chain 00415 if (numChains_ > 1) truthValues_[block[chosen]][chainIdx] = true; 00416 else gndPred->setTruthValue(true); 00417 affectedGndPreds.clear(); 00418 affectedGndPredIndices.clear(); 00419 gndPredFlippedUpdates(block[chosen], chainIdx, affectedGndPreds, 00420 affectedGndPredIndices); 00421 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00422 chainIdx); 00423 } 00424 00425 // If in actual gibbs sampling phase, track the num of times 00426 // the ground predicate is set to true 00427 if (!burningIn) numTrue_[block[chosen]]++; 00428 } 00429 00430 // Now go through all preds not in blocks 00431 for (int i = 0; i < state_->getNumAtoms(); i++) 00432 { 00433 // Predicates in blocks have been handled above 00434 if (state_->getBlockIndex(i) >= 0) continue; 00435 00436 if (mcmcdebug) 00437 { 00438 cout << "Chain " << chainIdx << ": Probability of pred " 00439 << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl; 00440 } 00441 00442 bool newAssignment 00443 = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1)); 00444 00445 // Truth values are stored differently for multi-chain 00446 bool truthValue; 00447 GroundPredicate* gndPred = state_->getGndPred(i); 00448 if (numChains_ > 1) truthValue = truthValues_[i][chainIdx]; 00449 else truthValue = gndPred->getTruthValue(); 00450 // If gndPred is flipped, do updates & find all affected gndPreds 00451 if (newAssignment != truthValue) 00452 { 00453 if (mcmcdebug) 00454 { 00455 cout << "Chain " << chainIdx << ": Changing truth value of pred " 00456 << i << " to " << newAssignment << endl; 00457 } 00458 00459 if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment; 00460 else gndPred->setTruthValue(newAssignment); 00461 affectedGndPreds.clear(); 00462 affectedGndPredIndices.clear(); 00463 gndPredFlippedUpdates(i, chainIdx, affectedGndPreds, 00464 affectedGndPredIndices); 00465 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00466 chainIdx); 00467 } 00468 00469 // If in actual gibbs sampling phase, track the num of times 00470 // the ground predicate is set to true 00471 if (!burningIn && newAssignment) numTrue_[i]++; 00472 } 00473 // If keeping track of true clause groundings 00474 if (!burningIn && trackClauseTrueCnts_) 00475 state_->getNumClauseGndings(clauseTrueCnts_, true); 00476 00477 if (mcmcdebug) cout << "End of Gibbs step" << endl; 00478 }
void MCMC::updateWtsForGndPreds | ( | GroundPredicateHashArray & | gndPreds, | |
Array< int > & | gndPredIndices, | |||
const int & | chainIdx | |||
) | [inline, protected] |
Updates the weights of affected ground predicates.
These are the ground predicates which are in clauses of predicates which have had their truth value changed.
gndPreds | Ground predicates whose weights should be updated. | |
chainIdx | Index of chain where updating occurs. |
Definition at line 488 of file mcmc.h.
References VariableState::getClauseCost(), VariableState::getGndClause(), VariableState::getNegOccurenceArray(), VariableState::getNumTrueLits(), VariableState::getPosOccurenceArray(), GroundClause::getWt(), GroundClause::isHardClause(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), Inference::state_, truthValues_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by SimulatedTempering::infer(), MCSAT::infer(), GibbsSampler::infer(), and performGibbsStep().
00491 { 00492 if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl; 00493 // for each ground predicate whose MB has changed 00494 for (int g = 0; g < gndPreds.size(); g++) 00495 { 00496 double wtIfNoChange = 0, wtIfInverted = 0, wt; 00497 // Ground clauses in which this pred occurs 00498 Array<int>& negGndClauses = 00499 state_->getNegOccurenceArray(gndPredIndices[g] + 1); 00500 Array<int>& posGndClauses = 00501 state_->getPosOccurenceArray(gndPredIndices[g] + 1); 00502 int gndClauseIdx; 00503 bool sense; 00504 00505 if (mcmcdebug) 00506 { 00507 cout << "Ground clauses in which pred " << g << " occurs neg.: " 00508 << negGndClauses.size() << endl; 00509 cout << "Ground clauses in which pred " << g << " occurs pos.: " 00510 << posGndClauses.size() << endl; 00511 } 00512 00513 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++) 00514 { 00515 if (i < negGndClauses.size()) 00516 { 00517 gndClauseIdx = negGndClauses[i]; 00518 if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl; 00519 sense = false; 00520 } 00521 else 00522 { 00523 gndClauseIdx = posGndClauses[i - negGndClauses.size()]; 00524 if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl; 00525 sense = true; 00526 } 00527 00528 GroundClause* gndClause = state_->getGndClause(gndClauseIdx); 00529 if (gndClause->isHardClause()) 00530 wt = state_->getClauseCost(gndClauseIdx); 00531 else 00532 wt = gndClause->getWt(); 00533 // NumTrueLits are stored differently for multi-chain 00534 int numSatLiterals; 00535 if (numChains_ > 1) 00536 numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx]; 00537 else 00538 numSatLiterals = state_->getNumTrueLits(gndClauseIdx); 00539 if (numSatLiterals > 1) 00540 { 00541 // Some other literal is making it sat, so it doesn't matter 00542 // if pos. clause. If neg., nothing can be done to unsatisfy it. 00543 if (wt > 0) 00544 { 00545 wtIfNoChange += wt; 00546 wtIfInverted += wt; 00547 } 00548 } 00549 else 00550 if (numSatLiterals == 1) 00551 { 00552 if (wt > 0) wtIfNoChange += wt; 00553 // Truth values are stored differently for multi-chain 00554 bool truthValue; 00555 if (numChains_ > 1) 00556 truthValue = truthValues_[gndPredIndices[g]][chainIdx]; 00557 else 00558 truthValue = gndPreds[g]->getTruthValue(); 00559 // If the current truth value is the same as its sense in gndClause 00560 if (truthValue == sense) 00561 { 00562 // This gndPred is the only one making this function satisfied 00563 if (wt < 0) wtIfInverted += abs(wt); 00564 } 00565 else 00566 { 00567 // Some other literal is making it satisfied 00568 if (wt > 0) wtIfInverted += wt; 00569 } 00570 } 00571 else 00572 if (numSatLiterals == 0) 00573 { 00574 // None satisfy, so when gndPred switch to its negative, it'll satisfy 00575 if (wt > 0) wtIfInverted += wt; 00576 else if (wt < 0) wtIfNoChange += abs(wt); 00577 } 00578 } // for each ground clause that gndPred appears in 00579 00580 if (mcmcdebug) 00581 { 00582 cout << "wtIfNoChange of pred " << g << ": " 00583 << wtIfNoChange << endl; 00584 cout << "wtIfInverted of pred " << g << ": " 00585 << wtIfInverted << endl; 00586 } 00587 00588 // Clause info is stored differently for multi-chain 00589 if (numChains_ > 1) 00590 { 00591 if (truthValues_[gndPredIndices[g]][chainIdx]) 00592 { 00593 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange; 00594 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted; 00595 } 00596 else 00597 { 00598 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange; 00599 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted; 00600 } 00601 } 00602 else 00603 { // Single chain 00604 if (gndPreds[g]->getTruthValue()) 00605 { 00606 gndPreds[g]->setWtWhenTrue(wtIfNoChange); 00607 gndPreds[g]->setWtWhenFalse(wtIfInverted); 00608 } 00609 else 00610 { 00611 gndPreds[g]->setWtWhenFalse(wtIfNoChange); 00612 gndPreds[g]->setWtWhenTrue(wtIfInverted); 00613 } 00614 } 00615 } // for each ground predicate whose MB has changed 00616 if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl; 00617 }
int MCMC::gibbsSampleFromBlock | ( | const int & | chainIdx, | |
const Array< int > & | block, | |||
const double & | invTemp | |||
) | [inline, protected] |
Chooses an atom from a block according to their probabilities.
chainIdx | Index of chain. | |
block | Block of predicate indices from which one is chosen. | |
invTemp | InvTemp used in simulated tempering. |
Definition at line 628 of file mcmc.h.
References Array< Type >::append(), getProbabilityOfPred(), and Array< Type >::size().
Referenced by SimulatedTempering::infer(), and performGibbsStep().
00630 { 00631 Array<double> numerators; 00632 double denominator = 0; 00633 00634 for (int i = 0; i < block.size(); i++) 00635 { 00636 double prob = getProbabilityOfPred(block[i], chainIdx, invTemp); 00637 numerators.append(prob); 00638 denominator += prob; 00639 } 00640 double r = random(); 00641 double numSum = 0.0; 00642 for (int i = 0; i < block.size(); i++) 00643 { 00644 numSum += numerators[i]; 00645 if (r < ((numSum / denominator) * RAND_MAX)) 00646 { 00647 return i; 00648 } 00649 } 00650 return block.size() - 1; 00651 }
void MCMC::gndPredFlippedUpdates | ( | const int & | gndPredIdx, | |
const int & | chainIdx, | |||
GroundPredicateHashArray & | affectedGndPreds, | |||
Array< int > & | affectedGndPredIndices | |||
) | [inline, protected] |
Updates information when a ground predicate is flipped and retrieves the Markov blanket of the ground predicate.
gndPredIdx | Index of ground pred which was flipped. | |
chainIdx | Index of chain in which the flipping occured. | |
affectedGndPreds | Holds the Markov blanket of the ground predicate. |
Definition at line 661 of file mcmc.h.
References Array< Type >::append(), HashArray< Type, HashFn, EqualFn >::append(), VariableState::decrementNumTrueLits(), VariableState::getGndClause(), VariableState::getGndPred(), VariableState::getGndPredHashArrayPtr(), GroundClause::getGroundPredicate(), GroundClause::getGroundPredicateIndex(), VariableState::getNegOccurenceArray(), VariableState::getNumAtoms(), GroundClause::getNumGroundPredicates(), VariableState::getPosOccurenceArray(), GroundPredicate::getTruthValue(), VariableState::incrementNumTrueLits(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), Inference::state_, and truthValues_.
Referenced by SimulatedTempering::infer(), and performGibbsStep().
00664 { 00665 if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl; 00666 int numAtoms = state_->getNumAtoms(); 00667 GroundPredicate* gndPred = state_->getGndPred(gndPredIdx); 00668 affectedGndPreds.append(gndPred, numAtoms); 00669 affectedGndPredIndices.append(gndPredIdx); 00670 assert(affectedGndPreds.size() <= numAtoms); 00671 00672 Array<int>& negGndClauses = 00673 state_->getNegOccurenceArray(gndPredIdx + 1); 00674 Array<int>& posGndClauses = 00675 state_->getPosOccurenceArray(gndPredIdx + 1); 00676 int gndClauseIdx; 00677 GroundClause* gndClause; 00678 bool sense; 00679 00680 // Find the Markov blanket of this ground predicate 00681 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++) 00682 { 00683 if (i < negGndClauses.size()) 00684 { 00685 gndClauseIdx = negGndClauses[i]; 00686 sense = false; 00687 } 00688 else 00689 { 00690 gndClauseIdx = posGndClauses[i - negGndClauses.size()]; 00691 sense = true; 00692 } 00693 gndClause = state_->getGndClause(gndClauseIdx); 00694 00695 // Different for multi-chain 00696 if (numChains_ > 1) 00697 { 00698 if (truthValues_[gndPredIdx][chainIdx] == sense) 00699 numTrueLits_[gndClauseIdx][chainIdx]++; 00700 else 00701 numTrueLits_[gndClauseIdx][chainIdx]--; 00702 } 00703 else 00704 { // Single chain 00705 if (gndPred->getTruthValue() == sense) 00706 state_->incrementNumTrueLits(gndClauseIdx); 00707 else 00708 state_->decrementNumTrueLits(gndClauseIdx); 00709 } 00710 00711 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++) 00712 { 00713 const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr(); 00714 GroundPredicate* pred = 00715 (GroundPredicate*)gndClause->getGroundPredicate(j, 00716 (GroundPredicateHashArray*)gpha); 00717 affectedGndPreds.append(pred, numAtoms); 00718 affectedGndPredIndices.append( 00719 abs(gndClause->getGroundPredicateIndex(j)) - 1); 00720 assert(affectedGndPreds.size() <= numAtoms); 00721 } 00722 } 00723 if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl; 00724 }
void MCMC::saveLowStateToChain | ( | const int & | chainIdx | ) | [inline, protected] |
The atom assignment in the best state is saved to a chain in the ground predicates.
chainIdx | Index of chain to which the atom assigment is saved |
Definition at line 740 of file mcmc.h.
References VariableState::getNumAtoms(), VariableState::getValueOfLowAtom(), Inference::state_, and truthValues_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
00741 { 00742 for (int i = 0; i < state_->getNumAtoms(); i++) 00743 truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1); 00744 }
void MCMC::setMCMCParameters | ( | MCMCParams * | params | ) | [inline, protected] |
Sets the user-set parameters for this MCMC algorithm.
params | MCMC parameters for this algorithm. |
Definition at line 751 of file mcmc.h.
References MCMCParams::burnMaxSteps, burnMaxSteps_, MCMCParams::burnMinSteps, burnMinSteps_, MCMCParams::maxSeconds, maxSeconds_, MCMCParams::maxSteps, maxSteps_, MCMCParams::minSteps, minSteps_, MCMCParams::numChains, and numChains_.
00752 { 00753 // User-set parameters 00754 numChains_ = params->numChains; 00755 burnMinSteps_ = params->burnMinSteps; 00756 burnMaxSteps_ = params->burnMaxSteps; 00757 minSteps_ = params->minSteps; 00758 maxSteps_ = params->maxSteps; 00759 maxSeconds_ = params->maxSeconds; 00760 }