00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #ifndef MCSAT_H_
00068 #define MCSAT_H_
00069
00070 #include "mcmc.h"
00071 #include "mcsatparams.h"
00072 #include "unitpropagation.h"
00073 #include "maxwalksat.h"
00074
00075 const int msdebug = false;
00076
00083 class MCSAT : public MCMC
00084 {
00085 public:
00086
00090 MCSAT(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00091 MCSatParams* mcsatParams,
00092 Array<Array<Predicate* >* >* queryFormulas = NULL)
00093 : MCMC(state, seed, trackClauseTrueCnts, mcsatParams, queryFormulas)
00094 {
00095 Timer timer1;
00096
00097
00098
00099 mws_ = new MaxWalkSat(state_, seed, false, mcsatParams->mwsParams);
00100 mws_->setPrintInfo(false);
00101
00102 if (msdebug)
00103 {
00104 cout << "[MCSAT] ";
00105 Timer::printTime(cout, timer1.time());
00106 cout << endl;
00107 timer1.reset();
00108 }
00109 }
00110
00114 ~MCSAT()
00115 {
00116
00117 delete mws_;
00118 }
00119
00123 void init()
00124 {
00125 Timer timer1;
00126 assert(numChains_ == 1);
00127
00128 cout << "Initializing MC-SAT with MaxWalksat on hard clauses..." << endl;
00129
00130 state_->eliminateSoftClauses();
00131 state_->setInferenceMode(state_->MODE_HARD);
00132
00133
00134 int numSolutions = mws_->getNumSolutions();
00135 mws_->setNumSolutions(1);
00136
00137
00138 mws_->init();
00139 mws_->infer();
00140
00141 if (msdebug)
00142 {
00143 cout << "Low state:" << endl;
00144 state_->printLowState(cout);
00145 }
00146 state_->saveLowStateToGndPreds();
00147
00148
00149 mws_->setHeuristic(SS);
00150 mws_->setNumSolutions(numSolutions);
00151 mws_->setTargetCost(0.0);
00152 state_->resetDeadClauses();
00153
00154
00155 state_->setInferenceMode(state_->MODE_SAMPLESAT);
00156
00157 if (msdebug)
00158 {
00159 cout << "[MCSAT.init] ";
00160 Timer::printTime(cout, timer1.time());
00161 cout << endl;
00162 timer1.reset();
00163 }
00164 }
00165
00169 void infer()
00170 {
00171 Timer timer1;
00172
00173 initNumTrue();
00174 Timer timer;
00175
00176 bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00177 double secondsElapsed = 0;
00178
00179 ssSecondsElapsed_ = 0;
00180 double startTimeSec = timer.time();
00181 double currentTimeSec;
00182 int samplesPerOutput = 100;
00183
00184
00185 GroundPredicateHashArray affectedGndPreds;
00186 Array<int> affectedGndPredIndices;
00187
00188 int numAtoms = state_->getNumAtoms();
00189 for (int i = 0; i < numAtoms; i++)
00190 {
00191 affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00192 affectedGndPredIndices.append(i);
00193 }
00194 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 0);
00195 affectedGndPreds.clear();
00196 affectedGndPredIndices.clear();
00197
00198 if (msdebug)
00199 {
00200 cout << "[MCSAT.infer.prep] ";
00201 Timer::printTime(cout, timer1.time());
00202 cout << endl;
00203 timer1.reset();
00204 }
00205
00206 cout << "Running MC-SAT sampling..." << endl;
00207
00208 int sample = 0;
00209 int numSamplesPerPred = 0;
00210 bool done = false;
00211 while (!done)
00212 {
00213 ++sample;
00214 if (sample % samplesPerOutput == 0)
00215 {
00216 currentTimeSec = timer.time();
00217 secondsElapsed = currentTimeSec - startTimeSec;
00218 cout << "Sample (per pred) " << sample << ", time elapsed = ";
00219 Timer::printTime(cout, secondsElapsed);
00220 cout << ", num. preds = " << state_->getNumAtoms();
00221 cout << ", num. clauses = " << state_->getNumClauses();
00222 cout << endl;
00223 }
00224
00225
00226 performMCSatStep(burningIn);
00227
00228 if (!burningIn) numSamplesPerPred++;
00229
00230 if (burningIn)
00231 {
00232 if ( (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00233 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00234 {
00235 cout << "Done burning. " << sample << " samples." << endl;
00236 burningIn = false;
00237 sample = 0;
00238 }
00239 }
00240 else
00241 {
00242 if ( (maxSteps_ >= 0 && sample >= maxSteps_)
00243 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00244 {
00245 cout << "Done MC-SAT sampling. " << sample << " samples."
00246 << endl;
00247 done = true;
00248 }
00249 }
00250 cout.flush();
00251 }
00252
00253 cout << "Final ground predicate number: " << state_->getNumAtoms() << endl;
00254 cout << "Final ground clause number: " << state_->getNumClauses() << endl;
00255
00256 cout<< "Time taken for MC-SAT sampling = ";
00257 Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00258
00259
00260
00261
00262 cout<< "Time taken for SampleSat = ";
00263 Timer::printTime(cout, ssSecondsElapsed_); cout << endl;
00264
00265
00266 for (int i = 0; i < state_->getNumAtoms(); i++)
00267 {
00268 setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00269 }
00270
00271
00272 if (qfProbs_)
00273 {
00274 for (int j = 0; j < qfProbs_->size(); j++)
00275 (*qfProbs_)[j] = (*qfProbs_)[j] / numSamplesPerPred;
00276 }
00277 }
00278
00279 private:
00280
00286 void performMCSatStep(const bool& burningIn)
00287 {
00288 Timer timer;
00289 double startTime;
00290 if (msdebug) cout << "Entering MC-SAT step" << endl;
00291
00292 state_->setUseThreshold(true);
00293 state_->updatePrevSatisfied();
00294
00295 int start = 0;
00296
00297 state_->resetMakeBreakCostWatch();
00298
00299 state_->killClauses(start);
00300
00301 if (msdebug)
00302 {
00303 cout << "Num of clauses " << state_->getNumClauses() << endl;
00304 cout << "Num of dead clauses " << state_->getNumDeadClauses() << endl;
00305 }
00306
00307
00308 startTime = timer.time();
00309
00310
00311
00312
00313 startTime = timer.time();
00314 mws_->init();
00315 mws_->infer();
00316 ssSecondsElapsed_ += (timer.time() - startTime);
00317
00318 if (msdebug)
00319 {
00320 cout << "Low state:" << endl;
00321 state_->printLowState(cout);
00322 }
00323
00324
00325
00326 state_->resetFixedAtoms();
00327 state_->resetDeadClauses();
00328 state_->setUseThreshold(false);
00329 int numAtoms = state_->getNumAtoms();
00330
00331 numTrue_.growToSize(numAtoms, 0);
00332
00333 for (int i = 0; i < numAtoms; i++)
00334 {
00335 GroundPredicate* gndPred = state_->getGndPred(i);
00336 bool newAssignment = state_->getValueOfLowAtom(i + 1);
00337
00338
00339 if (newAssignment != gndPred->getTruthValue())
00340 {
00341 gndPred->setTruthValue(newAssignment);
00342 updateClauses(i);
00343 }
00344
00345
00346
00347 if (!burningIn && newAssignment) numTrue_[i]++;
00348 }
00349
00350
00351 if (!burningIn && trackClauseTrueCnts_)
00352 tallyCntsFromState();
00353
00354
00355 if (!burningIn && qfProbs_)
00356 {
00357 for (int i = 0; i < queryFormulas_->size(); i++)
00358 {
00359 Array<Predicate* >* formula = (*queryFormulas_)[i];
00360 bool satisfied = true;
00361 for (int j = 0; j < formula->size(); j++)
00362 {
00363 bool sense = (*formula)[j]->getSense();
00364 GroundPredicate* pred = new GroundPredicate((*formula)[j]);
00365 TruthValue tv = state_->getDomain()->getDB()->getValue(pred);
00366 if ((tv == TRUE && !sense) || (tv == FALSE && sense))
00367 satisfied = false;
00368 delete pred;
00369 if (!satisfied) break;
00370 }
00371 if (satisfied) (*qfProbs_)[i]++;
00372 }
00373 }
00374
00375 if (msdebug) cout << "Leaving MC-SAT step" << endl;
00376 }
00377
00385 void updateClauses(const int& gndPredIdx)
00386 {
00387 if (msdebug) cout << "Entering updateClauses" << endl;
00388 GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
00389 Array<int>& negGndClauses =
00390 state_->getNegOccurenceArray(gndPredIdx + 1);
00391 Array<int>& posGndClauses =
00392 state_->getPosOccurenceArray(gndPredIdx + 1);
00393 int gndClauseIdx;
00394 bool sense;
00395
00396 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00397 {
00398 if (i < negGndClauses.size())
00399 {
00400 gndClauseIdx = negGndClauses[i];
00401 sense = false;
00402 }
00403 else
00404 {
00405 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00406 sense = true;
00407 }
00408
00409 if (gndPred->getTruthValue() == sense)
00410 state_->incrementNumTrueLits(gndClauseIdx);
00411 else
00412 state_->decrementNumTrueLits(gndClauseIdx);
00413 }
00414 if (msdebug) cout << "Leaving updateClauses" << endl;
00415 }
00416
00417 private:
00418
00419
00420
00421
00422
00423
00424
00425 MaxWalkSat* mws_;
00426
00427
00428
00429
00430 double ssSecondsElapsed_;
00431 };
00432
00433 #endif