gibbssampler.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, and Daniel Lowd.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, and Daniel Lowd. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, and Daniel Lowd in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef GIBBSSAMPLER_H_
00067 #define GIBBSSAMPLER_H_
00068 
00069 #include "mcmc.h"
00070 #include "gibbsparams.h"
00071 #include "maxwalksat.h"
00072 #include "convergencetest.h"
00073 #include "gelmanconvergencetest.h"
00074 
00075   // Initialize values randomly or with MaxWalkSat?
00076 enum WalksatType { NONE = 0, MAXWALKSAT = 1 };
00077   // Set to true for more output
00078 const bool gibbsdebug = false;
00079 
00083 class GibbsSampler : public MCMC
00084 {
00085  public:
00086 
00093   GibbsSampler(VariableState* state, long int seed,
00094                const bool& trackClauseTrueCnts, GibbsParams* gibbsParams)
00095     : MCMC(state, seed, trackClauseTrueCnts, gibbsParams)
00096   {
00097       // User-set parameters
00098     gamma_ = gibbsParams->gamma;
00099     epsilonError_ = gibbsParams->epsilonError;
00100     fracConverged_ = gibbsParams->fracConverged;
00101     walksatType_ = gibbsParams->walksatType;
00102     testConvergence_ = gibbsParams->testConvergence;
00103     samplesPerTest_ = gibbsParams->samplesPerTest;
00104     
00105       // We don't need to track clause true counts in MWS
00106     mws_ = new MaxWalkSat(state_, seed, false, gibbsParams->mwsParams);
00107   }
00108 
00112   ~GibbsSampler()
00113   {
00114     deleteConvergenceTests(burnConvergenceTests_, gibbsConvergenceTests_,
00115                            state_->getNumAtoms());
00116     delete mws_;
00117   }
00118   
00122   void init()
00123   {
00124       // Initialize gndPreds' truthValues & wts
00125     initTruthValuesAndWts(numChains_);
00126 
00127     cout << "Initializing Gibbs sampling " ;
00128       // Initialize with MWS
00129     if (walksatType_ == 1)
00130     {
00131       cout << "with MaxWalksat" << endl;
00132       for (int c = 0; c < numChains_; c++)
00133       {
00134         cout << "for chain " << c << "..." << endl;
00135         mws_->init();
00136         mws_->infer();
00137         saveLowStateToChain(c);
00138       }
00139       if (numChains_ == 1) state_->saveLowStateToGndPreds();
00140     }
00141       // Initialize randomly
00142     else
00143     {
00144       cout << "randomly" << endl;
00145       randomInitGndPredsTruthValues(numChains_);
00146     }
00147       
00148       // Initialize gndClauses' number of satisfied literals
00149     //int start = 0;
00150     initNumTrueLits(numChains_);
00151 
00152     int numGndPreds = state_->getNumAtoms();
00153       // Initialize convergence test
00154     initConvergenceTests(burnConvergenceTests_, gibbsConvergenceTests_,
00155                          gamma_, epsilonError_, numGndPreds, numChains_);
00156   }
00157 
00161   void infer()
00162   {
00163     initNumTrue();
00164     Timer timer;
00165       // Burn-in only if burnMaxSteps positive
00166     bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00167     double secondsElapsed = 0;
00168     double startTimeSec = timer.time();
00169     double currentTimeSec;
00170     
00171       // If keeping track of true clause groundings, then init to zero
00172     if (trackClauseTrueCnts_)
00173       for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00174         (*clauseTrueCnts_)[clauseno] = 0;
00175 
00176       // Holds the ground preds which have currently been affected
00177     GroundPredicateHashArray affectedGndPreds;
00178     Array<int> affectedGndPredIndices;
00179 
00180     int numAtoms = state_->getNumAtoms();
00181     for (int i = 0; i < numAtoms; i++)
00182     {
00183       affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00184       affectedGndPredIndices.append(i);
00185     }
00186     for (int c = 0; c < numChains_; c++)
00187       updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00188     affectedGndPreds.clear();
00189     affectedGndPredIndices.clear();
00190 
00191     cout << "Running Gibbs sampling..." << endl;
00192       // Sampling loop
00193     int sample = 0;
00194     int numSamplesPerPred = 0;
00195     bool done = false;
00196     while (!done)
00197     {
00198       ++sample;
00199 
00200       if (sample % samplesPerTest_ == 0)
00201       { 
00202         currentTimeSec = timer.time();
00203         secondsElapsed = currentTimeSec-startTimeSec;
00204         cout << "Sample (per pred per chain) " << sample << ", time elapsed = ";
00205         Timer::printTime(cout, secondsElapsed); cout << endl;
00206       }
00207 
00208         // For each chain, for each node, generate the node's new truth value
00209       for (int c = 0; c < numChains_; c++) 
00210       {
00211         performGibbsStep(c, burningIn, affectedGndPreds,
00212                          affectedGndPredIndices);
00213         if (!burningIn) numSamplesPerPred++;
00214       }
00215   
00216         // Add current truth values to the convergence testers
00217       for (int i = 0; i < state_->getNumAtoms(); i++) 
00218       {
00219         const bool* vals = truthValues_[i].getItems();
00220           //WARNING: implicit cast from bool* to double*
00221         if (burningIn) burnConvergenceTests_[i]->appendNewValues(vals);
00222         else           gibbsConvergenceTests_[i]->appendNewValues(vals);
00223       }
00224 
00225       if (sample % samplesPerTest_ != 0) continue;      
00226       if (burningIn) 
00227       {
00228           // Use convergence criteria stated in "Probability and Statistics",
00229           // DeGroot and Schervish
00230         bool burnConverged = false;
00231         
00232         if (testConvergence_)
00233           burnConverged = 
00234             GelmanConvergenceTest::checkConvergenceOfAll(burnConvergenceTests_,
00235                                                          state_->getNumAtoms(),
00236                                                          true);
00237         if (   (sample >= burnMinSteps_ && burnConverged)
00238             || (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00239             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00240         {
00241           cout << "Done burning. " << sample << " samples per pred per chain";
00242           if (testConvergence_)
00243           {
00244             cout << " (" << (burnConverged? "converged":"didn't converge") 
00245                  <<" at total of " << numChains_*sample << " samples per pred)";
00246           }
00247           cout << endl;
00248           burningIn = false;
00249           sample = 0;          
00250         }
00251       }
00252       else
00253       {  // Doing actual gibbs sampling
00254         bool gibbsConverged = false;
00255         
00256         if (testConvergence_)
00257           gibbsConverged =
00258             ConvergenceTest::checkConvergenceOfAtLeast(gibbsConvergenceTests_, 
00259                                                        state_->getNumAtoms(),
00260                                                        sample, fracConverged_,
00261                                                        true);
00262 
00263         if (   (sample >= minSteps_ && gibbsConverged) 
00264             || (maxSteps_ >= 0 && sample >= maxSteps_)
00265             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_)) 
00266         {
00267           cout << "Done Gibbs sampling. " << sample 
00268                << " samples per pred per chain";
00269           if (testConvergence_)
00270           {
00271             cout << " (" << (gibbsConverged? "converged":"didn't converge") 
00272                  <<" at total of " << numSamplesPerPred << " samples per pred)";
00273           }
00274           cout << endl;
00275           done = true;
00276         }
00277       }
00278       cout.flush();
00279     } // while (!done)
00280     
00281     cout<< "Time taken for Gibbs sampling = "; 
00282     Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00283 
00284       // update gndPreds probability that it is true
00285     for (int i = 0; i < state_->getNumAtoms(); i++)
00286     {
00287       setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00288     }
00289 
00290       // If keeping track of true clause groundings
00291     if (trackClauseTrueCnts_)
00292     {
00293         // Set the true counts to the average over all samples
00294       for (int i = 0; i < clauseTrueCnts_->size(); i++)
00295         (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00296     }
00297   }
00298   
00299  private:
00300  
00304   void initConvergenceTests(GelmanConvergenceTest**& burnConvergenceTests,
00305                             ConvergenceTest**& gibbsConvergenceTests, 
00306                             const double& gamma, const double& epsilonFrac, 
00307                             const int& numGndPreds, const int& numChains)
00308   {
00309     burnConvergenceTests = new GelmanConvergenceTest*[numGndPreds];
00310     gibbsConvergenceTests = new ConvergenceTest*[numGndPreds];
00311     for (int i = 0; i < numGndPreds; i++) 
00312     {
00313       burnConvergenceTests[i]  = new GelmanConvergenceTest(numChains);
00314       gibbsConvergenceTests[i] = new ConvergenceTest(numChains, gamma,
00315                                                      epsilonFrac);
00316     }
00317   }
00318 
00322   void deleteConvergenceTests(GelmanConvergenceTest**& burnConvergenceTests,
00323                               ConvergenceTest**& gibbsConvergenceTests, 
00324                               const int& numGndPreds)
00325   {
00326     for (int i = 0; i < numGndPreds; i++) 
00327     {
00328       delete burnConvergenceTests[i];
00329       delete gibbsConvergenceTests[i];
00330     }
00331     delete [] burnConvergenceTests;
00332     delete [] gibbsConvergenceTests;
00333   }  
00334   
00335  private:
00336     // Gamma used by convergence test
00337   double gamma_;
00338     // Epsilon used by convergence test
00339   double epsilonError_;
00340     // Fraction of samples needed to converge
00341   double fracConverged_;
00342     // 0 = Initialize randomly, 1 = initialize with MaxWalksat
00343   int walksatType_;
00344     // If true, test for convergence, otherwise do not test
00345   int testConvergence_;
00346     // Number of samples between checking for convergence
00347   int samplesPerTest_;
00348     // Convergence test for burning in
00349   GelmanConvergenceTest** burnConvergenceTests_;
00350     // Convergence test for sampling
00351   ConvergenceTest** gibbsConvergenceTests_;
00352     
00353     // MaxWalksat is used for initialization if walksatType_ = 1
00354   MaxWalkSat* mws_;  
00355 };
00356 
00357 #endif /*GIBBSSAMPLER_H_*/

Generated on Sun Jun 7 11:55:11 2009 for Alchemy by  doxygen 1.5.1