00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef GIBBSSAMPLER_H_
00067 #define GIBBSSAMPLER_H_
00068
00069 #include "mcmc.h"
00070 #include "gibbsparams.h"
00071 #include "maxwalksat.h"
00072 #include "convergencetest.h"
00073 #include "gelmanconvergencetest.h"
00074
00075
00076 enum WalksatType { NONE = 0, MAXWALKSAT = 1 };
00077
00078 const bool gibbsdebug = false;
00079
00083 class GibbsSampler : public MCMC
00084 {
00085 public:
00086
00093 GibbsSampler(VariableState* state, long int seed,
00094 const bool& trackClauseTrueCnts, GibbsParams* gibbsParams)
00095 : MCMC(state, seed, trackClauseTrueCnts, gibbsParams)
00096 {
00097
00098 gamma_ = gibbsParams->gamma;
00099 epsilonError_ = gibbsParams->epsilonError;
00100 fracConverged_ = gibbsParams->fracConverged;
00101 walksatType_ = gibbsParams->walksatType;
00102 samplesPerTest_ = gibbsParams->samplesPerTest;
00103
00104
00105 mws_ = new MaxWalkSat(state_, seed, false, gibbsParams->mwsParams);
00106 }
00107
00111 ~GibbsSampler()
00112 {
00113 deleteConvergenceTests(burnConvergenceTests_, gibbsConvergenceTests_,
00114 state_->getNumAtoms());
00115 delete mws_;
00116 }
00117
00121 void init()
00122 {
00123
00124 initTruthValuesAndWts(numChains_);
00125 initNumTrue();
00126
00127 cout << "Initializing Gibbs sampling " ;
00128
00129 if (walksatType_ == 1)
00130 {
00131 cout << "with MaxWalksat" << endl;
00132 for (int c = 0; c < numChains_; c++)
00133 {
00134 cout << "for chain " << c << "..." << endl;
00135 mws_->init();
00136 mws_->infer();
00137 saveLowStateToChain(c);
00138 }
00139 }
00140
00141 else
00142 {
00143 cout << "randomly" << endl;
00144 randomInitGndPredsTruthValues(numChains_);
00145 }
00146
00147
00148
00149 initNumTrueLits(numChains_);
00150
00151 int numGndPreds = state_->getNumAtoms();
00152
00153 initConvergenceTests(burnConvergenceTests_, gibbsConvergenceTests_,
00154 gamma_, epsilonError_, numGndPreds, numChains_);
00155 }
00156
00160 void infer()
00161 {
00162 Timer timer;
00163
00164 bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00165 double secondsElapsed = 0;
00166 double startTimeSec = timer.time();
00167 double currentTimeSec;
00168
00169
00170 if (trackClauseTrueCnts_)
00171 for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00172 (*clauseTrueCnts_)[clauseno] = 0;
00173
00174
00175 GroundPredicateHashArray affectedGndPreds;
00176 Array<int> affectedGndPredIndices;
00177
00178 int numAtoms = state_->getNumAtoms();
00179 for (int i = 0; i < numAtoms; i++)
00180 {
00181 affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00182 affectedGndPredIndices.append(i);
00183 }
00184 for (int c = 0; c < numChains_; c++)
00185 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00186 affectedGndPreds.clear();
00187 affectedGndPredIndices.clear();
00188
00189 cout << "Running Gibbs sampling..." << endl;
00190
00191 int sample = 0;
00192 int numSamplesPerPred = 0;
00193 bool done = false;
00194 while (!done)
00195 {
00196 ++sample;
00197
00198 if (sample % samplesPerTest_ == 0)
00199 {
00200 currentTimeSec = timer.time();
00201 secondsElapsed = currentTimeSec-startTimeSec;
00202 cout << "Sample (per pred per chain) " << sample << ", time elapsed = ";
00203 Timer::printTime(cout, secondsElapsed); cout << endl;
00204 }
00205
00206
00207 for (int c = 0; c < numChains_; c++)
00208 {
00209 performGibbsStep(c, burningIn, affectedGndPreds,
00210 affectedGndPredIndices);
00211 if (!burningIn) numSamplesPerPred++;
00212 }
00213
00214
00215 for (int i = 0; i < state_->getNumAtoms(); i++)
00216 {
00217 const bool* vals = truthValues_[i].getItems();
00218
00219 if (burningIn) burnConvergenceTests_[i]->appendNewValues(vals);
00220 else gibbsConvergenceTests_[i]->appendNewValues(vals);
00221 }
00222
00223 if (sample % samplesPerTest_ != 0) continue;
00224 if (burningIn)
00225 {
00226
00227
00228 bool burnConverged
00229 = GelmanConvergenceTest::checkConvergenceOfAll(burnConvergenceTests_,
00230 state_->getNumAtoms(),
00231 true);
00232 if ( (sample >= burnMinSteps_ && burnConverged)
00233 || (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00234 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00235 {
00236 cout << "Done burning. " << sample << " samples per pred per chain ("
00237 << (burnConverged? "converged":"didn't converge")
00238 <<" at total of " << numChains_*sample << " samples per pred)"
00239 << endl;
00240 burningIn = false;
00241 sample = 0;
00242 }
00243 }
00244 else
00245 {
00246 bool gibbsConverged
00247 = ConvergenceTest::checkConvergenceOfAtLeast(gibbsConvergenceTests_,
00248 state_->getNumAtoms(),
00249 sample, fracConverged_,
00250 true);
00251 if ( (sample >= minSteps_ && gibbsConverged)
00252 || (maxSteps_ >= 0 && sample >= maxSteps_)
00253 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00254 {
00255 cout << "Done Gibbs sampling. " << sample
00256 << " samples per pred per chain ("
00257 << (gibbsConverged? "converged":"didn't converge")
00258 <<" at total of " << numSamplesPerPred << " samples per pred)"
00259 << endl;
00260 done = true;
00261 }
00262 }
00263 cout.flush();
00264 }
00265
00266 cout<< "Time taken for Gibbs sampling = ";
00267 Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00268
00269
00270 for (int i = 0; i < state_->getNumAtoms(); i++)
00271 {
00272 setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00273 }
00274
00275
00276 if (trackClauseTrueCnts_)
00277 {
00278
00279 for (int i = 0; i < clauseTrueCnts_->size(); i++)
00280 (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00281 }
00282 }
00283
00284 private:
00285
00289 void initConvergenceTests(GelmanConvergenceTest**& burnConvergenceTests,
00290 ConvergenceTest**& gibbsConvergenceTests,
00291 const double& gamma, const double& epsilonFrac,
00292 const int& numGndPreds, const int& numChains)
00293 {
00294 burnConvergenceTests = new GelmanConvergenceTest*[numGndPreds];
00295 gibbsConvergenceTests = new ConvergenceTest*[numGndPreds];
00296 for (int i = 0; i < numGndPreds; i++)
00297 {
00298 burnConvergenceTests[i] = new GelmanConvergenceTest(numChains);
00299 gibbsConvergenceTests[i] = new ConvergenceTest(numChains, gamma,
00300 epsilonFrac);
00301 }
00302 }
00303
00307 void deleteConvergenceTests(GelmanConvergenceTest**& burnConvergenceTests,
00308 ConvergenceTest**& gibbsConvergenceTests,
00309 const int& numGndPreds)
00310 {
00311 for (int i = 0; i < numGndPreds; i++)
00312 {
00313 delete burnConvergenceTests[i];
00314 delete gibbsConvergenceTests[i];
00315 }
00316 delete [] burnConvergenceTests;
00317 delete [] gibbsConvergenceTests;
00318 }
00319
00320 private:
00321
00322 double gamma_;
00323
00324 double epsilonError_;
00325
00326 double fracConverged_;
00327
00328 int walksatType_;
00329
00330 int samplesPerTest_;
00331
00332 GelmanConvergenceTest** burnConvergenceTests_;
00333
00334 ConvergenceTest** gibbsConvergenceTests_;
00335
00336
00337 MaxWalkSat* mws_;
00338 };
00339
00340 #endif