#include <inference.h>
Inheritance diagram for Inference:
Public Member Functions | |
Inference (VariableState *state, long int seed, const bool &trackClauseTrueCnts, Array< Array< Predicate * > * > *queryFormulas=NULL) | |
Constructor: Every inference algorithm is required to have a VariableState representing the state of variables and clauses and a seed for any randomization in the algorithm. | |
Inference (HVariableState *state, long int seed, const bool &trackClauseTrueCnts) | |
virtual | ~Inference () |
Virtual destructor. | |
void | saveAllCounts (bool saveCounts=true) |
virtual void | init ()=0 |
Initializes the inference algorithm. | |
virtual void | infer ()=0 |
Performs the inference algorithm. | |
virtual void | printNetwork (ostream &out)=0 |
Prints out the network. | |
virtual void | printProbabilities (ostream &out)=0 |
Prints the probabilities of each predicate to a stream. | |
virtual void | getChangedPreds (vector< string > &changedPreds, vector< float > &probs, vector< float > &oldProbs, const float &probDelta)=0 |
Puts the predicates whose probability has changed (more than probDelta in the case of prob. | |
virtual void | printTruePreds (ostream &out)=0 |
Prints the predicates inferred to be true to a stream. | |
virtual void | printTruePredsH (ostream &out)=0 |
virtual double | getProbability (GroundPredicate *const &gndPred)=0 |
Gets the probability of a ground predicate. | |
virtual double | getProbabilityH (GroundPredicate *const &gndPred)=0 |
void | printQFProbs (ostream &out, Domain *domain) |
Print probabilities of the query formulas. | |
long int | getSeed () |
void | setSeed (long int s) |
VariableState * | getState () |
void | setState (VariableState *s) |
HVariableState * | getHState () |
void | setHState (HVariableState *s) |
virtual void | scaleSamples (double factor) |
Increase or decrease the number of MCMC samples by a multiplicative factor. | |
const Array< double > * | getClauseTrueCnts () |
const Array< double > * | getClauseTrueSqCnts () |
int | getNumSamples () const |
const Array< Array< double > > * | getHessian () |
const Array< double > * | getHessianVectorProduct2 (Array< double > &v) |
const Array< double > * | getHessianVectorProduct (const Array< double > &v) |
void | resetCnts () |
void | saveCnts () |
void | restoreCnts () |
void | tallyCntsFromState () |
Protected Attributes | |
long int | seed_ |
VariableState * | state_ |
HVariableState * | hstate_ |
bool | saveAllCounts_ |
Array< double > * | clauseTrueCnts_ |
Array< double > * | clauseTrueCntsCont_ |
Array< double > * | clauseTrueSqCnts_ |
int | numSamples_ |
bool | trackClauseTrueCnts_ |
Array< Array< double > > * | allClauseTrueCnts_ |
Array< double > * | oldClauseTrueCnts_ |
Array< Array< double > > * | oldAllClauseTrueCnts_ |
Array< Array< Predicate * > * > * | queryFormulas_ |
Array< double > * | qfProbs_ |
At least one function is pure virtual making this an abstract class (it can not be instantiated).
Definition at line 81 of file inference.h.
Inference::Inference | ( | VariableState * | state, | |
long int | seed, | |||
const bool & | trackClauseTrueCnts, | |||
Array< Array< Predicate * > * > * | queryFormulas = NULL | |||
) | [inline] |
Constructor: Every inference algorithm is required to have a VariableState representing the state of variables and clauses and a seed for any randomization in the algorithm.
If there is no randomization, seed is not used.
state | State of the variables and clauses of the inference. | |
seed | Seed used to initialize randomization in the algorithm. | |
trackClauseTrueCnts | Indicates if true counts for each first-order clause are being kept |
Definition at line 96 of file inference.h.
References clauseTrueCnts_, clauseTrueSqCnts_, VariableState::getMLN(), MLN::getNumClauses(), qfProbs_, queryFormulas_, seed_, Array< Type >::size(), state_, and trackClauseTrueCnts_.
00099 : seed_(seed), state_(state), saveAllCounts_(false), 00100 clauseTrueCnts_(NULL), clauseTrueSqCnts_(NULL), 00101 numSamples_(0), 00102 allClauseTrueCnts_(NULL), oldClauseTrueCnts_(NULL), 00103 oldAllClauseTrueCnts_(NULL), queryFormulas_(queryFormulas) 00104 { 00105 // If seed not specified, then init always to same random number 00106 if (seed_ == -1) seed_ = DEFAULT_SEED; 00107 srandom(seed_); 00108 00109 trackClauseTrueCnts_ = trackClauseTrueCnts; 00110 if (trackClauseTrueCnts_ && state_) 00111 { 00112 int numClauses = state_->getMLN()->getNumClauses(); 00113 00114 // clauseTrueCnts_ and clauseTrueSqCnts_ will hold the true 00115 // counts (and squared true counts) for each first-order clause 00116 clauseTrueCnts_ = new Array<double>(numClauses, 0); 00117 clauseTrueSqCnts_ = new Array<double>(numClauses, 0); 00118 } 00119 00120 if (queryFormulas_) 00121 qfProbs_ = new Array<double>(queryFormulas_->size(), 0); 00122 else 00123 qfProbs_ = NULL; 00124 }
virtual void Inference::printNetwork | ( | ostream & | out | ) | [pure virtual] |
virtual void Inference::getChangedPreds | ( | vector< string > & | changedPreds, | |
vector< float > & | probs, | |||
vector< float > & | oldProbs, | |||
const float & | probDelta | |||
) | [pure virtual] |
Puts the predicates whose probability has changed (more than probDelta in the case of prob.
inference) with respect to the reference vector oldProbs in string form and the corresponding probabilities of each predicate in two vectors.
Implemented in BP, MCMC, SAT, and UnitPropagation.
virtual void Inference::scaleSamples | ( | double | factor | ) | [inline, virtual] |
Increase or decrease the number of MCMC samples by a multiplicative factor.
(For MaxWalkSAT, this could perhaps change the number of flips or something that trades off speed and accuracy.)
Reimplemented in MCMC.
Definition at line 265 of file inference.h.