00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef SIMULATEDTEMPERING_H_
00067 #define SIMULATEDTEMPERING_H_
00068
00069 #include "mcmc.h"
00070 #include "simulatedtemperingparams.h"
00071 #include "maxwalksat.h"
00072 #include "convergencetest.h"
00073 #include "gelmanconvergencetest.h"
00074
00078 class SimulatedTempering : public MCMC
00079 {
00080 public:
00081
00085 SimulatedTempering(VariableState* state, long int seed,
00086 const bool& trackClauseTrueCnts,
00087 SimulatedTemperingParams* stParams)
00088 : MCMC(state, seed, trackClauseTrueCnts, stParams)
00089 {
00090
00091 subInterval_ = stParams->subInterval;
00092 numST_ = stParams->numST;
00093 numSwap_ = stParams->numSwap;
00094
00095 numChains_ = numSwap_*numST_;
00096
00097
00098
00099
00100
00101
00102 selInterval_ = subInterval_*(numSwap_ - 1);
00103
00104
00105 invTemps_ = new double*[numST_];
00106
00107 chainIds_ = new int*[numST_];
00108
00109 tempIds_ = new int*[numST_];
00110
00111 mws_ = new MaxWalkSat(state_, seed, false, stParams->mwsParams);
00112 }
00113
00117 ~SimulatedTempering()
00118 {
00119 for (int i = 0; i < numST_; i++)
00120 {
00121 delete [] invTemps_[i];
00122 delete [] chainIds_[i];
00123 delete [] tempIds_[i];
00124 }
00125 delete [] invTemps_;
00126 delete [] chainIds_;
00127 delete [] tempIds_;
00128 delete mws_;
00129 }
00130
00134 void init()
00135 {
00136
00137
00138 initTruthValuesAndWts(numChains_);
00139 initNumTrue();
00140
00141
00142 cout << "Initializing Simulated Tempering with MaxWalksat" << endl;
00143 state_->eliminateSoftClauses();
00144
00145 int numSolutions = mws_->getNumSolutions();
00146 mws_->setNumSolutions(1);
00147 for (int c = 0; c < numChains_; c++)
00148 {
00149 cout << "for chain " << c << "..." << endl;
00150
00151 mws_->init();
00152 mws_->infer();
00153 saveLowStateToChain(c);
00154 }
00155 mws_->setNumSolutions(numSolutions);
00156 state_->resetDeadClauses();
00157
00158
00159 double maxWt = state_->getMaxClauseWeight();
00160 double maxWtForEvenSchedule = 100.0;
00161 double base = log(maxWt) / log(numSwap_);
00162 double* divs = new double[numSwap_];
00163 divs[0] = 1.0;
00164
00165 for (int i = 1; i < numSwap_; i++)
00166 {
00167 divs[i] = divs[i - 1] / base;
00168 }
00169
00170 for (int i = 0; i < numST_; i++)
00171 {
00172 invTemps_[i] = new double[numSwap_];
00173 chainIds_[i] = new int[numSwap_];
00174 tempIds_[i] = new int[numSwap_];
00175 for (int j = 0; j < numSwap_; j++)
00176 {
00177 chainIds_[i][j] = j;
00178 tempIds_[i][j] = j;
00179
00180 if (maxWt > maxWtForEvenSchedule)
00181 {
00182 invTemps_[i][j] = divs[j];
00183 }
00184 else
00185 {
00186 invTemps_[i][j] = 1.0-((double)j)/((double) numSwap_);
00187 }
00188 }
00189 }
00190 delete [] divs;
00191
00192
00193
00194 initNumTrueLits(numChains_);
00195 }
00196
00200 void infer()
00201 {
00202 Timer timer;
00203
00204 bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00205 double secondsElapsed = 0;
00206 double startTimeSec = timer.time();
00207 double currentTimeSec;
00208 int samplesPerOutput = 100;
00209
00210
00211 if (trackClauseTrueCnts_)
00212 for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00213 (*clauseTrueCnts_)[clauseno] = 0;
00214
00215
00216 GroundPredicateHashArray affectedGndPreds;
00217 Array<int> affectedGndPredIndices;
00218
00219 int numAtoms = state_->getNumAtoms();
00220 for (int i = 0; i < numAtoms; i++)
00221 {
00222 affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00223 affectedGndPredIndices.append(i);
00224 }
00225 for (int c = 0; c < numChains_; c++)
00226 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00227 affectedGndPreds.clear();
00228 affectedGndPredIndices.clear();
00229
00230 cout << "Running Simulated Tempering sampling..." << endl;
00231
00232 int sample = 0;
00233 int numSamplesPerPred = 0;
00234 bool done = false;
00235 while (!done)
00236 {
00237 ++sample;
00238
00239 if (sample % samplesPerOutput == 0)
00240 {
00241 currentTimeSec = timer.time();
00242 secondsElapsed = currentTimeSec-startTimeSec;
00243 cout << "Sample (per pred per chain) " << sample << ", time elapsed = ";
00244 Timer::printTime(cout, secondsElapsed); cout << endl;
00245 }
00246
00247
00248 if ((sample % selInterval_) % subInterval_ == 0)
00249 {
00250 int attemptTempId = (sample % selInterval_) / subInterval_;
00251 if (attemptTempId < numSwap_ - 1)
00252 {
00253 double wl, wh, itl, ith;
00254 for (int i = 0; i < numST_; i++)
00255 {
00256 int lChainId = chainIds_[i][attemptTempId];
00257 int hChainId = chainIds_[i][attemptTempId + 1];
00258
00259
00260 wl = getWeightSum(i*numSwap_ + lChainId);
00261 wh = getWeightSum(i*numSwap_ + hChainId);
00262 itl = invTemps_[i][attemptTempId];
00263 ith = invTemps_[i][attemptTempId + 1];
00264
00265 if (wl <= wh || random() <= RAND_MAX*exp((itl - ith)*(wh - wl)))
00266 {
00267 chainIds_[i][attemptTempId] = hChainId;
00268 chainIds_[i][attemptTempId+1] = lChainId;
00269 tempIds_[i][hChainId] = attemptTempId;
00270 tempIds_[i][lChainId] = attemptTempId + 1;
00271 }
00272 }
00273 }
00274 }
00275
00276
00277 for (int c = 0; c < numChains_; c++)
00278 {
00279
00280 for (int i = 0; i < state_->getNumBlocks(); i++)
00281 {
00282
00283 if (state_->getBlockEvidence(i)) continue;
00284
00285 Array<int>& block = state_->getBlockArray(i);
00286 double invTemp =
00287 invTemps_[c/numSwap_][tempIds_[c/numSwap_][c%numSwap_]];
00288
00289 int chosen = gibbsSampleFromBlock(c, block, invTemp);
00290 bool truthValue = truthValues_[block[chosen]][c];
00291
00292
00293 if (!truthValue)
00294 {
00295 for (int j = 0; j < block.size(); j++)
00296 {
00297 bool otherTruthValue = truthValues_[block[j]][c];
00298 if (otherTruthValue)
00299 {
00300 truthValues_[block[j]][c] = false;
00301
00302 affectedGndPreds.clear();
00303 affectedGndPredIndices.clear();
00304 gndPredFlippedUpdates(block[j], c, affectedGndPreds,
00305 affectedGndPredIndices);
00306 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00307 c);
00308 }
00309 }
00310
00311 truthValues_[block[chosen]][c] = true;
00312 affectedGndPreds.clear();
00313 affectedGndPredIndices.clear();
00314 gndPredFlippedUpdates(block[chosen], c, affectedGndPreds,
00315 affectedGndPredIndices);
00316 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00317 }
00318
00319
00320
00321 if (!burningIn && tempIds_[c/numSwap_][c%numSwap_] == 0)
00322 numTrue_[block[chosen]]++;
00323 }
00324
00325
00326 for (int i = 0; i < state_->getNumAtoms(); i++)
00327 {
00328
00329 if (state_->getBlockIndex(i) >= 0) continue;
00330
00331 double invTemp =
00332 invTemps_[c/numSwap_][tempIds_[c/numSwap_][c%numSwap_]];
00333 double p = getProbabilityOfPred(i, c, invTemp);
00334
00335
00336
00337
00338
00339
00340
00341 bool newAssignment = genTruthValueForProb(p);
00342
00343 if (newAssignment != truthValues_[i][c])
00344 {
00345
00346 truthValues_[i][c] = newAssignment;
00347 affectedGndPreds.clear();
00348 affectedGndPredIndices.clear();
00349 gndPredFlippedUpdates(i, c, affectedGndPreds,
00350 affectedGndPredIndices);
00351 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00352 }
00353
00354
00355
00356 if (!burningIn && newAssignment &&
00357 tempIds_[c/numSwap_][c%numSwap_] == 0)
00358
00359 numTrue_[i]++;
00360 }
00361 }
00362 if (!burningIn) numSamplesPerPred += numST_;
00363
00364
00365 if (!burningIn && trackClauseTrueCnts_)
00366 state_->getNumClauseGndings(clauseTrueCnts_, true);
00367
00368 if (burningIn)
00369 {
00370 if ( (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00371 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00372 {
00373 cout << "Done burning. " << sample << " samples per chain " << endl;
00374 burningIn = false;
00375 sample = 0;
00376 }
00377 }
00378 else
00379 {
00380 if ( (maxSteps_ >= 0 && sample >= maxSteps_)
00381 || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00382 {
00383 cout << "Done simulated tempering sampling. " << sample
00384 << " samples per chain" << endl;
00385 done = true;
00386 }
00387 }
00388 cout.flush();
00389 }
00390
00391 cout<< "Time taken for Simulated Tempering sampling = ";
00392 Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00393
00394
00395 for (int i = 0; i < state_->getNumAtoms(); i++)
00396 {
00397
00398
00399 setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00400 }
00401
00402
00403 if (trackClauseTrueCnts_)
00404 {
00405
00406 for (int i = 0; i < clauseTrueCnts_->size(); i++)
00407 (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00408 }
00409 }
00410
00411 private:
00412
00420 long double getWeightSum(const int& chainIdx)
00421 {
00422 long double w = 0;
00423 for (int i = 0; i < state_->getNumClauses(); i++)
00424 {
00425 long double wt = state_->getClauseCost(i);
00426 if ((wt > 0 && numTrueLits_[i][chainIdx] > 0) ||
00427 (wt < 0 && numTrueLits_[i][chainIdx] == 0))
00428 w += abs(wt);
00429 }
00430 return w;
00431 }
00432
00433 private:
00434
00435
00436
00437 int subInterval_;
00438
00439 int numST_;
00440
00441 int numSwap_;
00442
00443
00444 MaxWalkSat* mws_;
00445
00446
00447 int selInterval_;
00448
00449 double** invTemps_;
00450
00451 int** chainIds_;
00452
00453 int** tempIds_;
00454 };
00455
00456 #endif