00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef HMCSAT_H_
00067 #define HMCSAT_H_
00068
00069 #include "mcmc.h"
00070 #include "mcsatparams.h"
00071 #include "hmaxwalksat.h"
00072 #include "hvariablestate.h"
00073
00074 const int hmsdebug = false;
00075 const int wjhmsdebug = false;
00076
00083 class HMCSAT : public MCMC
00084 {
00085 public:
00089 HMCSAT(HVariableState* state, long int seed, const bool& trackClauseTrueCnts,
00090 MCSatParams* mcsatParams)
00091 : MCMC(state, seed, trackClauseTrueCnts, mcsatParams)
00092 {
00093
00094 mws_ = new HMaxWalkSat(hstate_, seed, false, mcsatParams->mwsParams);
00095 bMaxWalkSat_ = false;
00096 contSamples_.growToSize(hstate_->contAtomNum_);
00097 print_vars_per_sample_ = false;
00098 }
00099
00103 ~HMCSAT()
00104 {
00105 delete mws_;
00106 }
00107
00108 void initNumTrueTotal()
00109 {
00110 int numDisPreds = hstate_->getNumAtoms();
00111 numTrue_.growToSize(numDisPreds, 0);
00112 }
00113
00117 void init()
00118 {
00119 assert(numChains_ == 1);
00120 initNumTrueTotal();
00121 hstate_->eliminateSoftClauses();
00122
00123 int numSolutions = mws_->getNumSolutions();
00124 mws_->setNumSolutions(1);
00125
00126 if (!bMaxWalkSat_)
00127 {
00128 hstate_->makeUnitCosts();
00129 }
00130
00131
00132
00133 mws_->setHeuristic(TABU);
00134 mws_->init();
00135 mws_->infer();
00136
00137 if (hmsdebug)
00138 {
00139 cout << "Low state:" << endl;
00140 hstate_->printLowState(cout);
00141 }
00142
00143 hstate_->saveLowStateToGndPreds();
00144 hstate_->saveCurrentAsLastAssignment();
00145 hstate_->UpdateHybridConstraintTh();
00146
00147 mws_->setHeuristic(SS);
00148 mws_->setNumSolutions(numSolutions);
00149 mws_->setTargetCost(0.0);
00150 hstate_->resetDeadClauses();
00151 sample_ = 0;
00152 }
00153
00154 void printContSamples()
00155 {
00156 for(int i = 0; i < contSamples_.size(); i ++)
00157 {
00158 for (int j = 0; j < contSamples_[i].size(); j++)
00159 {
00160 contSampleLog_ << contSamples_[i][j] << " ";
00161 }
00162 contSampleLog_ << endl;
00163 }
00164 }
00165
00169 void infer()
00170 {
00171 Timer timer;
00172
00173 bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00174 double secondsElapsed = 0;
00175 upSecondsElapsed_ = 0;
00176 ssSecondsElapsed_ = 0;
00177 double startTimeSec = timer.time();
00178 double currentTimeSec;
00179 int samplesPerOutput = 100;
00180
00181
00182 if (trackClauseTrueCnts_)
00183 {
00184 for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00185 (*clauseTrueCnts_)[clauseno] = 0;
00186 for (int clauseno = 0; clauseno < clauseTrueCntsCont_->size(); clauseno++)
00187 (*clauseTrueCntsCont_)[clauseno] = 0;
00188 }
00189
00190
00191 GroundPredicateHashArray affectedGndPreds;
00192 Array<int> affectedGndPredIndices;
00193
00194
00195 int numAtoms = hstate_->getNumAtoms();
00196
00197 for (int i = 0; i < numAtoms; i++)
00198 {
00199 affectedGndPreds.append(hstate_->getGndPred(i), numAtoms);
00200 affectedGndPredIndices.append(i);
00201 }
00202
00203 updateWtsForGndPredsH(affectedGndPreds, affectedGndPredIndices, 0);
00204 affectedGndPreds.clear();
00205 affectedGndPredIndices.clear();
00206 cout << "Running MC-SAT sampling..." << endl;
00207
00208 int numSamplesPerPred = 0;
00209 bool done = false;
00210 while (!done)
00211 {
00212 ++sample_;
00213 if (sample_ % samplesPerOutput == 0)
00214 {
00215 currentTimeSec = timer.time();
00216 secondsElapsed = currentTimeSec - startTimeSec;
00217 cout << "sample_ (per pred) " << sample_ << ", time elapsed = ";
00218 Timer::printTime(cout, secondsElapsed); cout << endl;
00219 }
00220
00221 performMCSatStep(burningIn);
00222
00223 if (!burningIn) numSamplesPerPred++;
00224
00225 if (burningIn)
00226 {
00227 if ((burnMaxSteps_ >= 0 && sample_ >= burnMaxSteps_) || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00228 {
00229 cout << "Done burning. " << sample_ << " samples." << endl;
00230 burningIn = false;
00231 sample_ = 0;
00232 }
00233 }
00234 else
00235 {
00236 if ( (maxSteps_ >= 0 && sample_ >= maxSteps_)
00237 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00238 {
00239 cout << "Done MC-SAT sampling. " << sample_ << " samples."
00240 << endl;
00241 done = true;
00242 }
00243 }
00244 cout.flush();
00245 }
00246
00247 cout<< "Time taken for MC-SAT sampling = ";
00248 Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00249
00250 cout<< "Time taken for unit propagation = ";
00251 Timer::printTime(cout, upSecondsElapsed_); cout << endl;
00252
00253 cout<< "Time taken for SampleSat = ";
00254 Timer::printTime(cout, ssSecondsElapsed_); cout << endl;
00255
00256
00257 for (int i = 0; i < hstate_->getNumAtoms(); i++)
00258 {
00259 setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00260 }
00261
00262
00263 if (trackClauseTrueCnts_)
00264 {
00265
00266 for (int i = 0; i < clauseTrueCnts_->size(); i++)
00267 (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00268 for (int i = 0; i < clauseTrueCntsCont_->size(); i++)
00269 (*clauseTrueCntsCont_)[i] = (*clauseTrueCntsCont_)[i] / numSamplesPerPred;
00270 }
00271 }
00272
00273
00274 void SetContSampleFile(const char* contsamplelog)
00275 {
00276 contSampleLog_.open(contsamplelog);
00277 if (!contSampleLog_.is_open())
00278 {
00279 cout << "cont sample log file is not specified." << endl;
00280 }
00281 }
00282
00283
00284
00285 void SetPrintVarsPerSample(bool pps) {
00286 print_vars_per_sample_ = pps;
00287 }
00288
00289
00290 bool bMaxWalkSat_;
00291
00292 private:
00298 void performMCSatStep(const bool& burningIn)
00299 {
00300 if (hmsdebug)
00301 {
00302 cout << "Num of clauses " << hstate_->getNumClauses() << endl;
00303 cout << "Num of dead clauses " << hstate_->getNumDeadClauses() << endl;
00304 }
00305 Timer timer;
00306 double startTime;
00307 if (hmsdebug) cout << "Entering MC-SAT step" << endl;
00308
00309 hstate_->setUseThreshold(true);
00310 int start = 0;
00311
00312
00313
00314 hstate_->killClauses(start);
00315
00316
00317 startTime = timer.time();
00318 mws_->init();
00319 mws_->infer();
00320 ssSecondsElapsed_ += (timer.time() - startTime);
00321
00322 hstate_->saveLowStateToGndPreds();
00323 if (print_vars_per_sample_) {
00324 cout << "HMCS Iter#" << sample_ << ":\t";
00325 for (int i = 0; i < hstate_->getNumAtoms(); ++i) {
00326 cout << hstate_->atom_[i + 1] << "\t";
00327 }
00328
00329 for (int i = 0; i < hstate_->getNumContAtoms(); ++i) {
00330 cout << hstate_->contAtoms_[i+1] << "\t";
00331 }
00332 cout << endl;
00333 }
00334 if (wjhmsdebug && hstate_->costOfTotalFalseConstraints_ < mws_->getTargetCost() + SMALLVALUE)
00335 {
00336 hstate_->printLowStateAll(cout);
00337 }
00338 if (wjhmsdebug && hstate_->costOfTotalFalseConstraints_ > mws_->getTargetCost() + SMALLVALUE)
00339 {
00340 cout << "not all satisfied, at sample " << sample_ << ". " << endl;
00341 cout << "Error at sample: " << sample_ << ", " << hstate_->costOfTotalFalseConstraints_ << endl;
00342 cout << "False cont constraints: " << hstate_->costHybridFalseConstraint_ << ", false discrete: " << hstate_->getCostOfFalseClauses() << endl;
00343
00344 hstate_->printFalseClauses(cout);
00345 }
00346
00347 int numAtoms = hstate_->getNumAtoms();
00348 for (int i = 0; i < numAtoms; i++)
00349 {
00350 GroundPredicate* gndPred = hstate_->getGndPred(i);
00351 bool newAssignment = hstate_->getValueOfLowAtom(i + 1);
00352
00353 if (newAssignment != gndPred->getTruthValue()) {
00354 gndPred->setTruthValue(newAssignment);
00355 updateClauses(i);
00356 }
00357 }
00358
00359 for(int i = 0; i < hstate_->contAtomNum_; i++)
00360 {
00361 contSamples_[i].append(hstate_->contAtoms_[i+1]);
00362 contSampleLog_ << hstate_->contAtoms_[i+1] << " ";
00363 }
00364 contSampleLog_ << endl;
00365
00366 for( int i = 0; i < hstate_->getNumAtoms(); i++)
00367 {
00368 bool newAssignment = hstate_->getValueOfAtom(i+1);
00369 if (!burningIn && newAssignment) numTrue_[i]++;
00370 }
00371 hstate_->resetFixedAtoms();
00372 hstate_->resetDeadClauses();
00373 hstate_->saveCurrentAsLastAssignment();
00374
00375 hstate_->UpdateHybridConstraintTh();
00376 hstate_->setUseThreshold(false);
00377
00378 if (!burningIn && trackClauseTrueCnts_)
00379 {
00380 hstate_->getNumClauseGndings(clauseTrueCnts_, true);
00381 hstate_->getContClauseGndings(clauseTrueCntsCont_);
00382 }
00383 if (hmsdebug) cout << "Leaving MC-SAT step" << endl;
00384 }
00385
00393 void updateClauses(const int& gndPredIdx)
00394 {
00395 if (hmsdebug) cout << "Entering updateClauses" << endl;
00396 GroundPredicate* gndPred = hstate_->getGndPred(gndPredIdx);
00397 Array<int>& negGndClauses =
00398 hstate_->getNegOccurenceArray(gndPredIdx + 1);
00399 Array<int>& posGndClauses =
00400 hstate_->getPosOccurenceArray(gndPredIdx + 1);
00401 int gndClauseIdx;
00402 bool sense;
00403
00404 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00405 {
00406 if (i < negGndClauses.size())
00407 {
00408 gndClauseIdx = negGndClauses[i];
00409 sense = false;
00410 }
00411 else
00412 {
00413 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00414 sense = true;
00415 }
00416
00417 if (gndPred->getTruthValue() == sense)
00418 hstate_->incrementNumTrueLits(gndClauseIdx);
00419 else
00420 hstate_->decrementNumTrueLits(gndClauseIdx);
00421 }
00422 if (hmsdebug) cout << "Leaving updateClauses" << endl;
00423 }
00424
00425 private:
00426 ofstream contSampleLog_;
00427 int sample_;
00428 bool print_vars_per_sample_;
00429
00430 Array<Array<double> > contSamples_;
00431
00432
00433
00434
00435 HMaxWalkSat* mws_;
00436
00437
00438 double upSecondsElapsed_;
00439
00440 double ssSecondsElapsed_;
00441 };
00442
00443 #endif