gibbssampler.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef GIBBSSAMPLER_H_
00067 #define GIBBSSAMPLER_H_
00068 
00069 #include "mcmc.h"
00070 #include "gibbsparams.h"
00071 #include "maxwalksat.h"
00072 #include "convergencetest.h"
00073 #include "gelmanconvergencetest.h"
00074 
00075   // Initialize values randomly or with MaxWalkSat?
00076 enum WalksatType { NONE = 0, MAXWALKSAT = 1 };
00077   // Set to true for more output
00078 const bool gibbsdebug = false;
00079 
00083 class GibbsSampler : public MCMC
00084 {
00085  public:
00086 
00093   GibbsSampler(VariableState* state, long int seed,
00094                const bool& trackClauseTrueCnts, GibbsParams* gibbsParams)
00095     : MCMC(state, seed, trackClauseTrueCnts, gibbsParams)
00096   {
00097       // User-set parameters
00098     gamma_ = gibbsParams->gamma;
00099     epsilonError_ = gibbsParams->epsilonError;
00100     fracConverged_ = gibbsParams->fracConverged;
00101     walksatType_ = gibbsParams->walksatType;
00102     samplesPerTest_ = gibbsParams->samplesPerTest;
00103     
00104       // We don't need to track clause true counts in up and ss
00105     mws_ = new MaxWalkSat(state_, seed, false, gibbsParams->mwsParams);
00106   }
00107 
00111   ~GibbsSampler()
00112   {
00113     deleteConvergenceTests(burnConvergenceTests_, gibbsConvergenceTests_,
00114                            state_->getNumAtoms());
00115     delete mws_;
00116   }
00117   
00121   void init()
00122   {
00123       // Initialize gndPreds' truthValues & wts
00124     initTruthValuesAndWts(numChains_);
00125     initNumTrue();
00126 
00127     cout << "Initializing Gibbs sampling " ;
00128       // Initialize with MWS
00129     if (walksatType_ == 1)
00130     {
00131       cout << "with MaxWalksat" << endl;
00132       for (int c = 0; c < numChains_; c++)
00133       {
00134         cout << "for chain " << c << "..." << endl;
00135         mws_->init();
00136         mws_->infer();
00137         saveLowStateToChain(c);
00138       }
00139     }
00140       // Initialize randomly
00141     else
00142     {
00143       cout << "randomly" << endl;
00144       randomInitGndPredsTruthValues(numChains_);
00145     }
00146       
00147       // Initialize gndClauses' number of satisfied literals
00148     //int start = 0;
00149     initNumTrueLits(numChains_);
00150 
00151     int numGndPreds = state_->getNumAtoms();
00152       // Initialize convergence test
00153     initConvergenceTests(burnConvergenceTests_, gibbsConvergenceTests_,
00154                          gamma_, epsilonError_, numGndPreds, numChains_);
00155   }
00156 
00160   void infer()
00161   {
00162     Timer timer;
00163       // Burn-in only if burnMaxSteps positive
00164     bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00165     double secondsElapsed = 0;
00166     double startTimeSec = timer.time();
00167     double currentTimeSec;
00168     
00169       // If keeping track of true clause groundings, then init to zero
00170     if (trackClauseTrueCnts_)
00171       for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00172         (*clauseTrueCnts_)[clauseno] = 0;
00173 
00174       // Holds the ground preds which have currently been affected
00175     GroundPredicateHashArray affectedGndPreds;
00176     Array<int> affectedGndPredIndices;
00177 
00178     int numAtoms = state_->getNumAtoms();
00179     for (int i = 0; i < numAtoms; i++)
00180     {
00181       affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00182       affectedGndPredIndices.append(i);
00183     }
00184     for (int c = 0; c < numChains_; c++)
00185       updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, c);
00186     affectedGndPreds.clear();
00187     affectedGndPredIndices.clear();
00188 
00189     cout << "Running Gibbs sampling..." << endl;
00190       // Sampling loop
00191     int sample = 0;
00192     int numSamplesPerPred = 0;
00193     bool done = false;
00194     while (!done)
00195     {
00196       ++sample;
00197 
00198       if (sample % samplesPerTest_ == 0)
00199       { 
00200         currentTimeSec = timer.time();
00201         secondsElapsed = currentTimeSec-startTimeSec;
00202         cout << "Sample (per pred per chain) " << sample << ", time elapsed = ";
00203         Timer::printTime(cout, secondsElapsed); cout << endl;
00204       }
00205 
00206         // For each chain, for each node, generate the node's new truth value
00207       for (int c = 0; c < numChains_; c++) 
00208       {
00209         performGibbsStep(c, burningIn, affectedGndPreds,
00210                          affectedGndPredIndices);
00211         if (!burningIn) numSamplesPerPred++;
00212       }
00213   
00214         // Add current truth values to the convergence testers
00215       for (int i = 0; i < state_->getNumAtoms(); i++) 
00216       {
00217         const bool* vals = truthValues_[i].getItems();
00218           //WARNING: implicit cast from bool* to double*
00219         if (burningIn) burnConvergenceTests_[i]->appendNewValues(vals);
00220         else           gibbsConvergenceTests_[i]->appendNewValues(vals);
00221       }
00222 
00223       if (sample % samplesPerTest_ != 0) continue;      
00224       if (burningIn) 
00225       {
00226           // Use convergence criteria stated in "Probability and Statistics",
00227           // DeGroot and Schervish
00228         bool burnConverged 
00229           = GelmanConvergenceTest::checkConvergenceOfAll(burnConvergenceTests_,
00230                                                          state_->getNumAtoms(),
00231                                                          true);
00232         if (   (sample >= burnMinSteps_ && burnConverged)
00233             || (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00234             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00235         {
00236           cout << "Done burning. " << sample << " samples per pred per chain (" 
00237                << (burnConverged? "converged":"didn't converge") 
00238                <<" at total of " << numChains_*sample << " samples per pred)" 
00239                << endl;
00240           burningIn = false;
00241           sample = 0;          
00242         }
00243       }
00244       else
00245       {  // Doing actual gibbs sampling
00246         bool gibbsConverged 
00247           = ConvergenceTest::checkConvergenceOfAtLeast(gibbsConvergenceTests_, 
00248                                                        state_->getNumAtoms(),
00249                                                        sample, fracConverged_,
00250                                                        true);
00251         if (   (sample >= minSteps_ && gibbsConverged) 
00252             || (maxSteps_ >= 0 && sample >= maxSteps_)
00253             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_)) 
00254         {
00255           cout << "Done Gibbs sampling. " << sample 
00256                << " samples per pred per chain ("
00257                << (gibbsConverged? "converged":"didn't converge") 
00258                <<" at total of " << numSamplesPerPred << " samples per pred)" 
00259                << endl;
00260           done = true;
00261         }
00262       }
00263       cout.flush();
00264     } // while (!done)
00265     
00266     cout<< "Time taken for Gibbs sampling = "; 
00267     Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00268 
00269       // update gndPreds probability that it is true
00270     for (int i = 0; i < state_->getNumAtoms(); i++)
00271     {
00272       setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00273     }
00274 
00275       // If keeping track of true clause groundings
00276     if (trackClauseTrueCnts_)
00277     {
00278         // Set the true counts to the average over all samples
00279       for (int i = 0; i < clauseTrueCnts_->size(); i++)
00280         (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00281     }
00282   }
00283   
00284  private:
00285  
00289   void initConvergenceTests(GelmanConvergenceTest**& burnConvergenceTests,
00290                             ConvergenceTest**& gibbsConvergenceTests, 
00291                             const double& gamma, const double& epsilonFrac, 
00292                             const int& numGndPreds, const int& numChains)
00293   {
00294     burnConvergenceTests = new GelmanConvergenceTest*[numGndPreds];
00295     gibbsConvergenceTests = new ConvergenceTest*[numGndPreds];
00296     for (int i = 0; i < numGndPreds; i++) 
00297     {
00298       burnConvergenceTests[i]  = new GelmanConvergenceTest(numChains);
00299       gibbsConvergenceTests[i] = new ConvergenceTest(numChains, gamma,
00300                                                      epsilonFrac);
00301     }
00302   }
00303 
00307   void deleteConvergenceTests(GelmanConvergenceTest**& burnConvergenceTests,
00308                               ConvergenceTest**& gibbsConvergenceTests, 
00309                               const int& numGndPreds)
00310   {
00311     for (int i = 0; i < numGndPreds; i++) 
00312     {
00313       delete burnConvergenceTests[i];
00314       delete gibbsConvergenceTests[i];
00315     }
00316     delete [] burnConvergenceTests;
00317     delete [] gibbsConvergenceTests;
00318   }  
00319   
00320  private:
00321     // Gamma used by convergence test
00322   double gamma_;
00323     // Epsilon used by convergence test
00324   double epsilonError_;
00325     // Fraction of samples needed to converge
00326   double fracConverged_;
00327     // 0 = Initialize randomly, 1 = initialize with MaxWalksat
00328   int walksatType_;
00329     // Number of samples between checking for convergence
00330   int samplesPerTest_;
00331     // Convergence test for burning in
00332   GelmanConvergenceTest** burnConvergenceTests_;
00333     // Convergence test for sampling
00334   ConvergenceTest** gibbsConvergenceTests_;
00335     
00336     // MaxWalksat is used for initialization if walksatType_ = 1
00337   MaxWalkSat* mws_;  
00338 };
00339 
00340 #endif /*GIBBSSAMPLER_H_*/

Generated on Tue Jan 16 05:30:02 2007 for Alchemy by  doxygen 1.5.1