00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef MCSAT_H_
00067 #define MCSAT_H_
00068
00069 #include "mcmc.h"
00070 #include "mcsatparams.h"
00071 #include "unitpropagation.h"
00072 #include "maxwalksat.h"
00073
00074 const int msdebug = false;
00075
00082 class MCSAT : public MCMC
00083 {
00084 public:
00085
00089 MCSAT(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00090 MCSatParams* mcsatParams)
00091 : MCMC(state, seed, trackClauseTrueCnts, mcsatParams)
00092 {
00093 numStepsEveryMCSat_ = mcsatParams->numStepsEveryMCSat;
00094
00095 up_ = new UnitPropagation(state_, seed, false);
00096 mws_ = new MaxWalkSat(state_, seed, false, mcsatParams->mwsParams);
00097 }
00098
00102 ~MCSAT()
00103 {
00104 delete up_;
00105 delete mws_;
00106 }
00107
00111 void init()
00112 {
00113 assert(numChains_ == 1);
00114 initNumTrue();
00115
00116 cout << "Initializing MC-SAT with MaxWalksat on hard clauses..." << endl;
00117 state_->eliminateSoftClauses();
00118
00119 int numSolutions = mws_->getNumSolutions();
00120 mws_->setNumSolutions(1);
00121
00122
00123 mws_->init();
00124 mws_->infer();
00125
00126 if (msdebug)
00127 {
00128 cout << "Low state:" << endl;
00129 state_->printLowState(cout);
00130 }
00131 state_->saveLowStateToGndPreds();
00132
00133
00134 mws_->setHeuristic(SS);
00135 mws_->setNumSolutions(numSolutions);
00136 mws_->setTargetCost(0.0);
00137 state_->resetDeadClauses();
00138 state_->makeUnitCosts();
00139 }
00140
00144 void infer()
00145 {
00146 Timer timer;
00147
00148 bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00149 double secondsElapsed = 0;
00150 upSecondsElapsed_ = 0;
00151 ssSecondsElapsed_ = 0;
00152 double startTimeSec = timer.time();
00153 double currentTimeSec;
00154 int samplesPerOutput = 100;
00155
00156
00157 if (trackClauseTrueCnts_)
00158 for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00159 (*clauseTrueCnts_)[clauseno] = 0;
00160
00161
00162 GroundPredicateHashArray affectedGndPreds;
00163 Array<int> affectedGndPredIndices;
00164
00165 int numAtoms = state_->getNumAtoms();
00166 for (int i = 0; i < numAtoms; i++)
00167 {
00168 affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00169 affectedGndPredIndices.append(i);
00170 }
00171 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 0);
00172 affectedGndPreds.clear();
00173 affectedGndPredIndices.clear();
00174
00175 cout << "Running MC-SAT sampling..." << endl;
00176
00177 int sample = 0;
00178 int numSamplesPerPred = 0;
00179 bool done = false;
00180 while (!done)
00181 {
00182 ++sample;
00183 bool mcSatStep = (sample % numStepsEveryMCSat_ == 0);
00184
00185 if (sample % samplesPerOutput == 0)
00186 {
00187 currentTimeSec = timer.time();
00188 secondsElapsed = currentTimeSec - startTimeSec;
00189 cout << "Sample (per pred) " << sample << ", time elapsed = ";
00190 Timer::printTime(cout, secondsElapsed); cout << endl;
00191 }
00192
00193
00194 if (mcSatStep) performMCSatStep(burningIn);
00195
00196
00197 else performGibbsStep(0, burningIn, affectedGndPreds,
00198 affectedGndPredIndices);
00199
00200 if (!burningIn) numSamplesPerPred++;
00201
00202 if (burningIn)
00203 {
00204 if ( (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00205 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00206 {
00207 cout << "Done burning. " << sample << " samples." << endl;
00208 burningIn = false;
00209 sample = 0;
00210 }
00211 }
00212 else
00213 {
00214 if ( (maxSteps_ >= 0 && sample >= maxSteps_)
00215 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00216 {
00217 cout << "Done MC-SAT sampling. " << sample << " samples."
00218 << endl;
00219 done = true;
00220 }
00221 }
00222 cout.flush();
00223 }
00224
00225 cout<< "Time taken for MC-SAT sampling = ";
00226 Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00227
00228 cout<< "Time taken for unit propagation = ";
00229 Timer::printTime(cout, upSecondsElapsed_); cout << endl;
00230
00231 cout<< "Time taken for SampleSat = ";
00232 Timer::printTime(cout, ssSecondsElapsed_); cout << endl;
00233
00234
00235 for (int i = 0; i < state_->getNumAtoms(); i++)
00236 {
00237 setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00238 }
00239
00240
00241 if (trackClauseTrueCnts_)
00242 {
00243
00244 for (int i = 0; i < clauseTrueCnts_->size(); i++)
00245 (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00246 }
00247 }
00248
00249 private:
00250
00256 void performMCSatStep(const bool& burningIn)
00257 {
00258 Timer timer;
00259 double startTime;
00260 if (msdebug) cout << "Entering MC-SAT step" << endl;
00261
00262 state_->setUseThreshold(true);
00263 int start = 0;
00264 state_->killClauses(start);
00265
00266 if (msdebug)
00267 {
00268 cout << "Num of clauses " << state_->getNumClauses() << endl;
00269 cout << "Num of dead clauses " << state_->getNumDeadClauses() << endl;
00270 }
00271
00272
00273 startTime = timer.time();
00274 up_->init();
00275 up_->infer();
00276 upSecondsElapsed_ += (timer.time() - startTime);
00277
00278 startTime = timer.time();
00279 mws_->init();
00280 mws_->infer();
00281 ssSecondsElapsed_ += (timer.time() - startTime);
00282
00283 if (msdebug)
00284 {
00285 cout << "Low state:" << endl;
00286 state_->printLowState(cout);
00287 }
00288 state_->saveLowStateToGndPreds();
00289
00290
00291 state_->resetFixedAtoms();
00292 state_->resetDeadClauses();
00293 state_->setUseThreshold(false);
00294 int numAtoms = state_->getNumAtoms();
00295 for (int i = 0; i < numAtoms; i++)
00296 {
00297 GroundPredicate* gndPred = state_->getGndPred(i);
00298 bool newAssignment = state_->getValueOfLowAtom(i + 1);
00299
00300
00301 if (newAssignment != gndPred->getTruthValue())
00302 {
00303 gndPred->setTruthValue(newAssignment);
00304 updateClauses(i);
00305 }
00306
00307
00308
00309 if (!burningIn && newAssignment) numTrue_[i]++;
00310 }
00311
00312
00313 if (!burningIn && trackClauseTrueCnts_)
00314 state_->getNumClauseGndings(clauseTrueCnts_, true);
00315
00316 if (msdebug) cout << "Leaving MC-SAT step" << endl;
00317 }
00318
00326 void updateClauses(const int& gndPredIdx)
00327 {
00328 if (msdebug) cout << "Entering updateClauses" << endl;
00329 GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
00330 Array<int>& negGndClauses =
00331 state_->getNegOccurenceArray(gndPredIdx + 1);
00332 Array<int>& posGndClauses =
00333 state_->getPosOccurenceArray(gndPredIdx + 1);
00334 int gndClauseIdx;
00335 bool sense;
00336
00337 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00338 {
00339 if (i < negGndClauses.size())
00340 {
00341 gndClauseIdx = negGndClauses[i];
00342 sense = false;
00343 }
00344 else
00345 {
00346 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00347 sense = true;
00348 }
00349
00350 if (gndPred->getTruthValue() == sense)
00351 state_->incrementNumTrueLits(gndClauseIdx);
00352 else
00353 state_->decrementNumTrueLits(gndClauseIdx);
00354 }
00355 if (msdebug) cout << "Leaving updateClauses" << endl;
00356 }
00357
00358 private:
00359
00360
00361 int numStepsEveryMCSat_;
00362
00363
00364 UnitPropagation* up_;
00365
00366 MaxWalkSat* mws_;
00367
00368
00369 double upSecondsElapsed_;
00370
00371 double ssSecondsElapsed_;
00372 };
00373
00374 #endif