#include <discriminativelearner.h>
Public Member Functions | |
| DiscriminativeLearner (const Array< Inference * > &inferences, const StringHashArray &nonEvidPredNames, IndexTranslator *const &idxTrans, const bool &lazyInference, const bool &withEM, const bool &rescaleGradient, const int &method, const double &lambda, const bool &preconditionCG, const double &maxLambda) | |
| Constructor. | |
| ~DiscriminativeLearner () | |
| void | setMeansStdDevs (const int &arrSize, const double *const &priorMeans, const double *const &priorStdDevs) |
| void | setMLNWeights (double *const &weights) |
| void | setLogOddsWeights (double *weights, int numWeights) |
| void | getVariance (Array< double > &variance, int numWeights) |
| const Array< double > * | getHessianVectorProduct (const Array< double > &d) |
| double | computeQuadraticStepLength (double *gradient, const Array< double > &d, const Array< double > *Hd, double lambda) |
| void | learnWeights (double *const &weights, const int &numWeights, const int &maxIter, const double &maxSec, const double &learningRate, const double &momentum, bool initWithLogOdds, const int &mwsMaxSubsequentSteps, bool periodicMLNs) |
Static Public Attributes | |
| static const int | SIMPLE = 0 |
| static const int | DN = 2 |
| static const int | CG = 3 |
Definition at line 131 of file discriminativelearner.h.
| DiscriminativeLearner::DiscriminativeLearner | ( | const Array< Inference * > & | inferences, | |
| const StringHashArray & | nonEvidPredNames, | |||
| IndexTranslator *const & | idxTrans, | |||
| const bool & | lazyInference, | |||
| const bool & | withEM, | |||
| const bool & | rescaleGradient, | |||
| const int & | method, | |||
| const double & | lambda, | |||
| const bool & | preconditionCG, | |||
| const double & | maxLambda | |||
| ) | [inline] |
Constructor.
Various variables are initialized, relevant clauses are determined and weights and inference procedures are initialized.
| inferences | Array of inference procedures to be used for inference in each domain. | |
| nonEvidPredNames | Names of non-evidence predicates. This is used to determine the relevant clauses. | |
| idxTrans | IndexTranslator needed when multiple dbs are used and they don't line up. | |
| lazyInference | If true, lazy inference is used. | |
| withEM | If true, EM is used to fill in missing values. | |
| rescaleGradient | If true, use per-weight learning rates | |
| method | Determines how direction and step size are chosen | |
| lambda | Initial value of lambda for SMD or CG | |
| preconditionCG | Whether or not to use a preconditioner with scaled conjugate gradient | |
| maxLambda | Maximum value of lambda for CG |
Definition at line 154 of file discriminativelearner.h.
References Array< Type >::append(), Array< Type >::growToSize(), and Array< Type >::size().
00161 : domainCnt_(inferences.size()), idxTrans_(idxTrans), 00162 lazyInference_(lazyInference), rescaleGradient_(rescaleGradient), 00163 method_(method), 00164 // HACK: for now, we'll use the SMD lambda value for CG, even 00165 // though the two represent *very* different things! 00166 cg_lambda_(lambda), preconditionCG_(preconditionCG), 00167 maxBacktracks_(1000), backtrackCount_(0), 00168 cg_max_lambda_(maxLambda), withEM_(withEM) 00169 { 00170 cout << endl << "Constructing discriminative learner..." << endl << endl; 00171 00172 inferences_.append(inferences); 00173 logOddsPerDomain_.growToSize(domainCnt_); 00174 clauseCntPerDomain_.growToSize(domainCnt_); 00175 00176 for (int i = 0; i < domainCnt_; i++) 00177 { 00178 clauseCntPerDomain_[i] = 00179 inferences_[i]->getState()->getMLN()->getNumClauses(); 00180 logOddsPerDomain_[i].growToSize(clauseCntPerDomain_[i], 0); 00181 } 00182 00183 totalTrueCnts_.growToSize(domainCnt_); 00184 totalFalseCnts_.growToSize(domainCnt_); 00185 defaultTrueCnts_.growToSize(domainCnt_); 00186 defaultFalseCnts_.growToSize(domainCnt_); 00187 relevantClausesPerDomain_.growToSize(domainCnt_); 00188 //relevantClausesFormulas_ is set in findRelevantClausesFormulas() 00189 00190 findRelevantClauses(nonEvidPredNames); 00191 findRelevantClausesFormulas(); 00192 00193 // Initialize the clause wts 00194 initializeWts(nonEvidPredNames); 00195 00196 // Initialize the inference / state 00197 for (int i = 0; i < inferences_.size(); i++) 00198 inferences_[i]->init(); 00199 }
1.5.1