00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef SIMULATEDTEMPERING_H_
00067 #define SIMULATEDTEMPERING_H_
00068
00069 #include "mcmc.h"
00070 #include "simulatedtemperingparams.h"
00071 #include "maxwalksat.h"
00072 #include "convergencetest.h"
00073 #include "gelmanconvergencetest.h"
00074
00078 class SimulatedTempering : public MCMC
00079 {
00080 public:
00081
00085 SimulatedTempering(VariableState* state, long int seed,
00086 const bool& trackClauseTrueCnts,
00087 SimulatedTemperingParams* stParams)
00088 : MCMC(state, seed, trackClauseTrueCnts, stParams)
00089 {
00090
00091 subInterval_ = stParams->subInterval;
00092 numST_ = stParams->numST;
00093 numSwap_ = stParams->numSwap;
00094
00095 numChains_ = numSwap_*numST_;
00096
00097
00098
00099
00100
00101
00102 selInterval_ = subInterval_*(numSwap_ - 1);
00103
00104
00105 invTemps_ = new double*[numST_];
00106
00107 chainIds_ = new int*[numST_];
00108
00109 tempIds_ = new int*[numST_];
00110
00111 mws_ = new MaxWalkSat(state_, seed, false, stParams->mwsParams);
00112 }
00113
00117 ~SimulatedTempering()
00118 {
00119 for (int i = 0; i < numST_; i++)
00120 {
00121 delete [] invTemps_[i];
00122 delete [] chainIds_[i];
00123 delete [] tempIds_[i];
00124 }
00125 delete [] invTemps_;
00126 delete [] chainIds_;
00127 delete [] tempIds_;
00128 delete mws_;
00129 }
00130
00134 void init()
00135 {
00136
00137
00138 initTruthValuesAndWts(numChains_);
00139
00140
00141 cout << "Initializing Simulated Tempering with MaxWalksat" << endl;
00142 state_->eliminateSoftClauses();
00143
00144 int numSolutions = mws_->getNumSolutions();
00145 mws_->setNumSolutions(1);
00146 for (int c = 0; c < numChains_; c++)
00147 {
00148 cout << "for chain " << c << "..." << endl;
00149
00150 mws_->init();
00151 mws_->infer();
00152 saveLowStateToChain(c);
00153 }
00154 mws_->setNumSolutions(numSolutions);
00155 state_->resetDeadClauses();
00156
00157
00158 double maxWt = state_->getMaxClauseWeight();
00159 double maxWtForEvenSchedule = 100.0;
00160 double base = log(maxWt) / log((double)numSwap_);
00161 double* divs = new double[numSwap_];
00162 divs[0] = 1.0;
00163
00164 for (int i = 1; i < numSwap_; i++)
00165 {
00166 divs[i] = divs[i - 1] / base;
00167 }
00168
00169 for (int i = 0; i < numST_; i++)
00170 {
00171 invTemps_[i] = new double[numSwap_];
00172 chainIds_[i] = new int[numSwap_];
00173 tempIds_[i] = new int[numSwap_];
00174 for (int j = 0; j < numSwap_; j++)
00175 {
00176 chainIds_[i][j] = j;
00177 tempIds_[i][j] = j;
00178
00179 if (maxWt > maxWtForEvenSchedule)
00180 {
00181 invTemps_[i][j] = divs[j];
00182 }
00183 else
00184 {
00185 invTemps_[i][j] = 1.0-((double)j)/((double) numSwap_);
00186 }
00187 }
00188 }
00189 delete [] divs;
00190
00191
00192
00193 initNumTrueLits(numChains_);
00194 }
00195
00199 void infer()
00200 {
00201 initNumTrue();
00202 Timer timer;
00203
00204 bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00205 double secondsElapsed = 0;
00206 double startTimeSec = timer.time();
00207 double currentTimeSec;
00208 int samplesPerOutput = 100;
00209
00210
00211 if (trackClauseTrueCnts_)
00212 for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00213 (*clauseTrueCnts_)[clauseno] = 0;
00214
00215
00216 GroundPredicateHashArray affectedGndPreds;
00217 Array<int> affectedGndPredIndices;
00218
00219 int numAtoms = state_->getNumAtoms();
00220 for (int i = 0; i < numAtoms; i++)
00221 {
00222 affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00223 affectedGndPredIndices.append(i);
00224 }
00225 for (int c = 0; c < numChains_; c++)
00226 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00227 affectedGndPreds.clear();
00228 affectedGndPredIndices.clear();
00229
00230 cout << "Running Simulated Tempering sampling..." << endl;
00231
00232 int sample = 0;
00233 int numSamplesPerPred = 0;
00234 bool done = false;
00235 while (!done)
00236 {
00237 ++sample;
00238
00239 if (sample % samplesPerOutput == 0)
00240 {
00241 currentTimeSec = timer.time();
00242 secondsElapsed = currentTimeSec-startTimeSec;
00243 cout << "Sample (per pred per chain) " << sample << ", time elapsed = ";
00244 Timer::printTime(cout, secondsElapsed); cout << endl;
00245 }
00246
00247
00248 if ((sample % selInterval_) % subInterval_ == 0)
00249 {
00250 int attemptTempId = (sample % selInterval_) / subInterval_;
00251 if (attemptTempId < numSwap_ - 1)
00252 {
00253 double wl, wh, itl, ith;
00254 for (int i = 0; i < numST_; i++)
00255 {
00256 int lChainId = chainIds_[i][attemptTempId];
00257 int hChainId = chainIds_[i][attemptTempId + 1];
00258
00259
00260 wl = getWeightSum(i*numSwap_ + lChainId);
00261 wh = getWeightSum(i*numSwap_ + hChainId);
00262 itl = invTemps_[i][attemptTempId];
00263 ith = invTemps_[i][attemptTempId + 1];
00264
00265 if (wl <= wh || random() <= RAND_MAX*exp((itl - ith)*(wh - wl)))
00266 {
00267 chainIds_[i][attemptTempId] = hChainId;
00268 chainIds_[i][attemptTempId+1] = lChainId;
00269 tempIds_[i][hChainId] = attemptTempId;
00270 tempIds_[i][lChainId] = attemptTempId + 1;
00271 }
00272 }
00273 }
00274 }
00275
00276
00277 for (int c = 0; c < numChains_; c++)
00278 {
00279
00280 for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++)
00281 {
00282
00283 if (state_->getDomain()->getBlockEvidence(i)) continue;
00284
00285 double invTemp =
00286 invTemps_[c/numSwap_][tempIds_[c/numSwap_][c%numSwap_]];
00287
00288 int chosen = gibbsSampleFromBlock(c, i, invTemp);
00289
00290 const Predicate* pred =
00291 state_->getDomain()->getPredInBlock(chosen, i);
00292 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00293 int idx = state_->getIndexOfGroundPredicate(gndPred);
00294
00295 delete gndPred;
00296 delete pred;
00297
00298
00299 if (idx >= 0)
00300 {
00301 bool truthValue = truthValues_[idx][c];
00302
00303
00304 if (!truthValue)
00305 {
00306 int blockSize = state_->getDomain()->getBlockSize(i);
00307 for (int j = 0; j < blockSize; j++)
00308 {
00309 const Predicate* otherPred =
00310 state_->getDomain()->getPredInBlock(j, i);
00311 GroundPredicate* otherGndPred =
00312 new GroundPredicate((Predicate*)otherPred);
00313 int otherIdx = state_->getIndexOfGroundPredicate(gndPred);
00314
00315 delete otherGndPred;
00316 delete otherPred;
00317
00318
00319 if (otherIdx >= 0)
00320 {
00321 bool otherTruthValue = truthValues_[otherIdx][c];
00322 if (otherTruthValue)
00323 {
00324 truthValues_[otherIdx][c] = false;
00325
00326 affectedGndPreds.clear();
00327 affectedGndPredIndices.clear();
00328 gndPredFlippedUpdates(otherIdx, c, affectedGndPreds,
00329 affectedGndPredIndices);
00330 updateWtsForGndPreds(affectedGndPreds,
00331 affectedGndPredIndices, c);
00332 }
00333 }
00334 }
00335
00336 truthValues_[idx][c] = true;
00337 affectedGndPreds.clear();
00338 affectedGndPredIndices.clear();
00339 gndPredFlippedUpdates(idx, c, affectedGndPreds,
00340 affectedGndPredIndices);
00341 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00342 }
00343
00344
00345
00346 if (!burningIn && tempIds_[c/numSwap_][c%numSwap_] == 0)
00347 numTrue_[idx]++;
00348 }
00349 }
00350
00351
00352 for (int i = 0; i < state_->getNumAtoms(); i++)
00353 {
00354
00355 if (state_->getBlockIndex(i) >= 0) continue;
00356
00357 double invTemp =
00358 invTemps_[c/numSwap_][tempIds_[c/numSwap_][c%numSwap_]];
00359 double p = getProbabilityOfPred(i, c, invTemp);
00360
00361
00362 bool newAssignment = genTruthValueForProb(p);
00363
00364 if (newAssignment != truthValues_[i][c])
00365 {
00366
00367 truthValues_[i][c] = newAssignment;
00368 affectedGndPreds.clear();
00369 affectedGndPredIndices.clear();
00370 gndPredFlippedUpdates(i, c, affectedGndPreds,
00371 affectedGndPredIndices);
00372 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00373 }
00374
00375
00376
00377 if (!burningIn && newAssignment &&
00378 tempIds_[c/numSwap_][c%numSwap_] == 0)
00379
00380 numTrue_[i]++;
00381 }
00382 }
00383 if (!burningIn) numSamplesPerPred += numST_;
00384
00385
00386 if (!burningIn && trackClauseTrueCnts_)
00387 state_->getNumClauseGndings(clauseTrueCnts_, true);
00388
00389 if (burningIn)
00390 {
00391 if ( (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00392 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00393 {
00394 cout << "Done burning. " << sample << " samples per chain " << endl;
00395 burningIn = false;
00396 sample = 0;
00397 }
00398 }
00399 else
00400 {
00401 if ( (maxSteps_ >= 0 && sample >= maxSteps_)
00402 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00403 {
00404 cout << "Done simulated tempering sampling. " << sample
00405 << " samples per chain" << endl;
00406 done = true;
00407 }
00408 }
00409 cout.flush();
00410 }
00411
00412 cout<< "Time taken for Simulated Tempering sampling = ";
00413 Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00414
00415
00416 for (int i = 0; i < state_->getNumAtoms(); i++)
00417 {
00418
00419
00420 setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00421 }
00422
00423
00424 if (trackClauseTrueCnts_)
00425 {
00426
00427 for (int i = 0; i < clauseTrueCnts_->size(); i++)
00428 (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00429 }
00430 }
00431
00432 private:
00433
00441 long double getWeightSum(const int& chainIdx)
00442 {
00443 long double w = 0;
00444 for (int i = 0; i < state_->getNumClauses(); i++)
00445 {
00446 long double wt = state_->getClauseCost(i);
00447 if ((wt > 0 && numTrueLits_[i][chainIdx] > 0) ||
00448 (wt < 0 && numTrueLits_[i][chainIdx] == 0))
00449 w += abs(wt);
00450 }
00451 return w;
00452 }
00453
00454 private:
00455
00456
00457
00458 int subInterval_;
00459
00460 int numST_;
00461
00462 int numSwap_;
00463
00464
00465 MaxWalkSat* mws_;
00466
00467
00468 int selInterval_;
00469
00470 double** invTemps_;
00471
00472 int** chainIds_;
00473
00474 int** tempIds_;
00475 };
00476
00477 #endif