#include <mcmc.h>
Inheritance diagram for MCMC:
Public Member Functions | |
MCMC (VariableState *state, long int seed, const bool &trackClauseTrueCnts, MCMCParams *params, Array< Array< Predicate * > * > *queryFormulas=NULL) | |
Constructor. | |
MCMC (HVariableState *state, long int seed, const bool &trackClauseTrueCnts, MCMCParams *params) | |
~MCMC () | |
Destructor. | |
double | computeHybridClauseValue (int clauseIdx, int c) |
double | HybridClauseContPartValue (int contClauseIdx, int c) |
double | HybridClauseDisPartValue (int contClauseIdx, int c) |
void | updateDisPredValue (int predIdx, int chainIdx, bool updateValue) |
void | updateProposalContValue (int contPredIdx, int chainIdx) |
double | getProbabilityOfPredH (const int &predIdx, const int &chainIdx, const double &invTemp) |
virtual void | printNetwork (ostream &out) |
Prints out the network. | |
void | printProbabilities (ostream &out) |
Prints the probabilities of each predicate to a stream. | |
void | getChangedPreds (vector< string > &changedPreds, vector< float > &probs, vector< float > &oldProbs, const float &probDelta) |
Puts the predicates whose probability has changed with respect to the reference vector oldProbs by more than probDelta in string form and the corresponding probabilities of each predicate in two vectors. | |
double | getProbabilityH (GroundPredicate *const &gndPred) |
double | getProbability (GroundPredicate *const &gndPred) |
Gets the probability of a ground predicate. | |
void | printTruePreds (ostream &out) |
Prints each predicate with a probability of 0.5 or greater to a stream. | |
void | printTruePredsH (ostream &out) |
Public Attributes | |
Array< Array< double > > | truthValuesCont_ |
Array< Array< bool > > | hybridClauseDisPartValueMCMC_ |
Array< Array< double > > | hybridClauseContPartValueMCMC_ |
Protected Member Functions | |
void | initTruthValuesAndWts (const int &numChains) |
Initializes truth values and weights in the ground preds. | |
void | initNumTrue () |
Initializes structure for holding number of times a predicate was set to true. | |
void | initNumTrueLits (const int &numChains) |
Initializes the number of true lits in each clause in each chain. | |
void | randomInitGndPredsTruthValues (const int &numChains) |
Randomly initializes the ground predicate truth values, taking blocks into account. | |
bool | genTruthValueForProb (const double &p) |
Generates a truth value based on a probability. | |
double | getProbabilityOfPred (const int &predIdx, const int &chainIdx, const double &invTemp) |
Computes the probability of a ground predicate in a chain. | |
void | setOthersInBlockToFalse (const int &chainIdx, const int &atomIdx, const int &blockIdx) |
Sets the truth values of all atoms for a given chain in a block except for the one given. | |
void | performGibbsStep (const int &chainIdx, const bool &burningIn, GroundPredicateHashArray &affectedGndPreds, Array< int > &affectedGndPredIndices) |
Performs one step of Gibbs sampling in one chain. | |
void | updateWtsForGndPredsH (GroundPredicateHashArray &gndPreds, Array< int > &gndPredIndices, const int &chainIdx) |
Updates the weights of affected ground predicates. | |
void | updateWtsForGndPreds (GroundPredicateHashArray &gndPreds, Array< int > &gndPredIndices, const int &chainIdx) |
Updates the weights of affected ground predicates. | |
int | gibbsSampleFromBlock (const int &chainIdx, const int &blockIndex, const double &invTemp) |
Chooses an atom from a block according to their probabilities. | |
void | gndPredFlippedUpdates (const int &gndPredIdx, const int &chainIdx, GroundPredicateHashArray &affectedGndPreds, Array< int > &affectedGndPredIndices) |
Updates information when a ground predicate is flipped and retrieves the Markov blanket of the ground predicate. | |
double | getProbTrue (const int &predIdx) const |
void | setProbTrue (const int &predIdx, const double &p) |
void | saveLowStateToChain (const int &chainIdx) |
The atom assignment in the best state is saved to a chain in the ground predicates. | |
void | setMCMCParameters (MCMCParams *params) |
Sets the user-set parameters for this MCMC algorithm. | |
void | scaleSamples (double factor) |
Increase or decrease the number of MCMC samples by a multiplicative factor. | |
Protected Attributes | |
int | numChains_ |
int | burnMinSteps_ |
int | burnMaxSteps_ |
int | minSteps_ |
int | maxSteps_ |
int | maxSeconds_ |
Array< Array< bool > > | truthValues_ |
Array< Array< double > > | wtsWhenFalse_ |
Array< Array< double > > | wtsWhenTrue_ |
Array< double > | numTrue_ |
Array< Array< int > > | numTrueLits_ |
This class does not implement all pure virtual functions of Inference and is thus an abstract class.
Definition at line 80 of file mcmc.h.
MCMC::MCMC | ( | VariableState * | state, | |
long int | seed, | |||
const bool & | trackClauseTrueCnts, | |||
MCMCParams * | params, | |||
Array< Array< Predicate * > * > * | queryFormulas = NULL | |||
) | [inline] |
Constructor.
User-set parameters are set.
Definition at line 90 of file mcmc.h.
References MCMCParams::burnMaxSteps, burnMaxSteps_, MCMCParams::burnMinSteps, burnMinSteps_, MCMCParams::maxSeconds, maxSeconds_, MCMCParams::maxSteps, maxSteps_, MCMCParams::minSteps, minSteps_, MCMCParams::numChains, and numChains_.
00092 : Inference(state, seed, trackClauseTrueCnts, queryFormulas) 00093 { 00094 // User-set parameters 00095 numChains_ = params->numChains; 00096 burnMinSteps_ = params->burnMinSteps; 00097 burnMaxSteps_ = params->burnMaxSteps; 00098 minSteps_ = params->minSteps; 00099 maxSteps_ = params->maxSteps; 00100 maxSeconds_ = params->maxSeconds; 00101 }
void MCMC::getChangedPreds | ( | vector< string > & | changedPreds, | |
vector< float > & | probs, | |||
vector< float > & | oldProbs, | |||
const float & | probDelta | |||
) | [inline, virtual] |
Puts the predicates whose probability has changed with respect to the reference vector oldProbs by more than probDelta in string form and the corresponding probabilities of each predicate in two vectors.
changedPreds | Predicates whose probability have changed more than probDelta are put here. | |
probs | The probabilities corresponding to the predicates in changedPreds are put here. | |
oldProbs | Reference probabilities for checking for changes. | |
probDelta | If probability of an atom has changed more than this value, then it is considered to have changed. |
Implements Inference.
Definition at line 468 of file mcmc.h.
References VariableState::getNumAtoms(), getProbTrue(), VariableState::printGndPred(), and Inference::state_.
00470 { 00471 changedPreds.clear(); 00472 probs.clear(); 00473 int numAtoms = state_->getNumAtoms(); 00474 // Atoms may have been added to the state, previous prob. was 0 00475 oldProbs.resize(numAtoms, 0); 00476 for (int i = 0; i < numAtoms; i++) 00477 { 00478 double prob = getProbTrue(i); 00479 if (abs(prob - oldProbs[i]) > probDelta) 00480 { 00481 // Truth value has changed: Store new value (not smoothed) in oldProbs 00482 // and add to two return vectors 00483 oldProbs[i] = prob; 00484 // Uniform smoothing 00485 prob = (prob*10000 + 1/2.0)/(10000 + 1.0); 00486 ostringstream oss(ostringstream::out); 00487 state_->printGndPred(i, oss); 00488 changedPreds.push_back(oss.str()); 00489 probs.push_back(prob); 00490 } 00491 } 00492 }
double MCMC::getProbability | ( | GroundPredicate *const & | gndPred | ) | [inline, virtual] |
Gets the probability of a ground predicate.
gndPred | GroundPredicate whose probability is being retrieved. |
Implements Inference.
Definition at line 510 of file mcmc.h.
References VariableState::getGndPredIndex(), getProbTrue(), and Inference::state_.
00511 { 00512 int idx = state_->getGndPredIndex(gndPred); 00513 double prob = 0.0; 00514 if (idx >= 0) prob = getProbTrue(idx); 00515 // Uniform smoothing 00516 return (prob*10000 + 1/2.0)/(10000 + 1.0); 00517 }
void MCMC::initTruthValuesAndWts | ( | const int & | numChains | ) | [inline, protected] |
Initializes truth values and weights in the ground preds.
numChains | Number of chains for which the values should be initialized. | |
startIdx | All predicates with index greater than or equal to this will be initialized. |
Definition at line 557 of file mcmc.h.
References VariableState::getNumAtoms(), VariableState::getNumClauses(), Array< Type >::growToSize(), numTrueLits_, Inference::state_, truthValues_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
00558 { 00559 int numPreds = state_->getNumAtoms(); 00560 truthValues_.growToSize(numPreds); 00561 wtsWhenFalse_.growToSize(numPreds); 00562 wtsWhenTrue_.growToSize(numPreds); 00563 for (int i = 0; i < numPreds; i++) 00564 { 00565 truthValues_[i].growToSize(numChains, false); 00566 wtsWhenFalse_[i].growToSize(numChains, 0); 00567 wtsWhenTrue_[i].growToSize(numChains, 0); 00568 } 00569 00570 int numClauses = state_->getNumClauses(); 00571 numTrueLits_.growToSize(numClauses); 00572 for (int i = 0; i < numClauses; i++) 00573 { 00574 numTrueLits_[i].growToSize(numChains, 0); 00575 } 00576 }
void MCMC::initNumTrueLits | ( | const int & | numChains | ) | [inline, protected] |
Initializes the number of true lits in each clause in each chain.
numChains | Number of chains for which the initialization should take place. |
Definition at line 596 of file mcmc.h.
References VariableState::getAtomInClause(), VariableState::getGndClause(), VariableState::getGndPred(), VariableState::getNumAtoms(), VariableState::getNumClauses(), VariableState::getNumTrueLits(), GroundPredicate::getTruthValue(), VariableState::incrementNumTrueLits(), numTrueLits_, VariableState::resetMakeBreakCostWatch(), Inference::state_, and truthValues_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
00597 { 00598 // Single chain 00599 if (numChains == 1) state_->resetMakeBreakCostWatch(); 00600 for (int i = 0; i < state_->getNumClauses(); i++) 00601 { 00602 GroundClause* gndClause = state_->getGndClause(i); 00603 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++) 00604 { 00605 const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1; 00606 const bool sense = gndClause->getGroundPredicateSense(j); 00607 if (numChains > 1) 00608 { 00609 for (int c = 0; c < numChains; c++) 00610 { 00611 if (truthValues_[atomIdx][c] == sense) 00612 { 00613 numTrueLits_[i][c]++; 00614 assert(numTrueLits_[i][c] <= state_->getNumAtoms()); 00615 } 00616 } 00617 } 00618 else 00619 { // Single chain 00620 GroundPredicate* gndPred = state_->getGndPred(atomIdx); 00621 if (gndPred->getTruthValue() == sense) 00622 state_->incrementNumTrueLits(i); 00623 assert(state_->getNumTrueLits(i) <= state_->getNumAtoms()); 00624 } 00625 } 00626 } 00627 }
void MCMC::randomInitGndPredsTruthValues | ( | const int & | numChains | ) | [inline, protected] |
Randomly initializes the ground predicate truth values, taking blocks into account.
numChains | Number of chains for which the initialization should take place. |
Definition at line 636 of file mcmc.h.
References genTruthValueForProb(), Domain::getBlockEvidence(), VariableState::getBlockIndex(), VariableState::getDomain(), VariableState::getGndPred(), VariableState::getIndexOfGroundPredicate(), Domain::getNumPredBlocks(), Domain::getRandomPredInBlock(), numChains_, setOthersInBlockToFalse(), GroundPredicate::setTruthValue(), Array< Type >::size(), Inference::state_, and truthValues_.
Referenced by GibbsSampler::init().
00637 { 00638 for (int c = 0; c < numChains; c++) 00639 { 00640 if (mcmcdebug) cout << "Chain " << c << ":" << endl; 00641 // For each block: select one to set to true 00642 for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++) 00643 { 00644 // If evidence atom exists, then all others are false 00645 if (state_->getDomain()->getBlockEvidence(i)) 00646 { 00647 // If 2nd argument is -1, then all are set to false 00648 setOthersInBlockToFalse(c, -1, i); 00649 continue; 00650 } 00651 00652 bool ok = false; 00653 while (!ok) 00654 { 00655 const Predicate* pred = state_->getDomain()->getRandomPredInBlock(i); 00656 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred); 00657 int idx = state_->getIndexOfGroundPredicate(gndPred); 00658 00659 delete gndPred; 00660 delete pred; 00661 00662 if (idx >= 0) 00663 { 00664 // Truth values are stored differently for multi-chain 00665 if (numChains_ > 1) 00666 truthValues_[idx][c] = true; 00667 else 00668 { 00669 GroundPredicate* gndPred = state_->getGndPred(i); 00670 gndPred->setTruthValue(true); 00671 } 00672 setOthersInBlockToFalse(c, idx, i); 00673 ok = true; 00674 } 00675 } 00676 } 00677 00678 // Random tv for all not in blocks 00679 for (int i = 0; i < truthValues_.size(); i++) 00680 { 00681 // Predicates in blocks have been handled above 00682 if (state_->getBlockIndex(i) == -1) 00683 { 00684 bool tv = genTruthValueForProb(0.5); 00685 // Truth values are stored differently for multi-chain 00686 if (numChains_ > 1) 00687 truthValues_[i][c] = tv; 00688 else 00689 { 00690 GroundPredicate* gndPred = state_->getGndPred(i); 00691 gndPred->setTruthValue(tv); 00692 } 00693 if (mcmcdebug) cout << "Pred " << i << " set to " << tv << endl; 00694 } 00695 } 00696 } 00697 }
bool MCMC::genTruthValueForProb | ( | const double & | p | ) | [inline, protected] |
Generates a truth value based on a probability.
p | Number between 0 and 1. With probability p, truth value will be true and with probability 1 - p, it will be false. |
Definition at line 705 of file mcmc.h.
Referenced by SimulatedTempering::infer(), performGibbsStep(), and randomInitGndPredsTruthValues().
00706 { 00707 if (p == 1.0) return true; 00708 if (p == 0.0) return false; 00709 bool r = random() <= p*RAND_MAX; 00710 return r; 00711 }
double MCMC::getProbabilityOfPred | ( | const int & | predIdx, | |
const int & | chainIdx, | |||
const double & | invTemp | |||
) | [inline, protected] |
Computes the probability of a ground predicate in a chain.
predIdx | Index of predicate. | |
chainIdx | Index of chain. | |
invTemp | InvTemp used in simulated tempering. |
Definition at line 722 of file mcmc.h.
References VariableState::getGndPred(), GroundPredicate::getWtWhenFalse(), GroundPredicate::getWtWhenTrue(), numChains_, Inference::state_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by gibbsSampleFromBlock(), SimulatedTempering::infer(), and performGibbsStep().
00724 { 00725 // Different for multi-chain 00726 if (numChains_ > 1) 00727 { 00728 return 1.0 / 00729 ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] - 00730 wtsWhenTrue_[predIdx][chainIdx]) * 00731 invTemp)); 00732 } 00733 else 00734 { 00735 GroundPredicate* gndPred = state_->getGndPred(predIdx); 00736 return 1.0 / 00737 ( 1.0 + exp((gndPred->getWtWhenFalse() - 00738 gndPred->getWtWhenTrue()) * 00739 invTemp)); 00740 } 00741 }
void MCMC::setOthersInBlockToFalse | ( | const int & | chainIdx, | |
const int & | atomIdx, | |||
const int & | blockIdx | |||
) | [inline, protected] |
Sets the truth values of all atoms for a given chain in a block except for the one given.
chainIdx | Index of chain for which atoms are being set. | |
atomIdx | Index of atom in block exempt from being set to false. | |
blockIdx | Index of block whose atoms are set to false. |
Definition at line 751 of file mcmc.h.
References Domain::getBlockSize(), VariableState::getDomain(), VariableState::getIndexOfGroundPredicate(), Domain::getPredInBlock(), Inference::state_, and truthValues_.
Referenced by randomInitGndPredsTruthValues().
00753 { 00754 int blockSize = state_->getDomain()->getBlockSize(blockIdx); 00755 for (int i = 0; i < blockSize; i++) 00756 { 00757 const Predicate* pred = state_->getDomain()->getPredInBlock(i, blockIdx); 00758 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred); 00759 int idx = state_->getIndexOfGroundPredicate(gndPred); 00760 00761 delete gndPred; 00762 delete pred; 00763 00764 // Pred is in the state 00765 if (idx >= 0 && idx != atomIdx) 00766 truthValues_[idx][chainIdx] = false; 00767 } 00768 }
void MCMC::performGibbsStep | ( | const int & | chainIdx, | |
const bool & | burningIn, | |||
GroundPredicateHashArray & | affectedGndPreds, | |||
Array< int > & | affectedGndPredIndices | |||
) | [inline, protected] |
Performs one step of Gibbs sampling in one chain.
chainIdx | Index of chain in which the Gibbs step is performed. | |
burningIn | If true, burning-in is occuring. Otherwise, false. | |
affectedGndPreds | Used to store GroundPredicates which are affected by changing truth values. | |
affectedGndPredIndices | Used to store indices of GroundPredicates which are affected by changing truth values. |
Definition at line 780 of file mcmc.h.
References Inference::clauseTrueCnts_, Array< Type >::clear(), HashArray< Type, HashFn, EqualFn >::clear(), genTruthValueForProb(), Domain::getBlockEvidence(), VariableState::getBlockIndex(), Domain::getBlockSize(), VariableState::getDomain(), VariableState::getGndPred(), VariableState::getIndexOfGroundPredicate(), VariableState::getNumAtoms(), VariableState::getNumClauseGndings(), Domain::getNumPredBlocks(), Domain::getPredInBlock(), getProbabilityOfPred(), GroundPredicate::getTruthValue(), gibbsSampleFromBlock(), gndPredFlippedUpdates(), numChains_, numTrue_, GroundPredicate::setTruthValue(), Inference::state_, Inference::trackClauseTrueCnts_, truthValues_, and updateWtsForGndPreds().
Referenced by GibbsSampler::infer().
00783 { 00784 if (mcmcdebug) cout << "Gibbs step" << endl; 00785 00786 // For each block: select one to set to true 00787 for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++) 00788 { 00789 // If evidence atom exists, then all others stay false 00790 if (state_->getDomain()->getBlockEvidence(i)) continue; 00791 00792 int chosen = gibbsSampleFromBlock(chainIdx, i, 1); 00793 // Truth values are stored differently for multi-chain 00794 bool truthValue; 00795 const Predicate* pred = state_->getDomain()->getPredInBlock(chosen, i); 00796 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred); 00797 int idx = state_->getIndexOfGroundPredicate(gndPred); 00798 00799 delete gndPred; 00800 delete pred; 00801 00802 // If gnd pred in state: 00803 if (idx >= 0) 00804 { 00805 gndPred = state_->getGndPred(idx); 00806 if (numChains_ > 1) truthValue = truthValues_[idx][chainIdx]; 00807 else truthValue = gndPred->getTruthValue(); 00808 // If chosen pred was false, then need to set previous true 00809 // one to false and update wts 00810 if (!truthValue) 00811 { 00812 int blockSize = state_->getDomain()->getBlockSize(i); 00813 for (int j = 0; j < blockSize; j++) 00814 { 00815 // Truth values are stored differently for multi-chain 00816 bool otherTruthValue; 00817 const Predicate* otherPred = 00818 state_->getDomain()->getPredInBlock(j, i); 00819 GroundPredicate* otherGndPred = 00820 new GroundPredicate((Predicate*)otherPred); 00821 int otherIdx = state_->getIndexOfGroundPredicate(gndPred); 00822 00823 delete otherGndPred; 00824 delete otherPred; 00825 00826 // If gnd pred in state: 00827 if (otherIdx >= 0) 00828 { 00829 otherGndPred = state_->getGndPred(otherIdx); 00830 if (numChains_ > 1) 00831 otherTruthValue = truthValues_[otherIdx][chainIdx]; 00832 else 00833 otherTruthValue = otherGndPred->getTruthValue(); 00834 if (otherTruthValue) 00835 { 00836 // Truth values are stored differently for multi-chain 00837 if (numChains_ > 1) 00838 truthValues_[otherIdx][chainIdx] = false; 00839 else 00840 otherGndPred->setTruthValue(false); 00841 00842 affectedGndPreds.clear(); 00843 affectedGndPredIndices.clear(); 00844 gndPredFlippedUpdates(otherIdx, chainIdx, affectedGndPreds, 00845 affectedGndPredIndices); 00846 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00847 chainIdx); 00848 } 00849 } 00850 } 00851 // Set truth value and update wts for chosen atom 00852 // Truth values are stored differently for multi-chain 00853 if (numChains_ > 1) truthValues_[idx][chainIdx] = true; 00854 else gndPred->setTruthValue(true); 00855 affectedGndPreds.clear(); 00856 affectedGndPredIndices.clear(); 00857 gndPredFlippedUpdates(idx, chainIdx, affectedGndPreds, 00858 affectedGndPredIndices); 00859 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00860 chainIdx); 00861 } 00862 // If in actual gibbs sampling phase, track the num of times 00863 // the ground predicate is set to true 00864 if (!burningIn) numTrue_[idx]++; 00865 } 00866 } 00867 00868 // Now go through all preds not in blocks 00869 for (int i = 0; i < state_->getNumAtoms(); i++) 00870 { 00871 // Predicates in blocks have been handled above 00872 if (state_->getBlockIndex(i) >= 0) continue; 00873 00874 if (mcmcdebug) 00875 { 00876 cout << "Chain " << chainIdx << ": Probability of pred " 00877 << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl; 00878 } 00879 00880 bool newAssignment 00881 = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1)); 00882 00883 // Truth values are stored differently for multi-chain 00884 bool truthValue; 00885 GroundPredicate* gndPred = state_->getGndPred(i); 00886 if (numChains_ > 1) truthValue = truthValues_[i][chainIdx]; 00887 else truthValue = gndPred->getTruthValue(); 00888 // If gndPred is flipped, do updates & find all affected gndPreds 00889 if (newAssignment != truthValue) 00890 { 00891 if (mcmcdebug) 00892 { 00893 cout << "Chain " << chainIdx << ": Changing truth value of pred " 00894 << i << " to " << newAssignment << endl; 00895 } 00896 00897 if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment; 00898 else gndPred->setTruthValue(newAssignment); 00899 affectedGndPreds.clear(); 00900 affectedGndPredIndices.clear(); 00901 gndPredFlippedUpdates(i, chainIdx, affectedGndPreds, 00902 affectedGndPredIndices); 00903 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 00904 chainIdx); 00905 } 00906 00907 // If in actual gibbs sampling phase, track the num of times 00908 // the ground predicate is set to true 00909 if (!burningIn && newAssignment) numTrue_[i]++; 00910 } 00911 // If keeping track of true clause groundings 00912 if (!burningIn && trackClauseTrueCnts_) 00913 state_->getNumClauseGndings(clauseTrueCnts_, true); 00914 00915 if (mcmcdebug) cout << "End of Gibbs step" << endl; 00916 }
void MCMC::updateWtsForGndPredsH | ( | GroundPredicateHashArray & | gndPreds, | |
Array< int > & | gndPredIndices, | |||
const int & | chainIdx | |||
) | [inline, protected] |
Updates the weights of affected ground predicates.
These are the ground predicates which are in clauses of predicates which have had their truth value changed.
gndPreds | Ground predicates whose weights should be updated. | |
chainIdx | Index of chain where updating occurs. |
Definition at line 926 of file mcmc.h.
References HVariableState::getClauseCost(), HVariableState::getGndClause(), HVariableState::getNegOccurenceArray(), HVariableState::getNumTrueLits(), HVariableState::getPosOccurenceArray(), GroundClause::getWt(), Inference::hstate_, GroundClause::isHardClause(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), truthValues_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by HMCSAT::infer().
00928 { 00929 if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl; 00930 // for each ground predicate whose MB has changed 00931 for (int g = 0; g < gndPreds.size(); g++) 00932 { 00933 double wtIfNoChange = 0, wtIfInverted = 0, wt; 00934 // Ground clauses in which this pred occurs 00935 Array<int>& negGndClauses = 00936 hstate_->getNegOccurenceArray(gndPredIndices[g] + 1); 00937 Array<int>& posGndClauses = 00938 hstate_->getPosOccurenceArray(gndPredIndices[g] + 1); 00939 00940 int gndClauseIdx; 00941 bool sense; 00942 if (mcmcdebug) 00943 { 00944 cout << "Ground clauses in which pred " << g << " occurs neg.: " 00945 << negGndClauses.size() << endl; 00946 cout << "Ground clauses in which pred " << g << " occurs pos.: " 00947 << posGndClauses.size() << endl; 00948 } 00949 00950 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++) 00951 { 00952 if (i < negGndClauses.size()) 00953 { 00954 gndClauseIdx = negGndClauses[i]; 00955 if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl; 00956 sense = false; 00957 } 00958 else 00959 { 00960 gndClauseIdx = posGndClauses[i - negGndClauses.size()]; 00961 if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl; 00962 sense = true; 00963 } 00964 00965 GroundClause* gndClause = hstate_->getGndClause(gndClauseIdx); 00966 if (gndClause->isHardClause()) 00967 wt = hstate_->getClauseCost(gndClauseIdx); 00968 else 00969 wt = gndClause->getWt(); 00970 // NumTrueLits are stored differently for multi-chain 00971 int numSatLiterals; 00972 if (numChains_ > 1) 00973 numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx]; 00974 else 00975 numSatLiterals = hstate_->getNumTrueLits(gndClauseIdx); 00976 if (numSatLiterals > 1) 00977 { 00978 // Some other literal is making it sat, so it doesn't matter 00979 // if pos. clause. If neg., nothing can be done to unsatisfy it. 00980 if (wt > 0) 00981 { 00982 wtIfNoChange += wt; 00983 wtIfInverted += wt; 00984 } 00985 } 00986 else 00987 if (numSatLiterals == 1) 00988 { 00989 if (wt > 0) wtIfNoChange += wt; 00990 // Truth values are stored differently for multi-chain 00991 bool truthValue; 00992 if (numChains_ > 1) 00993 truthValue = truthValues_[gndPredIndices[g]][chainIdx]; 00994 else 00995 truthValue = gndPreds[g]->getTruthValue(); 00996 // If the current truth value is the same as its sense in gndClause 00997 if (truthValue == sense) 00998 { 00999 // This gndPred is the only one making this function satisfied 01000 if (wt < 0) wtIfInverted += abs(wt); 01001 } 01002 else 01003 { 01004 // Some other literal is making it satisfied 01005 if (wt > 0) wtIfInverted += wt; 01006 } 01007 } 01008 else 01009 if (numSatLiterals == 0) 01010 { 01011 // None satisfy, so when gndPred switch to its negative, it'll satisfy 01012 if (wt > 0) wtIfInverted += wt; 01013 else if (wt < 0) wtIfNoChange += abs(wt); 01014 } 01015 } // for each ground clause that gndPred appears in 01016 01017 if (mcmcdebug) 01018 { 01019 cout << "wtIfNoChange of pred " << g << ": " 01020 << wtIfNoChange << endl; 01021 cout << "wtIfInverted of pred " << g << ": " 01022 << wtIfInverted << endl; 01023 } 01024 01025 // Clause info is stored differently for multi-chain 01026 if (numChains_ > 1) 01027 { 01028 if (truthValues_[gndPredIndices[g]][chainIdx]) 01029 { 01030 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange; 01031 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted; 01032 } 01033 else 01034 { 01035 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange; 01036 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted; 01037 } 01038 } 01039 else 01040 { // Single chain 01041 if (gndPreds[g]->getTruthValue()) 01042 { 01043 gndPreds[g]->setWtWhenTrue(wtIfNoChange); 01044 gndPreds[g]->setWtWhenFalse(wtIfInverted); 01045 } 01046 else 01047 { 01048 gndPreds[g]->setWtWhenFalse(wtIfNoChange); 01049 gndPreds[g]->setWtWhenTrue(wtIfInverted); 01050 } 01051 } 01052 } // for each ground predicate whose MB has changed 01053 if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl; 01054 }
void MCMC::updateWtsForGndPreds | ( | GroundPredicateHashArray & | gndPreds, | |
Array< int > & | gndPredIndices, | |||
const int & | chainIdx | |||
) | [inline, protected] |
Updates the weights of affected ground predicates.
These are the ground predicates which are in clauses of predicates which have had their truth value changed.
gndPreds | Ground predicates whose weights should be updated. | |
chainIdx | Index of chain where updating occurs. |
Definition at line 1067 of file mcmc.h.
References VariableState::getClauseCost(), VariableState::getGndClause(), VariableState::getNegOccurenceArray(), VariableState::getNumTrueLits(), VariableState::getPosOccurenceArray(), GroundClause::getWt(), GroundClause::isHardClause(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), Inference::state_, truthValues_, wtsWhenFalse_, and wtsWhenTrue_.
Referenced by SimulatedTempering::infer(), MCSAT::infer(), GibbsSampler::infer(), and performGibbsStep().
01070 { 01071 if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl; 01072 // for each ground predicate whose MB has changed 01073 for (int g = 0; g < gndPreds.size(); g++) 01074 { 01075 double wtIfNoChange = 0, wtIfInverted = 0, wt; 01076 // Ground clauses in which this pred occurs 01077 Array<int>& negGndClauses = 01078 state_->getNegOccurenceArray(gndPredIndices[g] + 1); 01079 Array<int>& posGndClauses = 01080 state_->getPosOccurenceArray(gndPredIndices[g] + 1); 01081 int gndClauseIdx; 01082 bool sense; 01083 01084 if (mcmcdebug) 01085 { 01086 cout << "Ground clauses in which pred " << g << " occurs neg.: " 01087 << negGndClauses.size() << endl; 01088 cout << "Ground clauses in which pred " << g << " occurs pos.: " 01089 << posGndClauses.size() << endl; 01090 } 01091 01092 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++) 01093 { 01094 if (i < negGndClauses.size()) 01095 { 01096 gndClauseIdx = negGndClauses[i]; 01097 if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl; 01098 sense = false; 01099 } 01100 else 01101 { 01102 gndClauseIdx = posGndClauses[i - negGndClauses.size()]; 01103 if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl; 01104 sense = true; 01105 } 01106 01107 GroundClause* gndClause = state_->getGndClause(gndClauseIdx); 01108 if (gndClause->isHardClause()) 01109 wt = state_->getClauseCost(gndClauseIdx); 01110 else 01111 wt = gndClause->getWt(); 01112 // NumTrueLits are stored differently for multi-chain 01113 int numSatLiterals; 01114 if (numChains_ > 1) 01115 numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx]; 01116 else 01117 numSatLiterals = state_->getNumTrueLits(gndClauseIdx); 01118 01119 if (numSatLiterals > 1) 01120 { 01121 // Some other literal is making it sat, so it doesn't matter 01122 // if pos. clause. If neg., nothing can be done to unsatisfy it. 01123 if (wt > 0) 01124 { 01125 wtIfNoChange += wt; 01126 wtIfInverted += wt; 01127 } 01128 } 01129 else 01130 if (numSatLiterals == 1) 01131 { 01132 if (wt > 0) wtIfNoChange += wt; 01133 // Truth values are stored differently for multi-chain 01134 bool truthValue; 01135 if (numChains_ > 1) 01136 truthValue = truthValues_[gndPredIndices[g]][chainIdx]; 01137 else 01138 truthValue = gndPreds[g]->getTruthValue(); 01139 // If the current truth value is the same as its sense in gndClause 01140 if (truthValue == sense) 01141 { 01142 // This gndPred is the only one making this function satisfied 01143 if (wt < 0) wtIfInverted += abs(wt); 01144 } 01145 else 01146 { 01147 // Some other literal is making it satisfied 01148 if (wt > 0) wtIfInverted += wt; 01149 } 01150 } 01151 else 01152 if (numSatLiterals == 0) 01153 { 01154 // None satisfy, so when gndPred switch to its negative, it'll satisfy 01155 if (wt > 0) wtIfInverted += wt; 01156 else if (wt < 0) wtIfNoChange += abs(wt); 01157 } 01158 } // for each ground clause that gndPred appears in 01159 01160 if (mcmcdebug) 01161 { 01162 cout << "wtIfNoChange of pred " << g << ": " 01163 << wtIfNoChange << endl; 01164 cout << "wtIfInverted of pred " << g << ": " 01165 << wtIfInverted << endl; 01166 } 01167 01168 // Clause info is stored differently for multi-chain 01169 if (numChains_ > 1) 01170 { 01171 if (truthValues_[gndPredIndices[g]][chainIdx]) 01172 { 01173 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange; 01174 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted; 01175 } 01176 else 01177 { 01178 wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange; 01179 wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted; 01180 } 01181 } 01182 else 01183 { // Single chain 01184 if (gndPreds[g]->getTruthValue()) 01185 { 01186 gndPreds[g]->setWtWhenTrue(wtIfNoChange); 01187 gndPreds[g]->setWtWhenFalse(wtIfInverted); 01188 } 01189 else 01190 { 01191 gndPreds[g]->setWtWhenFalse(wtIfNoChange); 01192 gndPreds[g]->setWtWhenTrue(wtIfInverted); 01193 } 01194 } 01195 } // for each ground predicate whose MB has changed 01196 if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl; 01197 }
int MCMC::gibbsSampleFromBlock | ( | const int & | chainIdx, | |
const int & | blockIndex, | |||
const double & | invTemp | |||
) | [inline, protected] |
Chooses an atom from a block according to their probabilities.
chainIdx | Index of chain. | |
block | Block of predicate indices from which one is chosen. | |
invTemp | InvTemp used in simulated tempering. |
Definition at line 1208 of file mcmc.h.
References Array< Type >::append(), Domain::getBlockSize(), VariableState::getDomain(), VariableState::getIndexOfGroundPredicate(), Domain::getPredInBlock(), getProbabilityOfPred(), and Inference::state_.
Referenced by SimulatedTempering::infer(), and performGibbsStep().
01210 { 01211 Array<double> numerators; 01212 double denominator = 0; 01213 01214 int blockSize = state_->getDomain()->getBlockSize(blockIndex); 01215 for (int i = 0; i < blockSize; i++) 01216 { 01217 const Predicate* pred = 01218 state_->getDomain()->getPredInBlock(i, blockIndex); 01219 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred); 01220 int idx = state_->getIndexOfGroundPredicate(gndPred); 01221 01222 delete gndPred; 01223 delete pred; 01224 01225 // Prob is 0 if atom not in state 01226 double prob = 0.0; 01227 // Pred is in the state; otherwise, prob is zero 01228 if (idx >= 0) 01229 prob = getProbabilityOfPred(idx, chainIdx, invTemp); 01230 01231 numerators.append(prob); 01232 denominator += prob; 01233 } 01234 01235 double r = random(); 01236 double numSum = 0.0; 01237 for (int i = 0; i < blockSize; i++) 01238 { 01239 numSum += numerators[i]; 01240 if (r < ((numSum / denominator) * RAND_MAX)) 01241 { 01242 return i; 01243 } 01244 } 01245 return blockSize - 1; 01246 }
void MCMC::gndPredFlippedUpdates | ( | const int & | gndPredIdx, | |
const int & | chainIdx, | |||
GroundPredicateHashArray & | affectedGndPreds, | |||
Array< int > & | affectedGndPredIndices | |||
) | [inline, protected] |
Updates information when a ground predicate is flipped and retrieves the Markov blanket of the ground predicate.
gndPredIdx | Index of ground pred which was flipped. | |
chainIdx | Index of chain in which the flipping occured. | |
affectedGndPreds | Holds the Markov blanket of the ground predicate. |
Definition at line 1256 of file mcmc.h.
References Array< Type >::append(), HashArray< Type, HashFn, EqualFn >::append(), VariableState::decrementNumTrueLits(), VariableState::getGndClause(), VariableState::getGndPred(), VariableState::getGndPredHashArrayPtr(), GroundClause::getGroundPredicate(), GroundClause::getGroundPredicateIndex(), VariableState::getNegOccurenceArray(), VariableState::getNumAtoms(), GroundClause::getNumGroundPredicates(), VariableState::getPosOccurenceArray(), GroundPredicate::getTruthValue(), VariableState::incrementNumTrueLits(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), Inference::state_, and truthValues_.
Referenced by SimulatedTempering::infer(), and performGibbsStep().
01259 { 01260 if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl; 01261 int numAtoms = state_->getNumAtoms(); 01262 GroundPredicate* gndPred = state_->getGndPred(gndPredIdx); 01263 affectedGndPreds.append(gndPred, numAtoms); 01264 affectedGndPredIndices.append(gndPredIdx); 01265 assert(affectedGndPreds.size() <= numAtoms); 01266 01267 Array<int>& negGndClauses = 01268 state_->getNegOccurenceArray(gndPredIdx + 1); 01269 Array<int>& posGndClauses = 01270 state_->getPosOccurenceArray(gndPredIdx + 1); 01271 int gndClauseIdx; 01272 GroundClause* gndClause; 01273 bool sense; 01274 01275 // Find the Markov blanket of this ground predicate 01276 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++) 01277 { 01278 if (i < negGndClauses.size()) 01279 { 01280 gndClauseIdx = negGndClauses[i]; 01281 sense = false; 01282 } 01283 else 01284 { 01285 gndClauseIdx = posGndClauses[i - negGndClauses.size()]; 01286 sense = true; 01287 } 01288 gndClause = state_->getGndClause(gndClauseIdx); 01289 01290 // Different for multi-chain 01291 if (numChains_ > 1) 01292 { 01293 if (truthValues_[gndPredIdx][chainIdx] == sense) 01294 numTrueLits_[gndClauseIdx][chainIdx]++; 01295 else 01296 numTrueLits_[gndClauseIdx][chainIdx]--; 01297 } 01298 else 01299 { // Single chain 01300 if (gndPred->getTruthValue() == sense) 01301 state_->incrementNumTrueLits(gndClauseIdx); 01302 else 01303 state_->decrementNumTrueLits(gndClauseIdx); 01304 } 01305 01306 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++) 01307 { 01308 const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr(); 01309 GroundPredicate* pred = 01310 (GroundPredicate*)gndClause->getGroundPredicate(j, 01311 (GroundPredicateHashArray*)gpha); 01312 affectedGndPreds.append(pred, numAtoms); 01313 affectedGndPredIndices.append( 01314 abs(gndClause->getGroundPredicateIndex(j)) - 1); 01315 assert(affectedGndPreds.size() <= numAtoms); 01316 } 01317 } 01318 if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl; 01319 }
void MCMC::saveLowStateToChain | ( | const int & | chainIdx | ) | [inline, protected] |
The atom assignment in the best state is saved to a chain in the ground predicates.
chainIdx | Index of chain to which the atom assigment is saved |
Definition at line 1335 of file mcmc.h.
References VariableState::getNumAtoms(), VariableState::getValueOfLowAtom(), Inference::state_, and truthValues_.
Referenced by SimulatedTempering::init(), and GibbsSampler::init().
01336 { 01337 for (int i = 0; i < state_->getNumAtoms(); i++) 01338 truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1); 01339 }
void MCMC::setMCMCParameters | ( | MCMCParams * | params | ) | [inline, protected] |
Sets the user-set parameters for this MCMC algorithm.
params | MCMC parameters for this algorithm. |
Definition at line 1346 of file mcmc.h.
References MCMCParams::burnMaxSteps, burnMaxSteps_, MCMCParams::burnMinSteps, burnMinSteps_, MCMCParams::maxSeconds, maxSeconds_, MCMCParams::maxSteps, maxSteps_, MCMCParams::minSteps, minSteps_, MCMCParams::numChains, and numChains_.
01347 { 01348 // User-set parameters 01349 numChains_ = params->numChains; 01350 burnMinSteps_ = params->burnMinSteps; 01351 burnMaxSteps_ = params->burnMaxSteps; 01352 minSteps_ = params->minSteps; 01353 maxSteps_ = params->maxSteps; 01354 maxSeconds_ = params->maxSeconds; 01355 }
void MCMC::scaleSamples | ( | double | factor | ) | [inline, protected, virtual] |
Increase or decrease the number of MCMC samples by a multiplicative factor.
(For MaxWalkSAT, this could perhaps change the number of flips or something that trades off speed and accuracy.)
Reimplemented from Inference.
Definition at line 1357 of file mcmc.h.
References maxSteps_, and minSteps_.
01358 { 01359 minSteps_ = (int)(minSteps_ * factor); 01360 maxSteps_ = (int)(maxSteps_ * factor); 01361 }