#include <votedperceptron.h>
Public Member Functions | |
VotedPerceptron (const Array< Inference * > &inferences, const StringHashArray &nonEvidPredNames, IndexTranslator *const &idxTrans, const bool &lazyInference, const bool &rescaleGradient, const bool &withEM) | |
Constructor. | |
~VotedPerceptron () | |
void | setMeansStdDevs (const int &arrSize, const double *const &priorMeans, const double *const &priorStdDevs) |
void | learnWeights (double *const &weights, const int &numWeights, const int &maxIter, const double &learningRate, const double &momentum, bool initWithLogOdds) |
Definition at line 82 of file votedperceptron.h.
VotedPerceptron::VotedPerceptron | ( | const Array< Inference * > & | inferences, | |
const StringHashArray & | nonEvidPredNames, | |||
IndexTranslator *const & | idxTrans, | |||
const bool & | lazyInference, | |||
const bool & | rescaleGradient, | |||
const bool & | withEM | |||
) | [inline] |
Constructor.
Various variables are initialized, relevant clauses are determined and weights and inference procedures are initialized.
inferences | Array of inference procedures to be used for inference in each domain. | |
nonEvidPredNames | Names of non-evidence predicates. This is used to determine the relevant clauses. | |
idxTrans | IndexTranslator needed when multiple dbs are used and they don't line up. | |
lazyInference | If true, lazy inference is used. | |
rescaleGradient | If true, gradient is rescaled with each iteration. | |
withEM | If true, EM is used to fill in missing values. |
Definition at line 100 of file votedperceptron.h.
References Array< Type >::append(), Array< Type >::growToSize(), and Array< Type >::size().
00104 : domainCnt_(inferences.size()), idxTrans_(idxTrans), 00105 lazyInference_(lazyInference), rescaleGradient_(rescaleGradient), 00106 withEM_(withEM) 00107 { 00108 cout << endl << "Constructing voted perceptron..." << endl << endl; 00109 00110 inferences_.append(inferences); 00111 logOddsPerDomain_.growToSize(domainCnt_); 00112 clauseCntPerDomain_.growToSize(domainCnt_); 00113 00114 for (int i = 0; i < domainCnt_; i++) 00115 { 00116 clauseCntPerDomain_[i] = 00117 inferences_[i]->getState()->getMLN()->getNumClauses(); 00118 logOddsPerDomain_[i].growToSize(clauseCntPerDomain_[i], 0); 00119 } 00120 00121 totalTrueCnts_.growToSize(domainCnt_); 00122 defaultTrueCnts_.growToSize(domainCnt_); 00123 relevantClausesPerDomain_.growToSize(domainCnt_); 00124 //relevantClausesFormulas_ is set in findRelevantClausesFormulas() 00125 00126 findRelevantClauses(nonEvidPredNames); 00127 findRelevantClausesFormulas(); 00128 00129 // Initialize the clause wts for lazy version 00130 if (lazyInference_) 00131 { 00132 findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(nonEvidPredNames); 00133 00134 for (int i = 0; i < domainCnt_; i++) 00135 { 00136 const MLN* mln = inferences_[i]->getState()->getMLN(); 00137 Array<double>& logOdds = logOddsPerDomain_[i]; 00138 assert(mln->getNumClauses() == logOdds.size()); 00139 for (int j = 0; j < mln->getNumClauses(); j++) 00140 ((Clause*) mln->getClause(j))->setWt(logOdds[j]); 00141 } 00142 } 00143 // Initialize the clause wts for eager version 00144 else 00145 { 00146 initializeWts(); 00147 } 00148 00149 // Initialize the inference / state 00150 for (int i = 0; i < inferences_.size(); i++) 00151 inferences_[i]->init(); 00152 }