00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #ifndef INFERENCE_H_
00068 #define INFERENCE_H_
00069
00070 #include "variablestate.h"
00071 #include "hvariablestate.h"
00072
00073
00074 const long int DEFAULT_SEED = 2350877;
00075
00081 class Inference
00082 {
00083 public:
00084
00096 Inference(VariableState* state, long int seed,
00097 const bool& trackClauseTrueCnts,
00098 Array<Array<Predicate* >* >* queryFormulas = NULL)
00099 : seed_(seed), state_(state), saveAllCounts_(false),
00100 clauseTrueCnts_(NULL), clauseTrueSqCnts_(NULL),
00101 numSamples_(0),
00102 allClauseTrueCnts_(NULL), oldClauseTrueCnts_(NULL),
00103 oldAllClauseTrueCnts_(NULL), queryFormulas_(queryFormulas)
00104 {
00105
00106 if (seed_ == -1) seed_ = DEFAULT_SEED;
00107 srandom(seed_);
00108
00109 trackClauseTrueCnts_ = trackClauseTrueCnts;
00110 if (trackClauseTrueCnts_ && state_)
00111 {
00112 int numClauses = state_->getMLN()->getNumClauses();
00113
00114
00115
00116 clauseTrueCnts_ = new Array<double>(numClauses, 0);
00117 clauseTrueSqCnts_ = new Array<double>(numClauses, 0);
00118 }
00119
00120 if (queryFormulas_)
00121 qfProbs_ = new Array<double>(queryFormulas_->size(), 0);
00122 else
00123 qfProbs_ = NULL;
00124 }
00125
00126 Inference(HVariableState* state, long int seed,
00127 const bool& trackClauseTrueCnts)
00128 : seed_(seed), hstate_(state), saveAllCounts_(false),
00129 clauseTrueCnts_(NULL), clauseTrueCntsCont_(NULL),
00130 clauseTrueSqCnts_(NULL), numSamples_(0),
00131 allClauseTrueCnts_(NULL), oldClauseTrueCnts_(NULL),
00132 oldAllClauseTrueCnts_(NULL)
00133 {
00134
00135 if (seed_ == -1) seed_ = DEFAULT_SEED;
00136 srandom(seed_);
00137
00138 trackClauseTrueCnts_ = trackClauseTrueCnts;
00139 if (trackClauseTrueCnts_ && hstate_)
00140 {
00141
00142
00143 clauseTrueCnts_ = new Array<double>;
00144 clauseTrueCnts_->growToSize(hstate_->getMLN()->getNumClauses(), 0);
00145 clauseTrueCntsCont_ = new Array<double>;
00146 clauseTrueCntsCont_->growToSize(hstate_->getNumContFormulas(), 0);
00147 }
00148 }
00149
00153 virtual ~Inference()
00154 {
00155 delete clauseTrueCnts_;
00156 delete clauseTrueSqCnts_;
00157 delete allClauseTrueCnts_;
00158
00159 delete oldAllClauseTrueCnts_;
00160 delete oldClauseTrueCnts_;
00161
00162 if (qfProbs_) delete qfProbs_;
00163 }
00164
00165
00166 void saveAllCounts(bool saveCounts=true)
00167 {
00168 if (saveAllCounts_ == saveCounts)
00169 return;
00170
00171 saveAllCounts_ = saveCounts;
00172 if (saveCounts)
00173 {
00174 allClauseTrueCnts_ = new Array<Array<double> >;
00175 oldAllClauseTrueCnts_ = new Array<Array<double> >;
00176 }
00177 else
00178 {
00179 delete allClauseTrueCnts_;
00180 delete oldAllClauseTrueCnts_;
00181 allClauseTrueCnts_ = NULL;
00182 oldAllClauseTrueCnts_ = NULL;
00183 }
00184 }
00185
00186
00190 virtual void init() = 0;
00191
00195 virtual void infer() = 0;
00196
00200 virtual void printNetwork(ostream& out) = 0;
00201
00205 virtual void printProbabilities(ostream& out) = 0;
00206
00213 virtual void getChangedPreds(vector<string>& changedPreds,
00214 vector<float>& probs,
00215 vector<float>& oldProbs,
00216 const float& probDelta) = 0;
00217
00218
00222 virtual void printTruePreds(ostream& out) = 0;
00223 virtual void printTruePredsH(ostream& out) = 0;
00224
00228 virtual double getProbability(GroundPredicate* const& gndPred) = 0;
00229 virtual double getProbabilityH(GroundPredicate* const& gndPred) = 0;
00230
00234 void printQFProbs(ostream& out, Domain* domain)
00235 {
00236 if (qfProbs_)
00237 {
00238 for (int i = 0; i < queryFormulas_->size(); i++)
00239 {
00240 Array<Predicate* >* formula = (*queryFormulas_)[i];
00241 for (int j = 0; j < formula->size(); j++)
00242 {
00243 (*formula)[j]->printWithStrVar(out, domain);
00244 if (j != formula->size() - 1) out << " ^ ";
00245 }
00246 out << " " << (*qfProbs_)[i] << endl;
00247 }
00248 }
00249 }
00250
00251 long int getSeed() { return seed_; }
00252 void setSeed(long int s) { seed_ = s; }
00253
00254 VariableState* getState() { return state_; }
00255 void setState(VariableState* s) { state_ = s; }
00256
00257 HVariableState* getHState() { return hstate_; }
00258 void setHState(HVariableState* s) { hstate_ = s; }
00259
00265 virtual void scaleSamples(double factor) { }
00266
00267
00268
00269 const Array<double>* getClauseTrueCnts() { return clauseTrueCnts_; }
00270 const Array<double>* getClauseTrueSqCnts() { return clauseTrueSqCnts_; }
00271 int getNumSamples() const { return numSamples_; }
00272
00273
00274
00275
00276
00277
00278
00279
00280 const Array<Array<double> >* getHessian()
00281 {
00282 int numClauses = state_->getMLN()->getNumClauses();
00283 int numSamples = allClauseTrueCnts_->size();
00284
00285
00286 Array<Array<double> >* hessian = new Array<Array<double> >(numClauses);
00287 for (int i = 0; i < numClauses; i++)
00288 (*hessian)[i].growToSize(numClauses);
00289
00290
00291
00292
00293
00294 for (int i = 0; i < numClauses; i++)
00295 {
00296 for (int j = 0; j < numClauses; j++)
00297 {
00298 double ni = 0.0;
00299 double nj = 0.0;
00300 double ninj = 0.0;
00301 for (int s = 0; s < numSamples; s++)
00302 {
00303 ni += (*allClauseTrueCnts_)[s][i];
00304 nj += (*allClauseTrueCnts_)[s][j];
00305 ninj += (*allClauseTrueCnts_)[s][i]
00306 * (*allClauseTrueCnts_)[s][j];
00307 }
00308 double n = numSamples;
00309 (*hessian)[i][j] = ni/n * nj/n - ninj/n;
00310 }
00311 }
00312
00313 return hessian;
00314 }
00315
00316
00317
00318
00319 const Array<double>* getHessianVectorProduct2(Array<double>& v)
00320 {
00321 int numClauses = state_->getMLN()->getNumClauses();
00322 const Array<Array<double> >* hessian = getHessian();
00323 Array<double>* product = new Array<double>(numClauses,0);
00324
00325 for (int clauseno = 0; clauseno < numClauses; clauseno++)
00326 {
00327 (*product)[clauseno] = 0.0;
00328 for (int i = 0; i < numClauses; i++)
00329 (*product)[clauseno] += (*hessian)[clauseno][i] * v[i];
00330 }
00331
00332 delete hessian;
00333 return product;
00334 }
00335
00336
00337 const Array<double>* getHessianVectorProduct(const Array<double>& v)
00338 {
00339 int numClauses = state_->getMLN()->getNumClauses();
00340 int numSamples = allClauseTrueCnts_->size();
00341
00342
00343
00344
00345
00346
00347
00348 double sumVN = 0;
00349 Array<double> sumN(numClauses, 0);
00350 Array<double> sumNiVN(numClauses, 0);
00351
00352
00353
00354 for (int s = 0; s < numSamples; s++)
00355 {
00356 Array<double>& n = (*allClauseTrueCnts_)[s];
00357
00358
00359 double vn = 0;
00360 for (int i = 0; i < numClauses; i++)
00361 vn += v[i] * n[i];
00362
00363
00364 sumVN += vn;
00365 for (int i = 0; i < numClauses; i++)
00366 {
00367 sumN[i] += n[i];
00368 sumNiVN[i] += n[i] * vn;
00369 }
00370 }
00371
00372
00373 Array<double>* product = new Array<double>(numClauses,0);
00374 for (int clauseno = 0; clauseno < numClauses; clauseno++)
00375 {
00376 double E_vn = sumVN/numSamples;
00377 double E_ni = sumN[clauseno]/numSamples;
00378 double E_nivn = sumNiVN[clauseno]/numSamples;
00379 (*product)[clauseno] = E_nivn - E_ni * E_vn;
00380 }
00381
00382 return product;
00383 }
00384
00385
00386 void resetCnts()
00387 {
00388 for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00389 {
00390 (*clauseTrueCnts_)[clauseno] = 0;
00391 (*clauseTrueSqCnts_)[clauseno] = 0;
00392 }
00393 numSamples_ = 0;
00394
00395 if (saveAllCounts_)
00396 {
00397 delete allClauseTrueCnts_;
00398 allClauseTrueCnts_ = new Array<Array<double> >;
00399 }
00400 }
00401
00402
00403 void saveCnts()
00404 {
00405 if (!saveAllCounts_)
00406 return;
00407
00408
00409
00410
00411 delete oldAllClauseTrueCnts_;
00412 oldAllClauseTrueCnts_ = new Array<Array<double> > (*allClauseTrueCnts_);
00413
00414
00415
00416
00417
00418
00419 }
00420
00421
00422 void restoreCnts()
00423 {
00424 if (!saveAllCounts_)
00425 return;
00426
00427 resetCnts();
00428
00429 *allClauseTrueCnts_ = *oldAllClauseTrueCnts_;
00430 for (int i = 0; i < allClauseTrueCnts_->size(); i++)
00431 {
00432 int numcounts = (*allClauseTrueCnts_)[i].size();
00433 for (int j = 0; j < numcounts; j++)
00434 {
00435 double currcount = (*allClauseTrueCnts_)[i][j];
00436 (*clauseTrueCnts_)[j] += currcount;
00437 (*clauseTrueSqCnts_)[j] += currcount * currcount;
00438 }
00439 numSamples_++;
00440 }
00441
00442
00443
00444
00445
00446
00447 }
00448
00449
00450 void tallyCntsFromState()
00451 {
00452 int numcounts = clauseTrueCnts_->size();
00453 Array<double> currCounts(numcounts, 0.0);
00454 state_->getNumClauseGndings(&currCounts, true);
00455
00456 if (saveAllCounts_)
00457 {
00458 allClauseTrueCnts_->append(Array<double>());
00459 (*allClauseTrueCnts_)[numSamples_].growToSize(numcounts);
00460 }
00461
00462 for (int i = 0; i < numcounts; i++)
00463 {
00464 if (saveAllCounts_)
00465 (*allClauseTrueCnts_)[numSamples_][i] = currCounts[i];
00466
00467 (*clauseTrueCnts_)[i] += currCounts[i];
00468 (*clauseTrueSqCnts_)[i] += currCounts[i] * currCounts[i];
00469 }
00470 numSamples_++;
00471 }
00472
00473 protected:
00474
00475
00476 long int seed_;
00477
00478
00479 VariableState* state_;
00480 HVariableState* hstate_;
00481
00482
00483 bool saveAllCounts_;
00484
00485
00486 Array<double>* clauseTrueCnts_;
00487 Array<double>* clauseTrueCntsCont_;
00488
00489
00490
00491 Array<double>* clauseTrueSqCnts_;
00492
00493 int numSamples_;
00494
00495 bool trackClauseTrueCnts_;
00496
00497
00498 Array<Array<double> >* allClauseTrueCnts_;
00499
00500 Array<double>* oldClauseTrueCnts_;
00501 Array<Array<double> >* oldAllClauseTrueCnts_;
00502
00503 Array<Array<Predicate* >* >* queryFormulas_;
00504 Array<double>* qfProbs_;
00505 };
00506
00507 #endif