#include <mcmc.h>
Inheritance diagram for MCMC:
Public Member Functions | |
MCMC (VariableState *state, long int seed, const bool &trackClauseTrueCnts, MCMCParams *params) | |
Constructor. | |
~MCMC () | |
Destructor. | |
void | printProbabilities (ostream &out) |
Prints the probabilities of each predicate to a stream. | |
double | getProbability (GroundPredicate *const &gndPred) |
Gets the probability of a ground predicate. | |
void | printTruePreds (ostream &out) |
Prints each predicate with a probability of 0.5 or greater to a stream. | |
Protected Member Functions | |
void | initTruthValuesAndWts (const int &numChains) |
Initializes truth values and weights in the ground preds. | |
void | initNumTrue () |
Initializes structure for holding number of times a predicate was set to true. | |
void | initNumTrueLits (const int &numChains) |
Initializes the number of true lits in each clause in each chain. | |
void | randomInitGndPredsTruthValues (const int &numChains) |
Randomly initializes the ground predicate truth values, taking blocks into account. | |
bool | genTruthValueForProb (const double &p) |
Generates a truth value based on a probability. | |
double | getProbabilityOfPred (const int &predIdx, const int &chainIdx, const double &invTemp) |
Computes the probability of a ground predicate in a chain. | |
void | setOthersInBlockToFalse (const int &chainIdx, const int &atomIdx, const int &blockIdx) |
Sets the truth values of all atoms for a given chain in a block except for the one given. | |
void | performGibbsStep (const int &chainIdx, const bool &burningIn, GroundPredicateHashArray &affectedGndPreds, Array< int > &affectedGndPredIndices) |
Performs one step of Gibbs sampling in one chain. | |
void | updateWtsForGndPreds (GroundPredicateHashArray &gndPreds, Array< int > &gndPredIndices, const int &chainIdx) |
Updates the weights of affected ground predicates. | |
int | gibbsSampleFromBlock (const int &chainIdx, const Array< int > &block, const double &invTemp) |
Chooses an atom from a block according to their probabilities. | |
void | gndPredFlippedUpdates (const int &gndPredIdx, const int &chainIdx, GroundPredicateHashArray &affectedGndPreds, Array< int > &affectedGndPredIndices) |
Updates information when a ground predicate is flipped and retrieves the Markov blanket of the ground predicate. | |
double | getProbTrue (const int &predIdx) const |
void | setProbTrue (const int &predIdx, const double &p) |
void | saveLowStateToChain (const int &chainIdx) |
The atom assignment in the best state is saved to a chain in the ground predicates. | |
void | setMCMCParameters (MCMCParams *params) |
Sets the user-set parameters for this MCMC algorithm. | |
Protected Attributes | |
int | numChains_ |
int | burnMinSteps_ |
int | burnMaxSteps_ |
int | minSteps_ |
int | maxSteps_ |
int | maxSeconds_ |
Array< Array< bool > > | truthValues_ |
Array< Array< double > > | wtsWhenFalse_ |
Array< Array< double > > | wtsWhenTrue_ |
Array< double > | numTrue_ |
Array< Array< int > > | numTrueLits_ |
This class does not implement all pure virtual functions of Inference and is thus an abstract class.
Definition at line 79 of file mcmc.h.
MCMC::MCMC | ( | VariableState * | state, | |
long int | seed, | |||
const bool & | trackClauseTrueCnts, | |||
MCMCParams * | params | |||
) | [inline] |
Constructor.
User-set parameters are set.
Definition at line 89 of file mcmc.h.
References MCMCParams::burnMaxSteps, burnMaxSteps_, MCMCParams::burnMinSteps, burnMinSteps_, MCMCParams::maxSeconds, maxSeconds_, MCMCParams::maxSteps, maxSteps_, MCMCParams::minSteps, minSteps_, MCMCParams::numChains, and numChains_.
00091 : Inference(state, seed, trackClauseTrueCnts) 00092 { 00093 // User-set parameters 00094 numChains_ = params->numChains; 00095 burnMinSteps_ = params->burnMinSteps; 00096 burnMaxSteps_ = params->burnMaxSteps; 00097 minSteps_ = params->minSteps; 00098 maxSteps_ = params->maxSteps; 00099 maxSeconds_ = params->maxSeconds; 00100 }
double MCMC::getProbability | ( | GroundPredicate *const & | gndPred | ) | [inline, virtual] |
Gets the probability of a ground predicate.
gndPred | GroundPredicate whose probability is being retrieved. |
Implements Inference.
Definition at line 129 of file mcmc.h.
References VariableState::getGndPredIndex(), getProbTrue(), and Inference::state_.
00130 { 00131 int idx = state_->getGndPredIndex(gndPred); 00132 double prob = 0.0; 00133 if (idx >= 0) prob = getProbTrue(idx); 00134 // Uniform smoothing 00135 return (prob*10000 + 1/2.0)/(10000 + 1.0); 00136 }
void MCMC::initTruthValuesAndWts | ( | const int & | numChains | ) | [inline, protected] |
Initializes truth values and weights in the ground preds.
numChains | Number of chains for which the values should be initialized. | |
startIdx | All predicates with index greater than or equal to this will be initialized. |
Definition at line 163 of file mcmc.h.
References VariableState::getNumAtoms(), VariableState::getNumClauses(), Array< Type >::growToSize(), numTrueLits_, Inference::state_, truthValues_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
00164 { 00165 int numPreds = state_->getNumAtoms(); 00166 truthValues_.growToSize(numPreds); 00167 wtsWhenFalse_.growToSize(numPreds); 00168 wtsWhenTrue_.growToSize(numPreds); 00169 for (int i = 0; i < numPreds; i++) 00170 { 00171 truthValues_[i].growToSize(numChains, false); 00172 wtsWhenFalse_[i].growToSize(numChains, 0); 00173 wtsWhenTrue_[i].growToSize(numChains, 0); 00174 } 00175 00176 int numClauses = state_->getNumClauses(); 00177 numTrueLits_.growToSize(numClauses); 00178 for (int i = 0; i < numClauses; i++) 00179 { 00180 numTrueLits_[i].growToSize(numChains, 0); 00181 } 00182 }
void MCMC::initNumTrueLits | ( | const int & | numChains | ) | [inline, protected] |
Initializes the number of true lits in each clause in each chain.
numChains | Number of chains for which the initialization should take place. |
Definition at line 200 of file mcmc.h.
References VariableState::getAtomInClause(), VariableState::getGndClause(), VariableState::getNumAtoms(), VariableState::getNumClauses(), numTrueLits_, Inference::state_, and truthValues_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
00201 { 00202 for (int i = 0; i < state_->getNumClauses(); i++) 00203 { 00204 GroundClause* gndClause = state_->getGndClause(i); 00205 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++) 00206 { 00207 const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1; 00208 for (int c = 0; c < numChains; c++) 00209 { 00210 if (truthValues_[atomIdx][c] == gndClause->getGroundPredicateSense(j)) 00211 { 00212 numTrueLits_[i][c]++; 00213 assert(numTrueLits_[i][c] <= state_->getNumAtoms()); 00214 } 00215 } 00216 } 00217 } 00218 }
void MCMC::randomInitGndPredsTruthValues | ( | const int & | numChains | ) | [inline, protected] |
Randomly initializes the ground predicate truth values, taking blocks into account.
numChains | Number of chains for which the initialization should take place. |
Definition at line 227 of file mcmc.h.
References genTruthValueForProb(), VariableState::getBlockArray(), VariableState::getBlockEvidence(), VariableState::getBlockIndex(), VariableState::getNumBlocks(), setOthersInBlockToFalse(), Array< Type >::size(), Inference::state_, and truthValues_.
Referenced by GibbsSampler::init().
00228 { 00229 for (int c = 0; c < numChains; c++) 00230 { 00231 // For each block: select one to set to true 00232 for (int i = 0; i < state_->getNumBlocks(); i++) 00233 { 00234 // If evidence atom exists, then all others are false 00235 if (state_->getBlockEvidence(i)) 00236 { 00237 // If 2nd argument is -1, then all are set to false 00238 setOthersInBlockToFalse(c, -1, i); 00239 continue; 00240 } 00241 00242 Array<int>& block = state_->getBlockArray(i); 00243 int chosen = random() % block.size(); 00244 truthValues_[block[chosen]][c] = true; 00245 setOthersInBlockToFalse(c, chosen, i); 00246 } 00247 00248 // Random tv for all not in blocks 00249 for (int i = 0; i < truthValues_.size(); i++) 00250 { 00251 // Predicates in blocks have been handled above 00252 if (state_->getBlockIndex(i) == -1) 00253 { 00254 bool tv = genTruthValueForProb(0.5); 00255 truthValues_[i][c] = tv; 00256 } 00257 } 00258 } 00259 }
bool MCMC::genTruthValueForProb | ( | const double & | p | ) | [inline, protected] |
Generates a truth value based on a probability.
p | Number between 0 and 1. With probability p, truth value will be true and with probability 1 - p, it will be false. |
Definition at line 267 of file mcmc.h.
Referenced by SimulatedTempering::infer(), performGibbsStep(), and randomInitGndPredsTruthValues().
00268 { 00269 if (p == 1.0) return true; 00270 if (p == 0.0) return false; 00271 bool r = random() <= p*RAND_MAX; 00272 return r; 00273 }
double MCMC::getProbabilityOfPred | ( | const int & | predIdx, | |
const int & | chainIdx, | |||
const double & | invTemp | |||
) | [inline, protected] |
Computes the probability of a ground predicate in a chain.
predIdx | Index of predicate. | |
chainIdx | Index of chain. | |
invTemp | InvTemp used in simulated tempering. |
Definition at line 284 of file mcmc.h.
References VariableState::getGndPred(), GroundPredicate::getWtWhenFalse(), GroundPredicate::getWtWhenTrue(), numChains_, Inference::state_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by gibbsSampleFromBlock(), SimulatedTempering::infer(), and performGibbsStep().
00286 { 00287 // Different for multi-chain 00288 if (numChains_ > 1) 00289 { 00290 return 1.0 / 00291 ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] - 00292 wtsWhenTrue_[predIdx][chainIdx]) * 00293 invTemp)); 00294 } 00295 else 00296 { 00297 GroundPredicate* gndPred = state_->getGndPred(predIdx); 00298 return 1.0 / 00299 ( 1.0 + exp((gndPred->getWtWhenFalse() - 00300 gndPred->getWtWhenTrue()) * 00301 invTemp)); 00302 } 00303 }
void MCMC::setOthersInBlockToFalse | ( | const int & | chainIdx, | |
const int & | atomIdx, | |||
const int & | blockIdx | |||
) | [inline, protected] |
Sets the truth values of all atoms for a given chain in a block except for the one given.
chainIdx | Index of chain for which atoms are being set. | |
atomIdx | Index of atom in block exempt from being set to false. | |
blockIdx | Index of block whose atoms are set to false. |
Definition at line 313 of file mcmc.h.
References VariableState::getBlockArray(), Array< Type >::size(), Inference::state_, and truthValues_.
Referenced by randomInitGndPredsTruthValues().
00315 { 00316 Array<int>& block = state_->getBlockArray(blockIdx); 00317 for (int i = 0; i < block.size(); i++) 00318 { 00319 if (i != atomIdx) 00320 truthValues_[block[i]][chainIdx] = false; 00321 } 00322 }
void MCMC::performGibbsStep | ( | const int & | chainIdx, | |
const bool & | burningIn, | |||
GroundPredicateHashArray & | affectedGndPreds, | |||
Array< int > & | affectedGndPredIndices | |||
) | [inline, protected] |
Performs one step of Gibbs sampling in one chain.
chainIdx | Index of chain in which the Gibbs step is performed. | |
burningIn | If true, burning-in is occuring. Otherwise, false. | |
affectedGndPreds | Used to store GroundPredicates which are affected by changing truth values. | |
affectedGndPredIndices | Used to store indices of GroundPredicates which are affected by changing truth values. |
Definition at line 334 of file mcmc.h.
References Inference::clauseTrueCnts_, Array< Type >::clear(), HashArray< Type, HashFn, EqualFn >::clear(), genTruthValueForProb(), VariableState::getBlockArray(), VariableState::getBlockEvidence(), VariableState::getBlockIndex(), VariableState::getGndPred(), VariableState::getNumAtoms(), VariableState::getNumBlocks(), VariableState::getNumClauseGndings(), getProbabilityOfPred(), GroundPredicate::getTruthValue(), gibbsSampleFromBlock(), gndPredFlippedUpdates(), numChains_, numTrue_, GroundPredicate::setTruthValue(), Inference::state_, Inference::trackClauseTrueCnts_, truthValues_, and updateWtsForGndPreds().
Referenced by MCSAT::infer(), and GibbsSampler::infer().
00337 { 00338 if (mcmcdebug) cout << "Gibbs step" << endl; 00339 00340 // For each block: select one to set to true 00341 for (int i = 0; i < state_->getNumBlocks(); i++) 00342 { 00343 // If evidence atom exists, then all others stay false 00344 if (state_->getBlockEvidence(i)) continue; 00345 00346 Array<int>& block = state_->getBlockArray(i); 00347 // chosen is index in the block, block[chosen] is index in gndPreds_ 00348 int chosen = gibbsSampleFromBlock(chainIdx, block, 1); 00349 // Truth values are stored differently for multi-chain 00350 bool truthValue; 00351 GroundPredicate* gndPred = state_->getGndPred(block[chosen]); 00352 if (numChains_ > 1) truthValue = truthValues_[block[chosen]][chainIdx]; 00353 else truthValue = gndPred->getTruthValue(); 00354 // If chosen pred was false, then need to set previous true 00355 // one to false and update wts 00356 if (!truthValue) 00357 { 00358 for (int j = 0; j < block.size(); j++) 00359 { 00360 // Truth values are stored differently for multi-chain 00361 bool otherTruthValue; 00362 GroundPredicate* otherGndPred = state_->getGndPred(block[j]); 00363 if (numChains_ > 1) 00364 otherTruthValue = truthValues_[block[j]][chainIdx]; 00365 else 00366 otherTruthValue = otherGndPred->getTruthValue(); 00367 if (otherTruthValue) 00368 { 00369 // Truth values are stored differently for multi-chain 00370 if (numChains_ > 1) 00371 truthValues_[block[j]][chainIdx] = false; 00372 else 00373 otherGndPred->setTruthValue(false); 00374 00375 affectedGndPreds.clear(); 00376 affectedGndPredIndices.clear(); 00377 gndPredFlippedUpdates(block[j], chainIdx, affectedGndPreds, 00378 affectedGndPredIndices); 00379 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00380 chainIdx); 00381 } 00382 } 00383 // Set truth value and update wts for chosen atom 00384 // Truth values are stored differently for multi-chain 00385 if (numChains_ > 1) truthValues_[block[chosen]][chainIdx] = true; 00386 else gndPred->setTruthValue(true); 00387 affectedGndPreds.clear(); 00388 affectedGndPredIndices.clear(); 00389 gndPredFlippedUpdates(block[chosen], chainIdx, affectedGndPreds, 00390 affectedGndPredIndices); 00391 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00392 chainIdx); 00393 } 00394 00395 // If in actual gibbs sampling phase, track the num of times 00396 // the ground predicate is set to true 00397 if (!burningIn) numTrue_[block[chosen]]++; 00398 } 00399 00400 // Now go through all preds not in blocks 00401 for (int i = 0; i < state_->getNumAtoms(); i++) 00402 { 00403 // Predicates in blocks have been handled above 00404 if (state_->getBlockIndex(i) >= 0) continue; 00405 00406 if (mcmcdebug) 00407 { 00408 cout << "Chain " << chainIdx << ": Probability of pred " 00409 << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl; 00410 } 00411 00412 bool newAssignment 00413 = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1)); 00414 00415 // Truth values are stored differently for multi-chain 00416 bool truthValue; 00417 GroundPredicate* gndPred = state_->getGndPred(i); 00418 if (numChains_ > 1) truthValue = truthValues_[i][chainIdx]; 00419 else truthValue = gndPred->getTruthValue(); 00420 // If gndPred is flipped, do updates & find all affected gndPreds 00421 if (newAssignment != truthValue) 00422 { 00423 if (mcmcdebug) 00424 { 00425 cout << "Chain " << chainIdx << ": Changing truth value of pred " 00426 << i << " to " << newAssignment << endl; 00427 } 00428 00429 if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment; 00430 else gndPred->setTruthValue(newAssignment); 00431 affectedGndPreds.clear(); 00432 affectedGndPredIndices.clear(); 00433 gndPredFlippedUpdates(i, chainIdx, affectedGndPreds, 00434 affectedGndPredIndices); 00435 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00436 chainIdx); 00437 } 00438 00439 // If in actual gibbs sampling phase, track the num of times 00440 // the ground predicate is set to true 00441 if (!burningIn && newAssignment) numTrue_[i]++; 00442 } 00443 // If keeping track of true clause groundings 00444 if (!burningIn && trackClauseTrueCnts_) 00445 state_->getNumClauseGndings(clauseTrueCnts_, true); 00446 00447 if (mcmcdebug) cout << "End of Gibbs step" << endl; 00448 }
void MCMC::updateWtsForGndPreds | ( | GroundPredicateHashArray & | gndPreds, | |
Array< int > & | gndPredIndices, | |||
const int & | chainIdx | |||
) | [inline, protected] |
Updates the weights of affected ground predicates.
These are the ground predicates which are in clauses of predicates which have had their truth value changed.
gndPreds | Ground predicates whose weights should be updated. | |
chainIdx | Index of chain where updating occurs. |
Definition at line 458 of file mcmc.h.
References VariableState::getClauseCost(), VariableState::getGndClause(), VariableState::getNegOccurenceArray(), VariableState::getNumTrueLits(), VariableState::getPosOccurenceArray(), GroundClause::getWt(), GroundClause::isHardClause(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), Inference::state_, truthValues_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by SimulatedTempering::infer(), MCSAT::infer(), GibbsSampler::infer(), and performGibbsStep().
00461 { 00462 if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl; 00463 // for each ground predicate whose MB has changed 00464 for (int g = 0; g < gndPreds.size(); g++) 00465 { 00466 double wtIfNoChange = 0, wtIfInverted = 0, wt; 00467 // Ground clauses in which this pred occurs 00468 Array<int>& negGndClauses = 00469 state_->getNegOccurenceArray(gndPredIndices[g] + 1); 00470 Array<int>& posGndClauses = 00471 state_->getPosOccurenceArray(gndPredIndices[g] + 1); 00472 int gndClauseIdx; 00473 bool sense; 00474 00475 if (mcmcdebug) 00476 { 00477 cout << "Ground clauses in which pred " << g << " occurs neg.: " 00478 << negGndClauses.size() << endl; 00479 cout << "Ground clauses in which pred " << g << " occurs pos.: " 00480 << posGndClauses.size() << endl; 00481 } 00482 00483 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++) 00484 { 00485 if (i < negGndClauses.size()) 00486 { 00487 gndClauseIdx = negGndClauses[i]; 00488 if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl; 00489 sense = false; 00490 } 00491 else 00492 { 00493 gndClauseIdx = posGndClauses[i - negGndClauses.size()]; 00494 if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl; 00495 sense = true; 00496 } 00497 00498 GroundClause* gndClause = state_->getGndClause(gndClauseIdx); 00499 if (gndClause->isHardClause()) 00500 wt = state_->getClauseCost(gndClauseIdx); 00501 else 00502 wt = gndClause->getWt(); 00503 // NumTrueLits are stored differently for multi-chain 00504 int numSatLiterals; 00505 if (numChains_ > 1) 00506 numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx]; 00507 else 00508 numSatLiterals = state_->getNumTrueLits(gndClauseIdx); 00509 if (numSatLiterals > 1) 00510 { 00511 // Some other literal is making it sat, so it doesn't matter 00512 // if pos. clause. If neg., nothing can be done to unsatisfy it. 00513 if (wt > 0) 00514 { 00515 wtIfNoChange += wt; 00516 wtIfInverted += wt; 00517 } 00518 } 00519 else 00520 if (numSatLiterals == 1) 00521 { 00522 if (wt > 0) wtIfNoChange += wt; 00523 // Truth values are stored differently for multi-chain 00524 bool truthValue; 00525 if (numChains_ > 1) 00526 truthValue = truthValues_[gndPredIndices[g]][chainIdx]; 00527 else 00528 truthValue = gndPreds[g]->getTruthValue(); 00529 // If the current truth value is the same as its sense in gndClause 00530 if (truthValue == sense) 00531 { 00532 // This gndPred is the only one making this function satisfied 00533 if (wt < 0) wtIfInverted += abs(wt); 00534 } 00535 else 00536 { 00537 // Some other literal is making it satisfied 00538 if (wt > 0) wtIfInverted += wt; 00539 } 00540 } 00541 else 00542 if (numSatLiterals == 0) 00543 { 00544 // None satisfy, so when gndPred switch to its negative, it'll satisfy 00545 if (wt > 0) wtIfInverted += wt; 00546 else if (wt < 0) wtIfNoChange += abs(wt); 00547 } 00548 } // for each ground clause that gndPred appears in 00549 00550 if (mcmcdebug) 00551 { 00552 cout << "wtIfNoChange of pred " << g << ": " 00553 << wtIfNoChange << endl; 00554 cout << "wtIfInverted of pred " << g << ": " 00555 << wtIfInverted << endl; 00556 } 00557 00558 // Clause info is stored differently for multi-chain 00559 if (numChains_ > 1) 00560 { 00561 if (truthValues_[gndPredIndices[g]][chainIdx]) 00562 { 00563 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange; 00564 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted; 00565 } 00566 else 00567 { 00568 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange; 00569 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted; 00570 } 00571 } 00572 else 00573 { // Single chain 00574 if (gndPreds[g]->getTruthValue()) 00575 { 00576 gndPreds[g]->setWtWhenTrue(wtIfNoChange); 00577 gndPreds[g]->setWtWhenFalse(wtIfInverted); 00578 } 00579 else 00580 { 00581 gndPreds[g]->setWtWhenFalse(wtIfNoChange); 00582 gndPreds[g]->setWtWhenTrue(wtIfInverted); 00583 } 00584 } 00585 } // for each ground predicate whose MB has changed 00586 if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl; 00587 }
int MCMC::gibbsSampleFromBlock | ( | const int & | chainIdx, | |
const Array< int > & | block, | |||
const double & | invTemp | |||
) | [inline, protected] |
Chooses an atom from a block according to their probabilities.
chainIdx | Index of chain. | |
block | Block of predicate indices from which one is chosen. | |
invTemp | InvTemp used in simulated tempering. |
Definition at line 598 of file mcmc.h.
References Array< Type >::append(), getProbabilityOfPred(), and Array< Type >::size().
Referenced by SimulatedTempering::infer(), and performGibbsStep().
00600 { 00601 Array<double> numerators; 00602 double denominator = 0; 00603 00604 for (int i = 0; i < block.size(); i++) 00605 { 00606 double prob = getProbabilityOfPred(block[i], chainIdx, invTemp); 00607 numerators.append(prob); 00608 denominator += prob; 00609 } 00610 double r = random(); 00611 double numSum = 0.0; 00612 for (int i = 0; i < block.size(); i++) 00613 { 00614 numSum += numerators[i]; 00615 if (r < ((numSum / denominator) * RAND_MAX)) 00616 { 00617 return i; 00618 } 00619 } 00620 return block.size() - 1; 00621 }
void MCMC::gndPredFlippedUpdates | ( | const int & | gndPredIdx, | |
const int & | chainIdx, | |||
GroundPredicateHashArray & | affectedGndPreds, | |||
Array< int > & | affectedGndPredIndices | |||
) | [inline, protected] |
Updates information when a ground predicate is flipped and retrieves the Markov blanket of the ground predicate.
gndPredIdx | Index of ground pred which was flipped. | |
chainIdx | Index of chain in which the flipping occured. | |
affectedGndPreds | Holds the Markov blanket of the ground predicate. |
Definition at line 631 of file mcmc.h.
References Array< Type >::append(), HashArray< Type, HashFn, EqualFn >::append(), VariableState::decrementNumTrueLits(), VariableState::getGndClause(), VariableState::getGndPred(), VariableState::getGndPredHashArrayPtr(), GroundClause::getGroundPredicate(), GroundClause::getGroundPredicateIndex(), VariableState::getNegOccurenceArray(), VariableState::getNumAtoms(), GroundClause::getNumGroundPredicates(), VariableState::getPosOccurenceArray(), GroundPredicate::getTruthValue(), VariableState::incrementNumTrueLits(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), Inference::state_, and truthValues_.
Referenced by SimulatedTempering::infer(), and performGibbsStep().
00634 { 00635 if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl; 00636 int numAtoms = state_->getNumAtoms(); 00637 GroundPredicate* gndPred = state_->getGndPred(gndPredIdx); 00638 affectedGndPreds.append(gndPred, numAtoms); 00639 affectedGndPredIndices.append(gndPredIdx); 00640 assert(affectedGndPreds.size() <= numAtoms); 00641 00642 Array<int>& negGndClauses = 00643 state_->getNegOccurenceArray(gndPredIdx + 1); 00644 Array<int>& posGndClauses = 00645 state_->getPosOccurenceArray(gndPredIdx + 1); 00646 int gndClauseIdx; 00647 GroundClause* gndClause; 00648 bool sense; 00649 00650 // Find the Markov blanket of this ground predicate 00651 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++) 00652 { 00653 if (i < negGndClauses.size()) 00654 { 00655 gndClauseIdx = negGndClauses[i]; 00656 sense = false; 00657 } 00658 else 00659 { 00660 gndClauseIdx = posGndClauses[i - negGndClauses.size()]; 00661 sense = true; 00662 } 00663 gndClause = state_->getGndClause(gndClauseIdx); 00664 00665 // Different for multi-chain 00666 if (numChains_ > 1) 00667 { 00668 if (truthValues_[gndPredIdx][chainIdx] == sense) 00669 numTrueLits_[gndClauseIdx][chainIdx]++; 00670 else 00671 numTrueLits_[gndClauseIdx][chainIdx]--; 00672 } 00673 else 00674 { // Single chain 00675 if (gndPred->getTruthValue() == sense) 00676 state_->incrementNumTrueLits(gndClauseIdx); 00677 else 00678 state_->decrementNumTrueLits(gndClauseIdx); 00679 } 00680 00681 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++) 00682 { 00683 const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr(); 00684 GroundPredicate* pred = 00685 (GroundPredicate*)gndClause->getGroundPredicate(j, 00686 (GroundPredicateHashArray*)gpha); 00687 affectedGndPreds.append(pred, numAtoms); 00688 affectedGndPredIndices.append( 00689 abs(gndClause->getGroundPredicateIndex(j)) - 1); 00690 assert(affectedGndPreds.size() <= numAtoms); 00691 } 00692 } 00693 if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl; 00694 }
void MCMC::saveLowStateToChain | ( | const int & | chainIdx | ) | [inline, protected] |
The atom assignment in the best state is saved to a chain in the ground predicates.
chainIdx | Index of chain to which the atom assigment is saved |
Definition at line 710 of file mcmc.h.
References VariableState::getNumAtoms(), VariableState::getValueOfLowAtom(), Inference::state_, and truthValues_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
00711 { 00712 for (int i = 0; i < state_->getNumAtoms(); i++) 00713 truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1); 00714 }
void MCMC::setMCMCParameters | ( | MCMCParams * | params | ) | [inline, protected] |
Sets the user-set parameters for this MCMC algorithm.
params | MCMC parameters for this algorithm. |
Definition at line 721 of file mcmc.h.
References MCMCParams::burnMaxSteps, burnMaxSteps_, MCMCParams::burnMinSteps, burnMinSteps_, MCMCParams::maxSeconds, maxSeconds_, MCMCParams::maxSteps, maxSteps_, MCMCParams::minSteps, minSteps_, MCMCParams::numChains, and numChains_.
00722 { 00723 // User-set parameters 00724 numChains_ = params->numChains; 00725 burnMinSteps_ = params->burnMinSteps; 00726 burnMaxSteps_ = params->burnMaxSteps; 00727 minSteps_ = params->minSteps; 00728 maxSteps_ = params->maxSteps; 00729 maxSeconds_ = params->maxSeconds; 00730 }