#include <discriminativelearner.h>
Public Member Functions | |
DiscriminativeLearner (const Array< Inference * > &inferences, const StringHashArray &nonEvidPredNames, IndexTranslator *const &idxTrans, const bool &lazyInference, const bool &withEM, const bool &rescaleGradient, const int &method, const double &lambda, const bool &preconditionCG, const double &maxLambda) | |
Constructor. | |
~DiscriminativeLearner () | |
void | setMeansStdDevs (const int &arrSize, const double *const &priorMeans, const double *const &priorStdDevs) |
void | setMLNWeights (double *const &weights) |
void | setLogOddsWeights (double *weights, int numWeights) |
void | getVariance (Array< double > &variance, int numWeights) |
const Array< double > * | getHessianVectorProduct (const Array< double > &d) |
double | computeQuadraticStepLength (double *gradient, const Array< double > &d, const Array< double > *Hd, double lambda) |
void | learnWeights (double *const &weights, const int &numWeights, const int &maxIter, const double &maxSec, const double &learningRate, const double &momentum, bool initWithLogOdds, const int &mwsMaxSubsequentSteps, bool periodicMLNs) |
Static Public Attributes | |
static const int | SIMPLE = 0 |
static const int | DN = 2 |
static const int | CG = 3 |
Definition at line 131 of file discriminativelearner.h.
DiscriminativeLearner::DiscriminativeLearner | ( | const Array< Inference * > & | inferences, | |
const StringHashArray & | nonEvidPredNames, | |||
IndexTranslator *const & | idxTrans, | |||
const bool & | lazyInference, | |||
const bool & | withEM, | |||
const bool & | rescaleGradient, | |||
const int & | method, | |||
const double & | lambda, | |||
const bool & | preconditionCG, | |||
const double & | maxLambda | |||
) | [inline] |
Constructor.
Various variables are initialized, relevant clauses are determined and weights and inference procedures are initialized.
inferences | Array of inference procedures to be used for inference in each domain. | |
nonEvidPredNames | Names of non-evidence predicates. This is used to determine the relevant clauses. | |
idxTrans | IndexTranslator needed when multiple dbs are used and they don't line up. | |
lazyInference | If true, lazy inference is used. | |
withEM | If true, EM is used to fill in missing values. | |
rescaleGradient | If true, use per-weight learning rates | |
method | Determines how direction and step size are chosen | |
lambda | Initial value of lambda for SMD or CG | |
preconditionCG | Whether or not to use a preconditioner with scaled conjugate gradient | |
maxLambda | Maximum value of lambda for CG |
Definition at line 154 of file discriminativelearner.h.
References Array< Type >::append(), Array< Type >::growToSize(), and Array< Type >::size().
00161 : domainCnt_(inferences.size()), idxTrans_(idxTrans), 00162 lazyInference_(lazyInference), rescaleGradient_(rescaleGradient), 00163 method_(method), 00164 // HACK: for now, we'll use the SMD lambda value for CG, even 00165 // though the two represent *very* different things! 00166 cg_lambda_(lambda), preconditionCG_(preconditionCG), 00167 maxBacktracks_(1000), backtrackCount_(0), 00168 cg_max_lambda_(maxLambda), withEM_(withEM) 00169 { 00170 cout << endl << "Constructing discriminative learner..." << endl << endl; 00171 00172 inferences_.append(inferences); 00173 logOddsPerDomain_.growToSize(domainCnt_); 00174 clauseCntPerDomain_.growToSize(domainCnt_); 00175 00176 for (int i = 0; i < domainCnt_; i++) 00177 { 00178 clauseCntPerDomain_[i] = 00179 inferences_[i]->getState()->getMLN()->getNumClauses(); 00180 logOddsPerDomain_[i].growToSize(clauseCntPerDomain_[i], 0); 00181 } 00182 00183 totalTrueCnts_.growToSize(domainCnt_); 00184 totalFalseCnts_.growToSize(domainCnt_); 00185 defaultTrueCnts_.growToSize(domainCnt_); 00186 defaultFalseCnts_.growToSize(domainCnt_); 00187 relevantClausesPerDomain_.growToSize(domainCnt_); 00188 //relevantClausesFormulas_ is set in findRelevantClausesFormulas() 00189 00190 findRelevantClauses(nonEvidPredNames); 00191 findRelevantClausesFormulas(); 00192 00193 // Initialize the clause wts 00194 initializeWts(nonEvidPredNames); 00195 00196 // Initialize the inference / state 00197 for (int i = 0; i < inferences_.size(); i++) 00198 inferences_[i]->init(); 00199 }