MCMC Class Reference

Superclass of all MCMC inference algorithms. More...

#include <mcmc.h>

Inheritance diagram for MCMC:

Inference GibbsSampler HMCSAT MCSAT SimulatedTempering List of all members.

Public Member Functions

 MCMC (VariableState *state, long int seed, const bool &trackClauseTrueCnts, MCMCParams *params, Array< Array< Predicate * > * > *queryFormulas=NULL)
 Constructor.
 MCMC (HVariableState *state, long int seed, const bool &trackClauseTrueCnts, MCMCParams *params)
 ~MCMC ()
 Destructor.
double computeHybridClauseValue (int clauseIdx, int c)
double HybridClauseContPartValue (int contClauseIdx, int c)
double HybridClauseDisPartValue (int contClauseIdx, int c)
void updateDisPredValue (int predIdx, int chainIdx, bool updateValue)
void updateProposalContValue (int contPredIdx, int chainIdx)
double getProbabilityOfPredH (const int &predIdx, const int &chainIdx, const double &invTemp)
virtual void printNetwork (ostream &out)
 Prints out the network.
void printProbabilities (ostream &out)
 Prints the probabilities of each predicate to a stream.
void getChangedPreds (vector< string > &changedPreds, vector< float > &probs, vector< float > &oldProbs, const float &probDelta)
 Puts the predicates whose probability has changed with respect to the reference vector oldProbs by more than probDelta in string form and the corresponding probabilities of each predicate in two vectors.
double getProbabilityH (GroundPredicate *const &gndPred)
double getProbability (GroundPredicate *const &gndPred)
 Gets the probability of a ground predicate.
void printTruePreds (ostream &out)
 Prints each predicate with a probability of 0.5 or greater to a stream.
void printTruePredsH (ostream &out)

Public Attributes

Array< Array< double > > truthValuesCont_
Array< Array< bool > > hybridClauseDisPartValueMCMC_
Array< Array< double > > hybridClauseContPartValueMCMC_

Protected Member Functions

void initTruthValuesAndWts (const int &numChains)
 Initializes truth values and weights in the ground preds.
void initNumTrue ()
 Initializes structure for holding number of times a predicate was set to true.
void initNumTrueLits (const int &numChains)
 Initializes the number of true lits in each clause in each chain.
void randomInitGndPredsTruthValues (const int &numChains)
 Randomly initializes the ground predicate truth values, taking blocks into account.
bool genTruthValueForProb (const double &p)
 Generates a truth value based on a probability.
double getProbabilityOfPred (const int &predIdx, const int &chainIdx, const double &invTemp)
 Computes the probability of a ground predicate in a chain.
void setOthersInBlockToFalse (const int &chainIdx, const int &atomIdx, const int &blockIdx)
 Sets the truth values of all atoms for a given chain in a block except for the one given.
void performGibbsStep (const int &chainIdx, const bool &burningIn, GroundPredicateHashArray &affectedGndPreds, Array< int > &affectedGndPredIndices)
 Performs one step of Gibbs sampling in one chain.
void updateWtsForGndPredsH (GroundPredicateHashArray &gndPreds, Array< int > &gndPredIndices, const int &chainIdx)
 Updates the weights of affected ground predicates.
void updateWtsForGndPreds (GroundPredicateHashArray &gndPreds, Array< int > &gndPredIndices, const int &chainIdx)
 Updates the weights of affected ground predicates.
int gibbsSampleFromBlock (const int &chainIdx, const int &blockIndex, const double &invTemp)
 Chooses an atom from a block according to their probabilities.
void gndPredFlippedUpdates (const int &gndPredIdx, const int &chainIdx, GroundPredicateHashArray &affectedGndPreds, Array< int > &affectedGndPredIndices)
 Updates information when a ground predicate is flipped and retrieves the Markov blanket of the ground predicate.
double getProbTrue (const int &predIdx) const
void setProbTrue (const int &predIdx, const double &p)
void saveLowStateToChain (const int &chainIdx)
 The atom assignment in the best state is saved to a chain in the ground predicates.
void setMCMCParameters (MCMCParams *params)
 Sets the user-set parameters for this MCMC algorithm.
void scaleSamples (double factor)
 Increase or decrease the number of MCMC samples by a multiplicative factor.

Protected Attributes

int numChains_
int burnMinSteps_
int burnMaxSteps_
int minSteps_
int maxSteps_
int maxSeconds_
Array< Array< bool > > truthValues_
Array< Array< double > > wtsWhenFalse_
Array< Array< double > > wtsWhenTrue_
Array< double > numTrue_
Array< Array< int > > numTrueLits_

Detailed Description

Superclass of all MCMC inference algorithms.

This class does not implement all pure virtual functions of Inference and is thus an abstract class.

Definition at line 80 of file mcmc.h.


Constructor & Destructor Documentation

MCMC::MCMC ( VariableState state,
long int  seed,
const bool &  trackClauseTrueCnts,
MCMCParams params,
Array< Array< Predicate * > * > *  queryFormulas = NULL 
) [inline]

Constructor.

User-set parameters are set.

See also:
Inference::Constructor(VariableState*, long int, const bool&, GibbsParams*)

Definition at line 90 of file mcmc.h.

References MCMCParams::burnMaxSteps, burnMaxSteps_, MCMCParams::burnMinSteps, burnMinSteps_, MCMCParams::maxSeconds, maxSeconds_, MCMCParams::maxSteps, maxSteps_, MCMCParams::minSteps, minSteps_, MCMCParams::numChains, and numChains_.

00092     : Inference(state, seed, trackClauseTrueCnts, queryFormulas)
00093   {
00094       // User-set parameters
00095     numChains_ = params->numChains;
00096     burnMinSteps_ = params->burnMinSteps;
00097     burnMaxSteps_ = params->burnMaxSteps;
00098     minSteps_ = params->minSteps;
00099     maxSteps_ = params->maxSteps;
00100     maxSeconds_ = params->maxSeconds;
00101   }


Member Function Documentation

void MCMC::getChangedPreds ( vector< string > &  changedPreds,
vector< float > &  probs,
vector< float > &  oldProbs,
const float &  probDelta 
) [inline, virtual]

Puts the predicates whose probability has changed with respect to the reference vector oldProbs by more than probDelta in string form and the corresponding probabilities of each predicate in two vectors.

Parameters:
changedPreds Predicates whose probability have changed more than probDelta are put here.
probs The probabilities corresponding to the predicates in changedPreds are put here.
oldProbs Reference probabilities for checking for changes.
probDelta If probability of an atom has changed more than this value, then it is considered to have changed.

Implements Inference.

Definition at line 468 of file mcmc.h.

References VariableState::getNumAtoms(), getProbTrue(), VariableState::printGndPred(), and Inference::state_.

00470   {
00471     changedPreds.clear();
00472     probs.clear();
00473     int numAtoms = state_->getNumAtoms();
00474       // Atoms may have been added to the state, previous prob. was 0
00475     oldProbs.resize(numAtoms, 0);
00476     for (int i = 0; i < numAtoms; i++)
00477     {
00478       double prob = getProbTrue(i);
00479       if (abs(prob - oldProbs[i]) > probDelta)
00480       {
00481           // Truth value has changed: Store new value (not smoothed) in oldProbs
00482           // and add to two return vectors
00483         oldProbs[i] = prob;
00484           // Uniform smoothing
00485         prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00486         ostringstream oss(ostringstream::out);
00487         state_->printGndPred(i, oss);
00488         changedPreds.push_back(oss.str());
00489         probs.push_back(prob);
00490       }
00491     }    
00492   }

double MCMC::getProbability ( GroundPredicate *const &  gndPred  )  [inline, virtual]

Gets the probability of a ground predicate.

Parameters:
gndPred GroundPredicate whose probability is being retrieved.
Returns:
Probability of gndPred if present in state, otherwise 0.

Implements Inference.

Definition at line 510 of file mcmc.h.

References VariableState::getGndPredIndex(), getProbTrue(), and Inference::state_.

00511   {
00512     int idx = state_->getGndPredIndex(gndPred);
00513     double prob = 0.0;
00514     if (idx >= 0) prob = getProbTrue(idx);
00515       // Uniform smoothing
00516     return (prob*10000 + 1/2.0)/(10000 + 1.0);
00517   }

void MCMC::initTruthValuesAndWts ( const int &  numChains  )  [inline, protected]

Initializes truth values and weights in the ground preds.

Parameters:
numChains Number of chains for which the values should be initialized.
startIdx All predicates with index greater than or equal to this will be initialized.

Definition at line 557 of file mcmc.h.

References VariableState::getNumAtoms(), VariableState::getNumClauses(), Array< Type >::growToSize(), numTrueLits_, Inference::state_, truthValues_, wtsWhenFalse_, and wtsWhenTrue_.

Referenced by SimulatedTempering::init(), and GibbsSampler::init().

00558   {
00559     int numPreds = state_->getNumAtoms();
00560     truthValues_.growToSize(numPreds);
00561     wtsWhenFalse_.growToSize(numPreds);
00562     wtsWhenTrue_.growToSize(numPreds);
00563     for (int i = 0; i < numPreds; i++)
00564     {
00565       truthValues_[i].growToSize(numChains, false);
00566       wtsWhenFalse_[i].growToSize(numChains, 0);
00567       wtsWhenTrue_[i].growToSize(numChains, 0);
00568     }
00569     
00570     int numClauses = state_->getNumClauses();
00571     numTrueLits_.growToSize(numClauses);
00572     for (int i = 0; i < numClauses; i++)
00573     {
00574       numTrueLits_[i].growToSize(numChains, 0);
00575     }
00576   }

void MCMC::initNumTrueLits ( const int &  numChains  )  [inline, protected]

Initializes the number of true lits in each clause in each chain.

Parameters:
numChains Number of chains for which the initialization should take place.

Definition at line 596 of file mcmc.h.

References VariableState::getAtomInClause(), VariableState::getGndClause(), VariableState::getGndPred(), VariableState::getNumAtoms(), VariableState::getNumClauses(), VariableState::getNumTrueLits(), GroundPredicate::getTruthValue(), VariableState::incrementNumTrueLits(), numTrueLits_, VariableState::resetMakeBreakCostWatch(), Inference::state_, and truthValues_.

Referenced by SimulatedTempering::init(), and GibbsSampler::init().

00597   {
00598       // Single chain
00599     if (numChains == 1) state_->resetMakeBreakCostWatch();
00600     for (int i = 0; i < state_->getNumClauses(); i++)
00601     {
00602       GroundClause* gndClause = state_->getGndClause(i);
00603       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00604       {
00605         const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1;
00606         const bool sense = gndClause->getGroundPredicateSense(j);
00607         if (numChains > 1)
00608         {
00609           for (int c = 0; c < numChains; c++)
00610           {
00611             if (truthValues_[atomIdx][c] == sense)
00612             {
00613               numTrueLits_[i][c]++;
00614               assert(numTrueLits_[i][c] <= state_->getNumAtoms());
00615             }
00616           }
00617         }
00618         else
00619         { // Single chain
00620           GroundPredicate* gndPred = state_->getGndPred(atomIdx);
00621           if (gndPred->getTruthValue() == sense)
00622             state_->incrementNumTrueLits(i);
00623           assert(state_->getNumTrueLits(i) <= state_->getNumAtoms());
00624         }        
00625       }
00626     }
00627   }

void MCMC::randomInitGndPredsTruthValues ( const int &  numChains  )  [inline, protected]

Randomly initializes the ground predicate truth values, taking blocks into account.

Parameters:
numChains Number of chains for which the initialization should take place.

Definition at line 636 of file mcmc.h.

References genTruthValueForProb(), Domain::getBlockEvidence(), VariableState::getBlockIndex(), VariableState::getDomain(), VariableState::getGndPred(), VariableState::getIndexOfGroundPredicate(), Domain::getNumPredBlocks(), Domain::getRandomPredInBlock(), numChains_, setOthersInBlockToFalse(), GroundPredicate::setTruthValue(), Array< Type >::size(), Inference::state_, and truthValues_.

Referenced by GibbsSampler::init().

00637   {
00638     for (int c = 0; c < numChains; c++)
00639     {
00640       if (mcmcdebug) cout << "Chain " << c << ":" << endl;
00641         // For each block: select one to set to true
00642       for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++)
00643       {
00644           // If evidence atom exists, then all others are false
00645         if (state_->getDomain()->getBlockEvidence(i))
00646         {
00647             // If 2nd argument is -1, then all are set to false
00648           setOthersInBlockToFalse(c, -1, i);
00649           continue;
00650         }
00651 
00652         bool ok = false;
00653         while (!ok)
00654         {
00655           const Predicate* pred = state_->getDomain()->getRandomPredInBlock(i);
00656           GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00657           int idx = state_->getIndexOfGroundPredicate(gndPred);
00658             
00659           delete gndPred;
00660           delete pred;
00661 
00662           if (idx >= 0)
00663           {
00664               // Truth values are stored differently for multi-chain
00665             if (numChains_ > 1)
00666               truthValues_[idx][c] = true;
00667             else
00668             {
00669               GroundPredicate* gndPred = state_->getGndPred(i);
00670               gndPred->setTruthValue(true);
00671             }
00672             setOthersInBlockToFalse(c, idx, i);
00673             ok = true;
00674           }
00675         }
00676       }
00677       
00678         // Random tv for all not in blocks
00679       for (int i = 0; i < truthValues_.size(); i++)
00680       {
00681           // Predicates in blocks have been handled above
00682         if (state_->getBlockIndex(i) == -1)
00683         {
00684           bool tv = genTruthValueForProb(0.5);
00685             // Truth values are stored differently for multi-chain
00686           if (numChains_ > 1)
00687             truthValues_[i][c] = tv;
00688           else
00689           {
00690             GroundPredicate* gndPred = state_->getGndPred(i);
00691             gndPred->setTruthValue(tv);
00692           }
00693           if (mcmcdebug) cout << "Pred " << i << " set to " << tv << endl;
00694         }
00695       }
00696     }
00697   }

bool MCMC::genTruthValueForProb ( const double &  p  )  [inline, protected]

Generates a truth value based on a probability.

Parameters:
p Number between 0 and 1. With probability p, truth value will be true and with probability 1 - p, it will be false.

Definition at line 705 of file mcmc.h.

Referenced by SimulatedTempering::infer(), performGibbsStep(), and randomInitGndPredsTruthValues().

00706   {
00707     if (p == 1.0) return true;
00708     if (p == 0.0) return false;
00709     bool r = random() <= p*RAND_MAX;
00710     return r;
00711   }

double MCMC::getProbabilityOfPred ( const int &  predIdx,
const int &  chainIdx,
const double &  invTemp 
) [inline, protected]

Computes the probability of a ground predicate in a chain.

Parameters:
predIdx Index of predicate.
chainIdx Index of chain.
invTemp InvTemp used in simulated tempering.
Returns:
Probability of predicate.

Definition at line 722 of file mcmc.h.

References VariableState::getGndPred(), GroundPredicate::getWtWhenFalse(), GroundPredicate::getWtWhenTrue(), numChains_, Inference::state_, wtsWhenFalse_, and wtsWhenTrue_.

Referenced by gibbsSampleFromBlock(), SimulatedTempering::infer(), and performGibbsStep().

00724   {
00725       // Different for multi-chain
00726     if (numChains_ > 1)
00727     {
00728       return 1.0 /
00729              ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] - 
00730                           wtsWhenTrue_[predIdx][chainIdx]) *
00731                           invTemp));      
00732     }
00733     else
00734     {
00735       GroundPredicate* gndPred = state_->getGndPred(predIdx);
00736       return 1.0 /
00737              ( 1.0 + exp((gndPred->getWtWhenFalse() - 
00738                           gndPred->getWtWhenTrue()) *
00739                           invTemp));
00740     }
00741   }

void MCMC::setOthersInBlockToFalse ( const int &  chainIdx,
const int &  atomIdx,
const int &  blockIdx 
) [inline, protected]

Sets the truth values of all atoms for a given chain in a block except for the one given.

Parameters:
chainIdx Index of chain for which atoms are being set.
atomIdx Index of atom in block exempt from being set to false.
blockIdx Index of block whose atoms are set to false.

Definition at line 751 of file mcmc.h.

References Domain::getBlockSize(), VariableState::getDomain(), VariableState::getIndexOfGroundPredicate(), Domain::getPredInBlock(), Inference::state_, and truthValues_.

Referenced by randomInitGndPredsTruthValues().

00753   {
00754     int blockSize = state_->getDomain()->getBlockSize(blockIdx);
00755     for (int i = 0; i < blockSize; i++)
00756     {
00757       const Predicate* pred = state_->getDomain()->getPredInBlock(i, blockIdx);
00758       GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00759       int idx = state_->getIndexOfGroundPredicate(gndPred);
00760 
00761       delete gndPred;
00762       delete pred;
00763 
00764         // Pred is in the state
00765       if (idx >= 0 && idx != atomIdx)
00766         truthValues_[idx][chainIdx] = false;
00767     }
00768   }

void MCMC::performGibbsStep ( const int &  chainIdx,
const bool &  burningIn,
GroundPredicateHashArray affectedGndPreds,
Array< int > &  affectedGndPredIndices 
) [inline, protected]

Performs one step of Gibbs sampling in one chain.

Parameters:
chainIdx Index of chain in which the Gibbs step is performed.
burningIn If true, burning-in is occuring. Otherwise, false.
affectedGndPreds Used to store GroundPredicates which are affected by changing truth values.
affectedGndPredIndices Used to store indices of GroundPredicates which are affected by changing truth values.

Definition at line 780 of file mcmc.h.

References Inference::clauseTrueCnts_, Array< Type >::clear(), HashArray< Type, HashFn, EqualFn >::clear(), genTruthValueForProb(), Domain::getBlockEvidence(), VariableState::getBlockIndex(), Domain::getBlockSize(), VariableState::getDomain(), VariableState::getGndPred(), VariableState::getIndexOfGroundPredicate(), VariableState::getNumAtoms(), VariableState::getNumClauseGndings(), Domain::getNumPredBlocks(), Domain::getPredInBlock(), getProbabilityOfPred(), GroundPredicate::getTruthValue(), gibbsSampleFromBlock(), gndPredFlippedUpdates(), numChains_, numTrue_, GroundPredicate::setTruthValue(), Inference::state_, Inference::trackClauseTrueCnts_, truthValues_, and updateWtsForGndPreds().

Referenced by GibbsSampler::infer().

00783   {
00784     if (mcmcdebug) cout << "Gibbs step" << endl;
00785 
00786       // For each block: select one to set to true
00787     for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++)
00788     {
00789         // If evidence atom exists, then all others stay false
00790       if (state_->getDomain()->getBlockEvidence(i)) continue;
00791 
00792       int chosen = gibbsSampleFromBlock(chainIdx, i, 1);
00793         // Truth values are stored differently for multi-chain
00794       bool truthValue;
00795       const Predicate* pred = state_->getDomain()->getPredInBlock(chosen, i);
00796       GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00797       int idx = state_->getIndexOfGroundPredicate(gndPred);
00798 
00799       delete gndPred;
00800       delete pred;
00801       
00802         // If gnd pred in state:
00803       if (idx >= 0)
00804       {
00805         gndPred = state_->getGndPred(idx);
00806         if (numChains_ > 1) truthValue = truthValues_[idx][chainIdx];
00807         else truthValue = gndPred->getTruthValue();
00808           // If chosen pred was false, then need to set previous true
00809           // one to false and update wts
00810         if (!truthValue)
00811         {
00812           int blockSize = state_->getDomain()->getBlockSize(i);
00813           for (int j = 0; j < blockSize; j++)
00814           {
00815               // Truth values are stored differently for multi-chain
00816             bool otherTruthValue;
00817             const Predicate* otherPred = 
00818               state_->getDomain()->getPredInBlock(j, i);
00819             GroundPredicate* otherGndPred =
00820               new GroundPredicate((Predicate*)otherPred);
00821             int otherIdx = state_->getIndexOfGroundPredicate(gndPred);
00822 
00823             delete otherGndPred;
00824             delete otherPred;
00825       
00826               // If gnd pred in state:
00827             if (otherIdx >= 0)
00828             {
00829               otherGndPred = state_->getGndPred(otherIdx);
00830               if (numChains_ > 1)
00831                 otherTruthValue = truthValues_[otherIdx][chainIdx];
00832               else
00833                 otherTruthValue = otherGndPred->getTruthValue();
00834               if (otherTruthValue)
00835               {
00836                   // Truth values are stored differently for multi-chain
00837                 if (numChains_ > 1)
00838                   truthValues_[otherIdx][chainIdx] = false;
00839                 else
00840                   otherGndPred->setTruthValue(false);
00841               
00842                 affectedGndPreds.clear();
00843                 affectedGndPredIndices.clear();
00844                 gndPredFlippedUpdates(otherIdx, chainIdx, affectedGndPreds,
00845                                       affectedGndPredIndices);
00846                 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00847                                      chainIdx);
00848               }
00849             }
00850           }
00851             // Set truth value and update wts for chosen atom
00852             // Truth values are stored differently for multi-chain
00853           if (numChains_ > 1) truthValues_[idx][chainIdx] = true;
00854           else gndPred->setTruthValue(true);
00855           affectedGndPreds.clear();
00856           affectedGndPredIndices.clear();
00857           gndPredFlippedUpdates(idx, chainIdx, affectedGndPreds,
00858                                 affectedGndPredIndices);
00859           updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00860                                chainIdx);
00861         }
00862           // If in actual gibbs sampling phase, track the num of times
00863           // the ground predicate is set to true
00864         if (!burningIn) numTrue_[idx]++;
00865       }
00866     }
00867 
00868       // Now go through all preds not in blocks
00869     for (int i = 0; i < state_->getNumAtoms(); i++)
00870     {
00871         // Predicates in blocks have been handled above
00872       if (state_->getBlockIndex(i) >= 0) continue;
00873 
00874       if (mcmcdebug)
00875       {
00876         cout << "Chain " << chainIdx << ": Probability of pred "
00877              << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl;
00878       }
00879       
00880       bool newAssignment
00881         = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1));
00882 
00883         // Truth values are stored differently for multi-chain
00884       bool truthValue;
00885       GroundPredicate* gndPred = state_->getGndPred(i);
00886       if (numChains_ > 1) truthValue = truthValues_[i][chainIdx];
00887       else truthValue = gndPred->getTruthValue();
00888         // If gndPred is flipped, do updates & find all affected gndPreds
00889       if (newAssignment != truthValue)
00890       {
00891         if (mcmcdebug)
00892         {
00893           cout << "Chain " << chainIdx << ": Changing truth value of pred "
00894                << i << " to " << newAssignment << endl;
00895         }
00896         
00897         if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment;
00898         else gndPred->setTruthValue(newAssignment);
00899         affectedGndPreds.clear();
00900         affectedGndPredIndices.clear();
00901         gndPredFlippedUpdates(i, chainIdx, affectedGndPreds,
00902                               affectedGndPredIndices);
00903         updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00904                              chainIdx);
00905       }
00906 
00907         // If in actual gibbs sampling phase, track the num of times
00908         // the ground predicate is set to true
00909       if (!burningIn && newAssignment) numTrue_[i]++;
00910     }
00911       // If keeping track of true clause groundings
00912     if (!burningIn && trackClauseTrueCnts_)
00913       state_->getNumClauseGndings(clauseTrueCnts_, true);
00914 
00915     if (mcmcdebug) cout << "End of Gibbs step" << endl;
00916   }

void MCMC::updateWtsForGndPredsH ( GroundPredicateHashArray gndPreds,
Array< int > &  gndPredIndices,
const int &  chainIdx 
) [inline, protected]

Updates the weights of affected ground predicates.

These are the ground predicates which are in clauses of predicates which have had their truth value changed.

Parameters:
gndPreds Ground predicates whose weights should be updated.
chainIdx Index of chain where updating occurs.

Definition at line 926 of file mcmc.h.

References HVariableState::getClauseCost(), HVariableState::getGndClause(), HVariableState::getNegOccurenceArray(), HVariableState::getNumTrueLits(), HVariableState::getPosOccurenceArray(), GroundClause::getWt(), Inference::hstate_, GroundClause::isHardClause(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), truthValues_, wtsWhenFalse_, and wtsWhenTrue_.

Referenced by HMCSAT::infer().

00928   {
00929           if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
00930           // for each ground predicate whose MB has changed
00931           for (int g = 0; g < gndPreds.size(); g++)
00932           {
00933                   double wtIfNoChange = 0, wtIfInverted = 0, wt;
00934                   // Ground clauses in which this pred occurs
00935                   Array<int>& negGndClauses =
00936                           hstate_->getNegOccurenceArray(gndPredIndices[g] + 1);
00937                   Array<int>& posGndClauses =
00938                           hstate_->getPosOccurenceArray(gndPredIndices[g] + 1);
00939 
00940                   int gndClauseIdx;
00941                   bool sense;
00942                   if (mcmcdebug)
00943                   {
00944                           cout << "Ground clauses in which pred " << g << " occurs neg.: "
00945                                   << negGndClauses.size() << endl;
00946                           cout << "Ground clauses in which pred " << g << " occurs pos.: "
00947                                   << posGndClauses.size() << endl;
00948                   }
00949 
00950                   for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00951                   {
00952                           if (i < negGndClauses.size())
00953                           {
00954                                   gndClauseIdx = negGndClauses[i];
00955                                   if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
00956                                   sense = false;
00957                           }
00958                           else
00959                           {
00960                                   gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00961                                   if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
00962                                   sense = true;
00963                           }
00964 
00965                           GroundClause* gndClause = hstate_->getGndClause(gndClauseIdx);
00966                           if (gndClause->isHardClause())
00967                                   wt = hstate_->getClauseCost(gndClauseIdx);
00968                           else
00969                                   wt = gndClause->getWt();
00970                           // NumTrueLits are stored differently for multi-chain
00971                           int numSatLiterals;
00972                           if (numChains_ > 1)
00973                                   numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
00974                           else
00975                                   numSatLiterals = hstate_->getNumTrueLits(gndClauseIdx);
00976                           if (numSatLiterals > 1)
00977                           {
00978                                   // Some other literal is making it sat, so it doesn't matter
00979                                   // if pos. clause. If neg., nothing can be done to unsatisfy it.
00980                                   if (wt > 0)
00981                                   {
00982                                           wtIfNoChange += wt;
00983                                           wtIfInverted += wt;
00984                                   }
00985                           }
00986                           else 
00987                                   if (numSatLiterals == 1) 
00988                                   {
00989                                           if (wt > 0) wtIfNoChange += wt;
00990                                           // Truth values are stored differently for multi-chain
00991                                           bool truthValue;
00992                                           if (numChains_ > 1)
00993                                                   truthValue = truthValues_[gndPredIndices[g]][chainIdx];
00994                                           else
00995                                                   truthValue = gndPreds[g]->getTruthValue();
00996                                           // If the current truth value is the same as its sense in gndClause
00997                                           if (truthValue == sense) 
00998                                           {
00999                                                   // This gndPred is the only one making this function satisfied
01000                                                   if (wt < 0) wtIfInverted += abs(wt);
01001                                           }
01002                                           else 
01003                                           {
01004                                                   // Some other literal is making it satisfied
01005                                                   if (wt > 0) wtIfInverted += wt;
01006                                           }
01007                                   }
01008                                   else
01009                                           if (numSatLiterals == 0) 
01010                                           {
01011                                                   // None satisfy, so when gndPred switch to its negative, it'll satisfy
01012                                                   if (wt > 0) wtIfInverted += wt;
01013                                                   else if (wt < 0) wtIfNoChange += abs(wt);
01014                                           }
01015                   } // for each ground clause that gndPred appears in
01016 
01017                   if (mcmcdebug)
01018                   {
01019                           cout << "wtIfNoChange of pred " << g << ": "
01020                                   << wtIfNoChange << endl;
01021                           cout << "wtIfInverted of pred " << g << ": "
01022                                   << wtIfInverted << endl;
01023                   }
01024 
01025                   // Clause info is stored differently for multi-chain
01026                   if (numChains_ > 1)
01027                   {
01028                           if (truthValues_[gndPredIndices[g]][chainIdx]) 
01029                           {
01030                                   wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01031                                   wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01032                           }
01033                           else 
01034                           {
01035                                   wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01036                                   wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01037                           }
01038                   }
01039                   else
01040                   { // Single chain
01041                           if (gndPreds[g]->getTruthValue())
01042                           {
01043                                   gndPreds[g]->setWtWhenTrue(wtIfNoChange);
01044                                   gndPreds[g]->setWtWhenFalse(wtIfInverted);
01045                           }
01046                           else
01047                           {
01048                                   gndPreds[g]->setWtWhenFalse(wtIfNoChange);
01049                                   gndPreds[g]->setWtWhenTrue(wtIfInverted);            
01050                           }
01051                   }
01052           } // for each ground predicate whose MB has changed
01053           if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
01054   }

void MCMC::updateWtsForGndPreds ( GroundPredicateHashArray gndPreds,
Array< int > &  gndPredIndices,
const int &  chainIdx 
) [inline, protected]

Updates the weights of affected ground predicates.

These are the ground predicates which are in clauses of predicates which have had their truth value changed.

Parameters:
gndPreds Ground predicates whose weights should be updated.
chainIdx Index of chain where updating occurs.

Definition at line 1067 of file mcmc.h.

References VariableState::getClauseCost(), VariableState::getGndClause(), VariableState::getNegOccurenceArray(), VariableState::getNumTrueLits(), VariableState::getPosOccurenceArray(), GroundClause::getWt(), GroundClause::isHardClause(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), Inference::state_, truthValues_, wtsWhenFalse_, and wtsWhenTrue_.

Referenced by SimulatedTempering::infer(), MCSAT::infer(), GibbsSampler::infer(), and performGibbsStep().

01070   {
01071     if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
01072       // for each ground predicate whose MB has changed
01073     for (int g = 0; g < gndPreds.size(); g++)
01074     {
01075       double wtIfNoChange = 0, wtIfInverted = 0, wt;
01076         // Ground clauses in which this pred occurs
01077       Array<int>& negGndClauses =
01078         state_->getNegOccurenceArray(gndPredIndices[g] + 1);
01079       Array<int>& posGndClauses =
01080         state_->getPosOccurenceArray(gndPredIndices[g] + 1);
01081       int gndClauseIdx;
01082       bool sense;
01083       
01084       if (mcmcdebug)
01085       {
01086         cout << "Ground clauses in which pred " << g << " occurs neg.: "
01087              << negGndClauses.size() << endl;
01088         cout << "Ground clauses in which pred " << g << " occurs pos.: "
01089              << posGndClauses.size() << endl;
01090       }
01091       
01092       for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
01093       {
01094         if (i < negGndClauses.size())
01095         {
01096           gndClauseIdx = negGndClauses[i];
01097           if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
01098           sense = false;
01099         }
01100         else
01101         {
01102           gndClauseIdx = posGndClauses[i - negGndClauses.size()];
01103           if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
01104           sense = true;
01105         }
01106         
01107         GroundClause* gndClause = state_->getGndClause(gndClauseIdx);
01108         if (gndClause->isHardClause())
01109           wt = state_->getClauseCost(gndClauseIdx);
01110         else
01111           wt = gndClause->getWt();
01112           // NumTrueLits are stored differently for multi-chain
01113         int numSatLiterals;
01114         if (numChains_ > 1)
01115           numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
01116         else
01117           numSatLiterals = state_->getNumTrueLits(gndClauseIdx);
01118 
01119         if (numSatLiterals > 1)
01120         {
01121             // Some other literal is making it sat, so it doesn't matter
01122             // if pos. clause. If neg., nothing can be done to unsatisfy it.
01123           if (wt > 0)
01124           {
01125             wtIfNoChange += wt;
01126             wtIfInverted += wt;
01127           }
01128         }
01129         else 
01130         if (numSatLiterals == 1) 
01131         {
01132           if (wt > 0) wtIfNoChange += wt;
01133             // Truth values are stored differently for multi-chain
01134           bool truthValue;
01135           if (numChains_ > 1)
01136             truthValue = truthValues_[gndPredIndices[g]][chainIdx];
01137           else
01138             truthValue = gndPreds[g]->getTruthValue();
01139             // If the current truth value is the same as its sense in gndClause
01140           if (truthValue == sense) 
01141           {
01142             // This gndPred is the only one making this function satisfied
01143             if (wt < 0) wtIfInverted += abs(wt);
01144           }
01145           else 
01146           {
01147               // Some other literal is making it satisfied
01148             if (wt > 0) wtIfInverted += wt;
01149           }
01150         }
01151         else
01152         if (numSatLiterals == 0) 
01153         {
01154           // None satisfy, so when gndPred switch to its negative, it'll satisfy
01155           if (wt > 0) wtIfInverted += wt;
01156           else if (wt < 0) wtIfNoChange += abs(wt);
01157         }
01158       } // for each ground clause that gndPred appears in
01159 
01160       if (mcmcdebug)
01161       {
01162         cout << "wtIfNoChange of pred " << g << ": "
01163              << wtIfNoChange << endl;
01164         cout << "wtIfInverted of pred " << g << ": "
01165              << wtIfInverted << endl;
01166       }
01167 
01168         // Clause info is stored differently for multi-chain
01169       if (numChains_ > 1)
01170       {
01171         if (truthValues_[gndPredIndices[g]][chainIdx]) 
01172         {
01173           wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01174           wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01175         }
01176         else 
01177         {
01178           wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01179           wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01180         }
01181       }
01182       else
01183       { // Single chain
01184         if (gndPreds[g]->getTruthValue())
01185         {
01186           gndPreds[g]->setWtWhenTrue(wtIfNoChange);
01187           gndPreds[g]->setWtWhenFalse(wtIfInverted);
01188         }
01189         else
01190         {
01191           gndPreds[g]->setWtWhenFalse(wtIfNoChange);
01192           gndPreds[g]->setWtWhenTrue(wtIfInverted);            
01193         }
01194       }
01195     } // for each ground predicate whose MB has changed
01196     if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
01197   }

int MCMC::gibbsSampleFromBlock ( const int &  chainIdx,
const int &  blockIndex,
const double &  invTemp 
) [inline, protected]

Chooses an atom from a block according to their probabilities.

Parameters:
chainIdx Index of chain.
block Block of predicate indices from which one is chosen.
invTemp InvTemp used in simulated tempering.
Returns:
Index of chosen atom in the block.

Definition at line 1208 of file mcmc.h.

References Array< Type >::append(), Domain::getBlockSize(), VariableState::getDomain(), VariableState::getIndexOfGroundPredicate(), Domain::getPredInBlock(), getProbabilityOfPred(), and Inference::state_.

Referenced by SimulatedTempering::infer(), and performGibbsStep().

01210   {
01211     Array<double> numerators;
01212     double denominator = 0;
01213 
01214     int blockSize = state_->getDomain()->getBlockSize(blockIndex);
01215     for (int i = 0; i < blockSize; i++)
01216     {
01217       const Predicate* pred =
01218         state_->getDomain()->getPredInBlock(i, blockIndex);
01219       GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
01220       int idx = state_->getIndexOfGroundPredicate(gndPred);
01221       
01222       delete gndPred;
01223       delete pred;
01224 
01225         // Prob is 0 if atom not in state
01226       double prob = 0.0;
01227         // Pred is in the state; otherwise, prob is zero
01228       if (idx >= 0)
01229         prob = getProbabilityOfPred(idx, chainIdx, invTemp);
01230 
01231       numerators.append(prob);
01232       denominator += prob;
01233     }
01234 
01235     double r = random();
01236     double numSum = 0.0;
01237     for (int i = 0; i < blockSize; i++)
01238     {
01239       numSum += numerators[i];
01240       if (r < ((numSum / denominator) * RAND_MAX))
01241       {
01242         return i;
01243       }
01244     }
01245     return blockSize - 1;
01246   }

void MCMC::gndPredFlippedUpdates ( const int &  gndPredIdx,
const int &  chainIdx,
GroundPredicateHashArray affectedGndPreds,
Array< int > &  affectedGndPredIndices 
) [inline, protected]

Updates information when a ground predicate is flipped and retrieves the Markov blanket of the ground predicate.

Parameters:
gndPredIdx Index of ground pred which was flipped.
chainIdx Index of chain in which the flipping occured.
affectedGndPreds Holds the Markov blanket of the ground predicate.

Definition at line 1256 of file mcmc.h.

References Array< Type >::append(), HashArray< Type, HashFn, EqualFn >::append(), VariableState::decrementNumTrueLits(), VariableState::getGndClause(), VariableState::getGndPred(), VariableState::getGndPredHashArrayPtr(), GroundClause::getGroundPredicate(), GroundClause::getGroundPredicateIndex(), VariableState::getNegOccurenceArray(), VariableState::getNumAtoms(), GroundClause::getNumGroundPredicates(), VariableState::getPosOccurenceArray(), GroundPredicate::getTruthValue(), VariableState::incrementNumTrueLits(), numChains_, numTrueLits_, Array< Type >::size(), HashArray< Type, HashFn, EqualFn >::size(), Inference::state_, and truthValues_.

Referenced by SimulatedTempering::infer(), and performGibbsStep().

01259   {
01260     if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl;
01261     int numAtoms = state_->getNumAtoms();
01262     GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
01263     affectedGndPreds.append(gndPred, numAtoms);
01264     affectedGndPredIndices.append(gndPredIdx);
01265     assert(affectedGndPreds.size() <= numAtoms);
01266 
01267     Array<int>& negGndClauses =
01268       state_->getNegOccurenceArray(gndPredIdx + 1);
01269     Array<int>& posGndClauses =
01270       state_->getPosOccurenceArray(gndPredIdx + 1);
01271     int gndClauseIdx;
01272     GroundClause* gndClause; 
01273     bool sense;
01274 
01275       // Find the Markov blanket of this ground predicate
01276     for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
01277     {
01278       if (i < negGndClauses.size())
01279       {
01280         gndClauseIdx = negGndClauses[i];
01281         sense = false;
01282       }
01283       else
01284       {
01285         gndClauseIdx = posGndClauses[i - negGndClauses.size()];
01286         sense = true;
01287       }
01288       gndClause = state_->getGndClause(gndClauseIdx);
01289 
01290         // Different for multi-chain
01291       if (numChains_ > 1)
01292       {
01293         if (truthValues_[gndPredIdx][chainIdx] == sense)
01294           numTrueLits_[gndClauseIdx][chainIdx]++;
01295         else
01296           numTrueLits_[gndClauseIdx][chainIdx]--;
01297       }
01298       else
01299       { // Single chain
01300         if (gndPred->getTruthValue() == sense)
01301           state_->incrementNumTrueLits(gndClauseIdx);
01302         else
01303           state_->decrementNumTrueLits(gndClauseIdx);
01304       }
01305       
01306       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
01307       {
01308         const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr();
01309         GroundPredicate* pred = 
01310           (GroundPredicate*)gndClause->getGroundPredicate(j,
01311             (GroundPredicateHashArray*)gpha);
01312         affectedGndPreds.append(pred, numAtoms);
01313         affectedGndPredIndices.append(
01314                                abs(gndClause->getGroundPredicateIndex(j)) - 1);
01315         assert(affectedGndPreds.size() <= numAtoms);
01316       }
01317     }
01318     if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl;
01319   }

void MCMC::saveLowStateToChain ( const int &  chainIdx  )  [inline, protected]

The atom assignment in the best state is saved to a chain in the ground predicates.

Parameters:
chainIdx Index of chain to which the atom assigment is saved

Definition at line 1335 of file mcmc.h.

References VariableState::getNumAtoms(), VariableState::getValueOfLowAtom(), Inference::state_, and truthValues_.

Referenced by SimulatedTempering::init(), and GibbsSampler::init().

01336   {
01337     for (int i = 0; i < state_->getNumAtoms(); i++)
01338       truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1);
01339   }

void MCMC::setMCMCParameters ( MCMCParams params  )  [inline, protected]

Sets the user-set parameters for this MCMC algorithm.

Parameters:
params MCMC parameters for this algorithm.

Definition at line 1346 of file mcmc.h.

References MCMCParams::burnMaxSteps, burnMaxSteps_, MCMCParams::burnMinSteps, burnMinSteps_, MCMCParams::maxSeconds, maxSeconds_, MCMCParams::maxSteps, maxSteps_, MCMCParams::minSteps, minSteps_, MCMCParams::numChains, and numChains_.

01347   {
01348       // User-set parameters
01349     numChains_ = params->numChains;
01350     burnMinSteps_ = params->burnMinSteps;
01351     burnMaxSteps_ = params->burnMaxSteps;
01352     minSteps_ = params->minSteps;
01353     maxSteps_ = params->maxSteps;
01354     maxSeconds_ = params->maxSeconds;    
01355   }

void MCMC::scaleSamples ( double  factor  )  [inline, protected, virtual]

Increase or decrease the number of MCMC samples by a multiplicative factor.

(For MaxWalkSAT, this could perhaps change the number of flips or something that trades off speed and accuracy.)

Reimplemented from Inference.

Definition at line 1357 of file mcmc.h.

References maxSteps_, and minSteps_.

01358   {
01359     minSteps_ = (int)(minSteps_ * factor);
01360     maxSteps_ = (int)(maxSteps_ * factor);
01361   }


The documentation for this class was generated from the following file:
Generated on Sun Jun 7 11:55:26 2009 for Alchemy by  doxygen 1.5.1