mcsat.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef MCSAT_H_
00067 #define MCSAT_H_
00068 
00069 #include "mcmc.h"
00070 #include "mcsatparams.h"
00071 #include "unitpropagation.h"
00072 #include "maxwalksat.h"
00073 
00074 const int msdebug = false;
00075 
00082 class MCSAT : public MCMC
00083 {
00084  public:
00085 
00089   MCSAT(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00090         MCSatParams* mcsatParams)
00091     : MCMC(state, seed, trackClauseTrueCnts, mcsatParams)
00092   {
00093     numStepsEveryMCSat_ = mcsatParams->numStepsEveryMCSat;
00094       // We don't need to track clause true counts in up and ss
00095     up_ = new UnitPropagation(state_, seed, false);
00096     mws_ = new MaxWalkSat(state_, seed, false, mcsatParams->mwsParams);
00097   }
00098 
00102   ~MCSAT()
00103   {
00104     delete up_;
00105     delete mws_;
00106   }
00107   
00111   void init()
00112   {
00113     assert(numChains_ == 1);
00114     initNumTrue();
00115 
00116     cout << "Initializing MC-SAT with MaxWalksat on hard clauses..." << endl;
00117     state_->eliminateSoftClauses();
00118       // Set num. of solutions temporarily to 1
00119     int numSolutions = mws_->getNumSolutions();
00120     mws_->setNumSolutions(1);
00121 
00122       // Initialize with MWS
00123     mws_->init();
00124     mws_->infer();
00125 
00126     if (msdebug) 
00127     {
00128       cout << "Low state:" << endl;
00129       state_->printLowState(cout);
00130     }
00131     state_->saveLowStateToGndPreds();
00132 
00133       // Set heuristic to SampleSat (Initialization could have been different)
00134     mws_->setHeuristic(SS);
00135     mws_->setNumSolutions(numSolutions);
00136     mws_->setTargetCost(0.0);
00137     state_->resetDeadClauses();
00138     state_->makeUnitCosts();
00139   }
00140 
00144   void infer()
00145   {
00146     Timer timer;
00147       // Burn-in only if burnMaxSteps positive
00148     bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00149     double secondsElapsed = 0;
00150     upSecondsElapsed_ = 0;
00151     ssSecondsElapsed_ = 0;
00152     double startTimeSec = timer.time();
00153     double currentTimeSec;
00154     int samplesPerOutput = 100;
00155 
00156       // If keeping track of true clause groundings, then init to zero
00157     if (trackClauseTrueCnts_)
00158       for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00159         (*clauseTrueCnts_)[clauseno] = 0;
00160 
00161       // Holds the ground preds which have currently been affected
00162     GroundPredicateHashArray affectedGndPreds;
00163     Array<int> affectedGndPredIndices;
00164       // Update the weights for Gibbs step
00165     int numAtoms = state_->getNumAtoms();
00166     for (int i = 0; i < numAtoms; i++)
00167     {
00168       affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00169       affectedGndPredIndices.append(i);
00170     }
00171     updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 0);
00172     affectedGndPreds.clear();
00173     affectedGndPredIndices.clear();
00174 
00175     cout << "Running MC-SAT sampling..." << endl;
00176       // Sampling loop
00177     int sample = 0;
00178     int numSamplesPerPred = 0;
00179     bool done = false;
00180     while (!done)
00181     {
00182       ++sample;
00183       bool mcSatStep = (sample % numStepsEveryMCSat_ == 0);
00184     
00185       if (sample % samplesPerOutput == 0)
00186       { 
00187         currentTimeSec = timer.time();
00188         secondsElapsed = currentTimeSec - startTimeSec;
00189         cout << "Sample (per pred) " << sample << ", time elapsed = ";
00190         Timer::printTime(cout, secondsElapsed); cout << endl;
00191       }
00192 
00193         // For each node, generate the node's new truth value
00194       if (mcSatStep) performMCSatStep(burningIn);
00195         // Defined in MCMC. Chain is set to 0, but single chain is considered
00196         // in performGibbsStep
00197       else performGibbsStep(0, burningIn, affectedGndPreds,
00198                             affectedGndPredIndices);
00199         
00200       if (!burningIn) numSamplesPerPred++;
00201 
00202       if (burningIn) 
00203       {
00204         if (   (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00205             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00206         {
00207           cout << "Done burning. " << sample << " samples." << endl;
00208           burningIn = false;
00209           sample = 0;
00210         }
00211       }
00212       else 
00213       {
00214         if (   (maxSteps_ >= 0 && sample >= maxSteps_)
00215             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_)) 
00216         {
00217           cout << "Done MC-SAT sampling. " << sample << " samples."             
00218                << endl;
00219           done = true;
00220         }
00221       }
00222       cout.flush();
00223     } // while (!done)
00224 
00225     cout<< "Time taken for MC-SAT sampling = "; 
00226     Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00227 
00228     cout<< "Time taken for unit propagation = "; 
00229     Timer::printTime(cout, upSecondsElapsed_); cout << endl;
00230 
00231     cout<< "Time taken for SampleSat = "; 
00232     Timer::printTime(cout, ssSecondsElapsed_); cout << endl;
00233 
00234       // Update gndPreds probability that it is true
00235     for (int i = 0; i < state_->getNumAtoms(); i++)
00236     {
00237       setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00238     }
00239 
00240       // If keeping track of true clause groundings
00241     if (trackClauseTrueCnts_)
00242     {
00243         // Set the true counts to the average over all samples
00244       for (int i = 0; i < clauseTrueCnts_->size(); i++)
00245         (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00246     }
00247   }
00248 
00249  private:
00250  
00256   void performMCSatStep(const bool& burningIn)
00257   {
00258     Timer timer;
00259     double startTime;
00260     if (msdebug) cout << "Entering MC-SAT step" << endl;
00261       // Clause selection
00262     state_->setUseThreshold(true);
00263     int start = 0;
00264     state_->killClauses(start);
00265 
00266     if (msdebug)
00267     {
00268       cout << "Num of clauses " << state_->getNumClauses() << endl;
00269       cout << "Num of dead clauses " << state_->getNumDeadClauses() << endl;
00270     }
00271     
00272       // Unit propagation on the clauses
00273     startTime = timer.time();
00274     up_->init();
00275     up_->infer();
00276     upSecondsElapsed_ += (timer.time() - startTime);
00277       // SampleSat on the clauses
00278     startTime = timer.time();
00279     mws_->init();
00280     mws_->infer();
00281     ssSecondsElapsed_ += (timer.time() - startTime);
00282 
00283     if (msdebug) 
00284     {
00285       cout << "Low state:" << endl;
00286       state_->printLowState(cout);
00287     }
00288     state_->saveLowStateToGndPreds();
00289 
00290       // Reset parameters needed for MCSat step
00291     state_->resetFixedAtoms();
00292     state_->resetDeadClauses();
00293     state_->setUseThreshold(false);
00294     int numAtoms = state_->getNumAtoms();
00295     for (int i = 0; i < numAtoms; i++)
00296     {
00297       GroundPredicate* gndPred = state_->getGndPred(i);
00298       bool newAssignment = state_->getValueOfLowAtom(i + 1);
00299       
00300         // No need to update weight but still need to update truth/NumSat
00301       if (newAssignment != gndPred->getTruthValue())
00302       {
00303         gndPred->setTruthValue(newAssignment);
00304         updateClauses(i);
00305       }
00306 
00307         // If in actual sampling phase, track the num of times
00308         // the ground predicate is set to true
00309       if (!burningIn && newAssignment) numTrue_[i]++;
00310     }
00311     
00312       // If keeping track of true clause groundings
00313     if (!burningIn && trackClauseTrueCnts_)
00314       state_->getNumClauseGndings(clauseTrueCnts_, true);
00315     
00316     if (msdebug) cout << "Leaving MC-SAT step" << endl;
00317   }
00318   
00326   void updateClauses(const int& gndPredIdx)
00327   {
00328     if (msdebug) cout << "Entering updateClauses" << endl;
00329     GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
00330     Array<int>& negGndClauses =
00331       state_->getNegOccurenceArray(gndPredIdx + 1);
00332     Array<int>& posGndClauses =
00333       state_->getPosOccurenceArray(gndPredIdx + 1);
00334     int gndClauseIdx;
00335     bool sense;
00336 
00337     for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00338     {
00339       if (i < negGndClauses.size())
00340       {
00341         gndClauseIdx = negGndClauses[i];
00342         sense = false;
00343       }
00344       else
00345       {
00346         gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00347         sense = true;
00348       }
00349 
00350       if (gndPred->getTruthValue() == sense)
00351         state_->incrementNumTrueLits(gndClauseIdx);
00352       else
00353         state_->decrementNumTrueLits(gndClauseIdx);
00354     }
00355     if (msdebug) cout << "Leaving updateClauses" << endl;
00356   }
00357   
00358  private:
00359 
00360     // Number of total steps (MC-SAT & Gibbs) for each MC-SAT step
00361   int numStepsEveryMCSat_;
00362 
00363     // Unit propagation is performed in MC-SAT  
00364   UnitPropagation* up_;
00365     // The base algorithm is SampleSat (MaxWalkSat with SS parameters)
00366   MaxWalkSat* mws_;
00367 
00368     // Time spent on UnitPropagation
00369   double upSecondsElapsed_;
00370     // Time spent on SampleSat
00371   double ssSecondsElapsed_;
00372 };
00373 
00374 #endif /*MCSAT_H_*/

Generated on Tue Jan 16 05:30:03 2007 for Alchemy by  doxygen 1.5.1