simulatedtempering.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef SIMULATEDTEMPERING_H_
00067 #define SIMULATEDTEMPERING_H_
00068 
00069 #include "mcmc.h"
00070 #include "simulatedtemperingparams.h"
00071 #include "maxwalksat.h"
00072 #include "convergencetest.h"
00073 #include "gelmanconvergencetest.h"
00074 
00078 class SimulatedTempering : public MCMC
00079 {
00080  public:
00081 
00085   SimulatedTempering(VariableState* state, long int seed,
00086                      const bool& trackClauseTrueCnts, 
00087                      SimulatedTemperingParams* stParams)
00088     : MCMC(state, seed, trackClauseTrueCnts, stParams)
00089   {
00090       // User-set parameters
00091     subInterval_ = stParams->subInterval;
00092     numST_ = stParams->numST;
00093     numSwap_ = stParams->numSwap;
00094       // Number of chains is determined here
00095     numChains_ = numSwap_*numST_;        
00096     // ------------------------------------------ //
00097     // Chained method
00098     //  10 chains: i and i+1 swap attempt at
00099     //      selInterval*k + selInterval/10*i
00100     // ------------------------------------------ //
00101       // 9 possible swaps out of 10 chains
00102     selInterval_ = subInterval_*(numSwap_ - 1);
00103 
00104       // invTemp for chain chainIds_[i]
00105     invTemps_ = new double*[numST_];
00106       // curr chainId for ith temperature
00107     chainIds_ = new int*[numST_];
00108       // curr tempId for ith chain
00109     tempIds_ = new int*[numST_];
00110       // We don't need to track clause true counts in mws
00111     mws_ = new MaxWalkSat(state_, seed, false, stParams->mwsParams);
00112   }
00113 
00117   ~SimulatedTempering()
00118   {
00119     for (int i = 0; i < numST_; i++)
00120     {
00121       delete [] invTemps_[i];
00122       delete [] chainIds_[i];
00123       delete [] tempIds_[i];
00124     }
00125     delete [] invTemps_;
00126     delete [] chainIds_;
00127     delete [] tempIds_;
00128     delete mws_;
00129   }
00130   
00134   void init()
00135   {
00136       // Initialize gndPreds' truthValues & wts
00137     //state_->initTruthValuesAndWts(numChains_, start);
00138     initTruthValuesAndWts(numChains_);
00139     initNumTrue();
00140 
00141       // Initialize with MWS
00142     cout << "Initializing Simulated Tempering with MaxWalksat" << endl;
00143     state_->eliminateSoftClauses();
00144       // Set num. of solutions temporarily to 1
00145     int numSolutions = mws_->getNumSolutions();
00146     mws_->setNumSolutions(1);
00147     for (int c = 0; c < numChains_; c++)
00148     {
00149       cout << "for chain " << c << "..." << endl;
00150         // Initialize with MWS
00151       mws_->init();
00152       mws_->infer();
00153       saveLowStateToChain(c);
00154     }
00155     mws_->setNumSolutions(numSolutions);
00156     state_->resetDeadClauses();
00157 
00158     // *** Initialize temperature schedule ***
00159     double maxWt = state_->getMaxClauseWeight();
00160     double maxWtForEvenSchedule = 100.0;
00161     double base = log(maxWt) / log(numSwap_);
00162     double* divs = new double[numSwap_];
00163     divs[0] = 1.0;
00164 
00165     for (int i = 1; i < numSwap_; i++)
00166     {
00167       divs[i] = divs[i - 1] / base;
00168     }
00169 
00170     for (int i = 0; i < numST_; i++)
00171     {
00172       invTemps_[i] = new double[numSwap_];
00173       chainIds_[i] = new int[numSwap_];
00174       tempIds_[i]  = new int[numSwap_];
00175       for (int j = 0; j < numSwap_; j++)
00176       {         
00177         chainIds_[i][j] = j;
00178         tempIds_[i][j] = j;
00179           // log vs even
00180         if (maxWt > maxWtForEvenSchedule)
00181         {
00182           invTemps_[i][j] = divs[j];
00183         }
00184         else
00185         {
00186           invTemps_[i][j] = 1.0-((double)j)/((double) numSwap_);
00187         }
00188       }
00189     }
00190     delete [] divs;
00191       
00192       // Initialize gndClauses' number of satisfied literals
00193     //int start = 0;
00194     initNumTrueLits(numChains_);
00195   }
00196 
00200   void infer()
00201   {
00202     Timer timer;
00203       // Burn-in only if burnMaxSteps positive
00204     bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00205     double secondsElapsed = 0;
00206     double startTimeSec = timer.time();
00207     double currentTimeSec;
00208     int samplesPerOutput = 100;
00209 
00210       // If keeping track of true clause groundings, then init to zero
00211     if (trackClauseTrueCnts_)
00212       for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00213         (*clauseTrueCnts_)[clauseno] = 0;
00214 
00215       // Holds the ground preds which have currently been affected
00216     GroundPredicateHashArray affectedGndPreds;
00217     Array<int> affectedGndPredIndices;
00218 
00219     int numAtoms = state_->getNumAtoms();
00220     for (int i = 0; i < numAtoms; i++)
00221     {
00222       affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00223       affectedGndPredIndices.append(i);
00224     }
00225     for (int c = 0; c < numChains_; c++)
00226       updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00227     affectedGndPreds.clear();
00228     affectedGndPredIndices.clear();
00229 
00230     cout << "Running Simulated Tempering sampling..." << endl;
00231       // Sampling loop
00232     int sample = 0;
00233     int numSamplesPerPred = 0;
00234     bool done = false;
00235     while (!done)
00236     {
00237       ++sample;
00238 
00239       if (sample % samplesPerOutput == 0)
00240       { 
00241         currentTimeSec = timer.time();
00242         secondsElapsed = currentTimeSec-startTimeSec;
00243         cout << "Sample (per pred per chain) " << sample << ", time elapsed = ";
00244         Timer::printTime(cout, secondsElapsed); cout << endl;
00245       }
00246 
00247         // Attempt to swap temperature
00248       if ((sample % selInterval_) % subInterval_ == 0)
00249       {
00250         int attemptTempId = (sample % selInterval_) / subInterval_;
00251         if (attemptTempId < numSwap_ - 1)
00252         {
00253           double wl, wh, itl, ith;
00254           for (int i = 0; i < numST_; i++)
00255           {
00256             int lChainId = chainIds_[i][attemptTempId];
00257             int hChainId = chainIds_[i][attemptTempId + 1];
00258               // compute w_low, w_high: e = -w
00259               // swap acceptance ratio=e^(0.1*(w_h-w_l))
00260             wl = getWeightSum(i*numSwap_ + lChainId);
00261             wh = getWeightSum(i*numSwap_ + hChainId);
00262             itl = invTemps_[i][attemptTempId];
00263             ith = invTemps_[i][attemptTempId + 1];
00264 
00265             if (wl <= wh || random() <= RAND_MAX*exp((itl - ith)*(wh - wl)))
00266             {
00267               chainIds_[i][attemptTempId] = hChainId;
00268               chainIds_[i][attemptTempId+1] = lChainId;
00269               tempIds_[i][hChainId] = attemptTempId;
00270               tempIds_[i][lChainId] = attemptTempId + 1;
00271             }
00272           }
00273         }
00274       }
00275 
00276         // Generate new truth value based on temperature
00277       for (int c = 0; c < numChains_; c++) 
00278       {
00279           // For each block: select one to set to true
00280         for (int i = 0; i < state_->getNumBlocks(); i++)
00281         {
00282             // If evidence atom exists, then all others stay false
00283           if (state_->getBlockEvidence(i)) continue;
00284  
00285           Array<int>& block = state_->getBlockArray(i);
00286           double invTemp =
00287             invTemps_[c/numSwap_][tempIds_[c/numSwap_][c%numSwap_]];
00288             // chosen is index in the block, block[chosen] is index in gndPreds_
00289           int chosen = gibbsSampleFromBlock(c, block, invTemp);
00290           bool truthValue = truthValues_[block[chosen]][c];
00291             // If chosen pred was false, then need to set previous true
00292             // one to false and update wts
00293           if (!truthValue)
00294           {
00295             for (int j = 0; j < block.size(); j++)
00296             {
00297               bool otherTruthValue = truthValues_[block[j]][c];
00298               if (otherTruthValue)
00299               {
00300                 truthValues_[block[j]][c] = false;
00301               
00302                 affectedGndPreds.clear();
00303                 affectedGndPredIndices.clear();
00304                 gndPredFlippedUpdates(block[j], c, affectedGndPreds,
00305                                       affectedGndPredIndices);
00306                 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00307                                      c);
00308               }
00309             }
00310               // Set truth value and update wts for chosen atom
00311             truthValues_[block[chosen]][c] = true;
00312             affectedGndPreds.clear();
00313             affectedGndPredIndices.clear();
00314             gndPredFlippedUpdates(block[chosen], c, affectedGndPreds,
00315                                   affectedGndPredIndices);
00316             updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00317           }
00318 
00319             // If in actual sampling phase, track the num of times
00320             // the ground predicate is set to true
00321           if (!burningIn && tempIds_[c/numSwap_][c%numSwap_] == 0)
00322             numTrue_[block[chosen]]++;
00323         }
00324 
00325           // Now go through all preds not in blocks
00326         for (int i = 0; i < state_->getNumAtoms(); i++) 
00327         {
00328             // Predicates in blocks have been handled above
00329           if (state_->getBlockIndex(i) >= 0) continue;
00330             // Calculate prob
00331           double invTemp =
00332             invTemps_[c/numSwap_][tempIds_[c/numSwap_][c%numSwap_]];
00333           double p = getProbabilityOfPred(i, c, invTemp);
00334           //double p = 1.0 /
00335           //           (1.0 + exp((wtsWhenFalse_[i][c] -
00336           //                       wtsWhenTrue_[i][c]) *
00337           //                       invTemps_[c/numSwap_]
00338           //                                [tempIds_[c/numSwap_][c%numSwap_]]));
00339 
00340             // Flip updates
00341           bool newAssignment = genTruthValueForProb(p);
00342           //if (newAssignment != pred->getTruthValue(c))
00343           if (newAssignment != truthValues_[i][c])
00344           {
00345             //pred->setTruthValue(c, newAssignment);
00346             truthValues_[i][c] = newAssignment;
00347             affectedGndPreds.clear();
00348             affectedGndPredIndices.clear();
00349             gndPredFlippedUpdates(i, c, affectedGndPreds,
00350                                   affectedGndPredIndices);
00351             updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00352           }
00353 
00354             // if in actual sim. tempering phase, track the num of times
00355             // the ground predicate is set to true
00356           if (!burningIn && newAssignment &&
00357               tempIds_[c/numSwap_][c%numSwap_] == 0)
00358             //pred->incrementNumTrue();
00359             numTrue_[i]++;
00360         }
00361       }
00362       if (!burningIn) numSamplesPerPred += numST_;
00363 
00364         // If keeping track of true clause groundings
00365       if (!burningIn && trackClauseTrueCnts_)
00366         state_->getNumClauseGndings(clauseTrueCnts_, true);
00367 
00368       if (burningIn) 
00369       {
00370         if (   (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00371             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00372         {
00373           cout << "Done burning. " << sample << " samples per chain " << endl;
00374           burningIn = false;
00375           sample = 0;
00376         }
00377       }
00378       else 
00379       {
00380         if (   (maxSteps_ >= 0 && sample >= maxSteps_)
00381             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_)) 
00382         {
00383           cout << "Done simulated tempering sampling. " << sample
00384                << " samples per chain" << endl;
00385           done = true;
00386         }
00387       }
00388       cout.flush();
00389     } // while (!done)
00390     
00391     cout<< "Time taken for Simulated Tempering sampling = "; 
00392     Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00393 
00394       // update gndPreds probability that it is true
00395     for (int i = 0; i < state_->getNumAtoms(); i++)
00396     {
00397       //GroundPredicate* gndPred = state_->getGndPred(i);
00398       //gndPred->setProbTrue(gndPred->getNumTrue() / numSamplesPerPred);
00399       setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00400     }
00401     
00402       // If keeping track of true clause groundings
00403     if (trackClauseTrueCnts_)
00404     {
00405         // Set the true counts to the average over all samples
00406       for (int i = 0; i < clauseTrueCnts_->size(); i++)
00407         (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00408     }
00409   }
00410   
00411  private:
00412  
00420   long double getWeightSum(const int& chainIdx)
00421   {
00422     long double w = 0;
00423     for (int i = 0; i < state_->getNumClauses(); i++)
00424     {
00425       long double wt = state_->getClauseCost(i);
00426       if ((wt > 0 && numTrueLits_[i][chainIdx] > 0) ||
00427           (wt < 0 && numTrueLits_[i][chainIdx] == 0))
00428         w += abs(wt);
00429     }
00430     return w;
00431   }
00432  
00433  private:
00434  
00435     // User-set parameters:
00436     // Selection interval between swap attempts
00437   int subInterval_;
00438     // Number of simulated tempering runs
00439   int numST_;
00440     // Number of swapping chains
00441   int numSwap_;
00442 
00443     // MaxWalksat is used for initialization
00444   MaxWalkSat* mws_;  
00445 
00446     // 9 possible swaps out of 10 chains
00447   int selInterval_;
00448     // invTemp for chain chainIds_[i]
00449   double** invTemps_;
00450     // curr chainId for ith temperature
00451   int** chainIds_;
00452     // curr tempId for ith chain
00453   int** tempIds_; 
00454 };
00455 
00456 #endif /*SIMULATEDTEMPERING_H_*/

Generated on Tue Jan 16 05:30:03 2007 for Alchemy by  doxygen 1.5.1