simulatedtempering.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, and Daniel Lowd.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, and Daniel Lowd. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, and Daniel Lowd in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef SIMULATEDTEMPERING_H_
00067 #define SIMULATEDTEMPERING_H_
00068 
00069 #include "mcmc.h"
00070 #include "simulatedtemperingparams.h"
00071 #include "maxwalksat.h"
00072 #include "convergencetest.h"
00073 #include "gelmanconvergencetest.h"
00074 
00078 class SimulatedTempering : public MCMC
00079 {
00080  public:
00081 
00085   SimulatedTempering(VariableState* state, long int seed,
00086                      const bool& trackClauseTrueCnts, 
00087                      SimulatedTemperingParams* stParams)
00088     : MCMC(state, seed, trackClauseTrueCnts, stParams)
00089   {
00090       // User-set parameters
00091     subInterval_ = stParams->subInterval;
00092     numST_ = stParams->numST;
00093     numSwap_ = stParams->numSwap;
00094       // Number of chains is determined here
00095     numChains_ = numSwap_*numST_;        
00096     // ------------------------------------------ //
00097     // Chained method
00098     //  10 chains: i and i+1 swap attempt at
00099     //      selInterval*k + selInterval/10*i
00100     // ------------------------------------------ //
00101       // 9 possible swaps out of 10 chains
00102     selInterval_ = subInterval_*(numSwap_ - 1);
00103 
00104       // invTemp for chain chainIds_[i]
00105     invTemps_ = new double*[numST_];
00106       // curr chainId for ith temperature
00107     chainIds_ = new int*[numST_];
00108       // curr tempId for ith chain
00109     tempIds_ = new int*[numST_];
00110       // We don't need to track clause true counts in mws
00111     mws_ = new MaxWalkSat(state_, seed, false, stParams->mwsParams);
00112   }
00113 
00117   ~SimulatedTempering()
00118   {
00119     for (int i = 0; i < numST_; i++)
00120     {
00121       delete [] invTemps_[i];
00122       delete [] chainIds_[i];
00123       delete [] tempIds_[i];
00124     }
00125     delete [] invTemps_;
00126     delete [] chainIds_;
00127     delete [] tempIds_;
00128     delete mws_;
00129   }
00130   
00134   void init()
00135   {
00136       // Initialize gndPreds' truthValues & wts
00137     //state_->initTruthValuesAndWts(numChains_, start);
00138     initTruthValuesAndWts(numChains_);
00139 
00140       // Initialize with MWS
00141     cout << "Initializing Simulated Tempering with MaxWalksat" << endl;
00142     state_->eliminateSoftClauses();
00143       // Set num. of solutions temporarily to 1
00144     int numSolutions = mws_->getNumSolutions();
00145     mws_->setNumSolutions(1);
00146     for (int c = 0; c < numChains_; c++)
00147     {
00148       cout << "for chain " << c << "..." << endl;
00149         // Initialize with MWS
00150       mws_->init();
00151       mws_->infer();
00152       saveLowStateToChain(c);
00153     }
00154     mws_->setNumSolutions(numSolutions);
00155     state_->resetDeadClauses();
00156 
00157     // *** Initialize temperature schedule ***
00158     double maxWt = state_->getMaxClauseWeight();
00159     double maxWtForEvenSchedule = 100.0;
00160     double base = log(maxWt) / log((double)numSwap_);
00161     double* divs = new double[numSwap_];
00162     divs[0] = 1.0;
00163 
00164     for (int i = 1; i < numSwap_; i++)
00165     {
00166       divs[i] = divs[i - 1] / base;
00167     }
00168 
00169     for (int i = 0; i < numST_; i++)
00170     {
00171       invTemps_[i] = new double[numSwap_];
00172       chainIds_[i] = new int[numSwap_];
00173       tempIds_[i]  = new int[numSwap_];
00174       for (int j = 0; j < numSwap_; j++)
00175       {         
00176         chainIds_[i][j] = j;
00177         tempIds_[i][j] = j;
00178           // log vs even
00179         if (maxWt > maxWtForEvenSchedule)
00180         {
00181           invTemps_[i][j] = divs[j];
00182         }
00183         else
00184         {
00185           invTemps_[i][j] = 1.0-((double)j)/((double) numSwap_);
00186         }
00187       }
00188     }
00189     delete [] divs;
00190       
00191       // Initialize gndClauses' number of satisfied literals
00192     //int start = 0;
00193     initNumTrueLits(numChains_);
00194   }
00195 
00199   void infer()
00200   {
00201     initNumTrue();
00202     Timer timer;
00203       // Burn-in only if burnMaxSteps positive
00204     bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00205     double secondsElapsed = 0;
00206     double startTimeSec = timer.time();
00207     double currentTimeSec;
00208     int samplesPerOutput = 100;
00209 
00210       // If keeping track of true clause groundings, then init to zero
00211     if (trackClauseTrueCnts_)
00212       for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00213         (*clauseTrueCnts_)[clauseno] = 0;
00214 
00215       // Holds the ground preds which have currently been affected
00216     GroundPredicateHashArray affectedGndPreds;
00217     Array<int> affectedGndPredIndices;
00218 
00219     int numAtoms = state_->getNumAtoms();
00220     for (int i = 0; i < numAtoms; i++)
00221     {
00222       affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00223       affectedGndPredIndices.append(i);
00224     }
00225     for (int c = 0; c < numChains_; c++)
00226       updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00227     affectedGndPreds.clear();
00228     affectedGndPredIndices.clear();
00229 
00230     cout << "Running Simulated Tempering sampling..." << endl;
00231       // Sampling loop
00232     int sample = 0;
00233     int numSamplesPerPred = 0;
00234     bool done = false;
00235     while (!done)
00236     {
00237       ++sample;
00238 
00239       if (sample % samplesPerOutput == 0)
00240       { 
00241         currentTimeSec = timer.time();
00242         secondsElapsed = currentTimeSec-startTimeSec;
00243         cout << "Sample (per pred per chain) " << sample << ", time elapsed = ";
00244         Timer::printTime(cout, secondsElapsed); cout << endl;
00245       }
00246 
00247         // Attempt to swap temperature
00248       if ((sample % selInterval_) % subInterval_ == 0)
00249       {
00250         int attemptTempId = (sample % selInterval_) / subInterval_;
00251         if (attemptTempId < numSwap_ - 1)
00252         {
00253           double wl, wh, itl, ith;
00254           for (int i = 0; i < numST_; i++)
00255           {
00256             int lChainId = chainIds_[i][attemptTempId];
00257             int hChainId = chainIds_[i][attemptTempId + 1];
00258               // compute w_low, w_high: e = -w
00259               // swap acceptance ratio=e^(0.1*(w_h-w_l))
00260             wl = getWeightSum(i*numSwap_ + lChainId);
00261             wh = getWeightSum(i*numSwap_ + hChainId);
00262             itl = invTemps_[i][attemptTempId];
00263             ith = invTemps_[i][attemptTempId + 1];
00264 
00265             if (wl <= wh || random() <= RAND_MAX*exp((itl - ith)*(wh - wl)))
00266             {
00267               chainIds_[i][attemptTempId] = hChainId;
00268               chainIds_[i][attemptTempId+1] = lChainId;
00269               tempIds_[i][hChainId] = attemptTempId;
00270               tempIds_[i][lChainId] = attemptTempId + 1;
00271             }
00272           }
00273         }
00274       }
00275 
00276         // Generate new truth value based on temperature
00277       for (int c = 0; c < numChains_; c++) 
00278       {
00279           // For each block: select one to set to true
00280         for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++)
00281         {
00282             // If evidence atom exists, then all others stay false
00283           if (state_->getDomain()->getBlockEvidence(i)) continue;
00284  
00285           double invTemp =
00286             invTemps_[c/numSwap_][tempIds_[c/numSwap_][c%numSwap_]];
00287             // chosen is index in the block, block[chosen] is index in gndPreds_
00288           int chosen = gibbsSampleFromBlock(c, i, invTemp);
00289 
00290           const Predicate* pred =
00291             state_->getDomain()->getPredInBlock(chosen, i);
00292           GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00293           int idx = state_->getIndexOfGroundPredicate(gndPred);
00294 
00295           delete gndPred;
00296           delete pred;
00297       
00298             // If gnd pred in state:
00299           if (idx >= 0)
00300           {
00301             bool truthValue = truthValues_[idx][c];
00302               // If chosen pred was false, then need to set previous true
00303               // one to false and update wts
00304             if (!truthValue)
00305             {
00306               int blockSize = state_->getDomain()->getBlockSize(i);
00307               for (int j = 0; j < blockSize; j++)
00308               {
00309                 const Predicate* otherPred = 
00310                   state_->getDomain()->getPredInBlock(j, i);
00311                 GroundPredicate* otherGndPred =
00312                   new GroundPredicate((Predicate*)otherPred);
00313                 int otherIdx = state_->getIndexOfGroundPredicate(gndPred);
00314 
00315                 delete otherGndPred;
00316                 delete otherPred;
00317       
00318                   // If gnd pred in state:
00319                 if (otherIdx >= 0)
00320                 {
00321                   bool otherTruthValue = truthValues_[otherIdx][c];
00322                   if (otherTruthValue)
00323                   {
00324                     truthValues_[otherIdx][c] = false;
00325               
00326                     affectedGndPreds.clear();
00327                     affectedGndPredIndices.clear();
00328                     gndPredFlippedUpdates(otherIdx, c, affectedGndPreds,
00329                                           affectedGndPredIndices);
00330                     updateWtsForGndPreds(affectedGndPreds,
00331                                          affectedGndPredIndices, c);
00332                   }
00333                 }
00334               }
00335                 // Set truth value and update wts for chosen atom
00336               truthValues_[idx][c] = true;
00337               affectedGndPreds.clear();
00338               affectedGndPredIndices.clear();
00339               gndPredFlippedUpdates(idx, c, affectedGndPreds,
00340                                     affectedGndPredIndices);
00341               updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00342             }
00343 
00344               // If in actual sampling phase, track the num of times
00345               // the ground predicate is set to true
00346             if (!burningIn && tempIds_[c/numSwap_][c%numSwap_] == 0)
00347               numTrue_[idx]++;
00348           }
00349         }
00350 
00351           // Now go through all preds not in blocks
00352         for (int i = 0; i < state_->getNumAtoms(); i++) 
00353         {
00354             // Predicates in blocks have been handled above
00355           if (state_->getBlockIndex(i) >= 0) continue;
00356             // Calculate prob
00357           double invTemp =
00358             invTemps_[c/numSwap_][tempIds_[c/numSwap_][c%numSwap_]];
00359           double p = getProbabilityOfPred(i, c, invTemp);
00360 
00361             // Flip updates
00362           bool newAssignment = genTruthValueForProb(p);
00363           //if (newAssignment != pred->getTruthValue(c))
00364           if (newAssignment != truthValues_[i][c])
00365           {
00366             //pred->setTruthValue(c, newAssignment);
00367             truthValues_[i][c] = newAssignment;
00368             affectedGndPreds.clear();
00369             affectedGndPredIndices.clear();
00370             gndPredFlippedUpdates(i, c, affectedGndPreds,
00371                                   affectedGndPredIndices);
00372             updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00373           }
00374 
00375             // if in actual sim. tempering phase, track the num of times
00376             // the ground predicate is set to true
00377           if (!burningIn && newAssignment &&
00378               tempIds_[c/numSwap_][c%numSwap_] == 0)
00379             //pred->incrementNumTrue();
00380             numTrue_[i]++;
00381         }
00382       }
00383       if (!burningIn) numSamplesPerPred += numST_;
00384 
00385         // If keeping track of true clause groundings
00386       if (!burningIn && trackClauseTrueCnts_)
00387         state_->getNumClauseGndings(clauseTrueCnts_, true);
00388 
00389       if (burningIn) 
00390       {
00391         if (   (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00392             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00393         {
00394           cout << "Done burning. " << sample << " samples per chain " << endl;
00395           burningIn = false;
00396           sample = 0;
00397         }
00398       }
00399       else 
00400       {
00401         if (   (maxSteps_ >= 0 && sample >= maxSteps_)
00402             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_)) 
00403         {
00404           cout << "Done simulated tempering sampling. " << sample
00405                << " samples per chain" << endl;
00406           done = true;
00407         }
00408       }
00409       cout.flush();
00410     } // while (!done)
00411     
00412     cout<< "Time taken for Simulated Tempering sampling = "; 
00413     Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00414 
00415       // update gndPreds probability that it is true
00416     for (int i = 0; i < state_->getNumAtoms(); i++)
00417     {
00418       //GroundPredicate* gndPred = state_->getGndPred(i);
00419       //gndPred->setProbTrue(gndPred->getNumTrue() / numSamplesPerPred);
00420       setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00421     }
00422     
00423       // If keeping track of true clause groundings
00424     if (trackClauseTrueCnts_)
00425     {
00426         // Set the true counts to the average over all samples
00427       for (int i = 0; i < clauseTrueCnts_->size(); i++)
00428         (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00429     }
00430   }
00431   
00432  private:
00433  
00441   long double getWeightSum(const int& chainIdx)
00442   {
00443     long double w = 0;
00444     for (int i = 0; i < state_->getNumClauses(); i++)
00445     {
00446       long double wt = state_->getClauseCost(i);
00447       if ((wt > 0 && numTrueLits_[i][chainIdx] > 0) ||
00448           (wt < 0 && numTrueLits_[i][chainIdx] == 0))
00449         w += abs(wt);
00450     }
00451     return w;
00452   }
00453  
00454  private:
00455  
00456     // User-set parameters:
00457     // Selection interval between swap attempts
00458   int subInterval_;
00459     // Number of simulated tempering runs
00460   int numST_;
00461     // Number of swapping chains
00462   int numSwap_;
00463 
00464     // MaxWalksat is used for initialization
00465   MaxWalkSat* mws_;  
00466 
00467     // 9 possible swaps out of 10 chains
00468   int selInterval_;
00469     // invTemp for chain chainIds_[i]
00470   double** invTemps_;
00471     // curr chainId for ith temperature
00472   int** chainIds_;
00473     // curr tempId for ith chain
00474   int** tempIds_; 
00475 };
00476 
00477 #endif /*SIMULATEDTEMPERING_H_*/

Generated on Sun Jun 7 11:55:13 2009 for Alchemy by  doxygen 1.5.1