mcsat.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #ifndef MCSAT_H_
00068 #define MCSAT_H_
00069 
00070 #include "mcmc.h"
00071 #include "mcsatparams.h"
00072 #include "unitpropagation.h"
00073 #include "maxwalksat.h"
00074 
00075 const int msdebug = false;
00076 
00083 class MCSAT : public MCMC
00084 {
00085  public:
00086 
00090   MCSAT(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00091         MCSatParams* mcsatParams,
00092         Array<Array<Predicate* >* >* queryFormulas = NULL)
00093     : MCMC(state, seed, trackClauseTrueCnts, mcsatParams, queryFormulas)
00094   {
00095         Timer timer1;
00096 
00097       // We don't need to track clause true counts in up and ss
00098     //up_ = new UnitPropagation(state_, seed, false);
00099     mws_ = new MaxWalkSat(state_, seed, false, mcsatParams->mwsParams);
00100     mws_->setPrintInfo(false);
00101 
00102     if (msdebug)
00103     {
00104       cout << "[MCSAT] ";
00105       Timer::printTime(cout, timer1.time());
00106       cout << endl;
00107       timer1.reset();
00108     }
00109   }
00110 
00114   ~MCSAT()
00115   {
00116     //delete up_;
00117     delete mws_;
00118   }
00119   
00123   void init()
00124   {
00125     Timer timer1;
00126     assert(numChains_ == 1);
00127 
00128     cout << "Initializing MC-SAT with MaxWalksat on hard clauses..." << endl;
00129     
00130     state_->eliminateSoftClauses();
00131     state_->setInferenceMode(state_->MODE_HARD);
00132 
00133       // Set num. of solutions temporarily to 1
00134     int numSolutions = mws_->getNumSolutions();
00135     mws_->setNumSolutions(1);
00136 
00137       // Initialize with MWS
00138     mws_->init();
00139     mws_->infer();
00140 
00141     if (msdebug) 
00142     {
00143       cout << "Low state:" << endl;
00144       state_->printLowState(cout);
00145     }
00146     state_->saveLowStateToGndPreds();
00147 
00148       // Set heuristic to SampleSat (Initialization could have been different)
00149     mws_->setHeuristic(SS);
00150     mws_->setNumSolutions(numSolutions);
00151     mws_->setTargetCost(0.0);
00152     state_->resetDeadClauses();
00153     
00154         // state_->makeUnitCosts();
00155     state_->setInferenceMode(state_->MODE_SAMPLESAT);   
00156 
00157     if (msdebug)
00158     {
00159       cout << "[MCSAT.init] ";
00160       Timer::printTime(cout, timer1.time());
00161       cout << endl;
00162       timer1.reset();
00163     }
00164   }
00165 
00169   void infer()
00170   {
00171         Timer timer1;
00172 
00173     initNumTrue();
00174     Timer timer;
00175       // Burn-in only if burnMaxSteps positive
00176     bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00177     double secondsElapsed = 0;
00178     //upSecondsElapsed_ = 0;
00179     ssSecondsElapsed_ = 0;
00180     double startTimeSec = timer.time();
00181     double currentTimeSec;
00182     int samplesPerOutput = 100;
00183 
00184       // Holds the ground preds which have currently been affected
00185     GroundPredicateHashArray affectedGndPreds;
00186     Array<int> affectedGndPredIndices;
00187       // Update the weights for Gibbs step
00188     int numAtoms = state_->getNumAtoms();
00189     for (int i = 0; i < numAtoms; i++)
00190     {
00191       affectedGndPreds.append(state_->getGndPred(i), numAtoms);
00192       affectedGndPredIndices.append(i);
00193     }
00194     updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices, 0);
00195     affectedGndPreds.clear();
00196     affectedGndPredIndices.clear();
00197 
00198     if (msdebug)
00199     {   
00200       cout << "[MCSAT.infer.prep] "; 
00201       Timer::printTime(cout, timer1.time());
00202       cout << endl;
00203       timer1.reset();
00204     }
00205 
00206     cout << "Running MC-SAT sampling..." << endl;
00207       // Sampling loop
00208     int sample = 0;
00209     int numSamplesPerPred = 0;
00210     bool done = false;
00211     while (!done)
00212     {
00213       ++sample;
00214       if (sample % samplesPerOutput == 0)
00215       { 
00216         currentTimeSec = timer.time();
00217         secondsElapsed = currentTimeSec - startTimeSec;
00218         cout << "Sample (per pred) " << sample << ", time elapsed = ";
00219         Timer::printTime(cout, secondsElapsed);
00220         cout << ", num. preds = " << state_->getNumAtoms();
00221                 cout << ", num. clauses = " << state_->getNumClauses();
00222                 cout << endl;           
00223       }
00224 
00225         // For each node, generate the node's new truth value
00226       performMCSatStep(burningIn);
00227         
00228       if (!burningIn) numSamplesPerPred++;
00229 
00230       if (burningIn) 
00231       {
00232         if (   (burnMaxSteps_ >= 0 && sample >= burnMaxSteps_)
00233             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00234         {
00235           cout << "Done burning. " << sample << " samples." << endl;
00236           burningIn = false;
00237           sample = 0;
00238         }
00239       }
00240       else 
00241       {
00242         if (   (maxSteps_ >= 0 && sample >= maxSteps_)
00243             || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_)) 
00244         {
00245           cout << "Done MC-SAT sampling. " << sample << " samples."             
00246                << endl;
00247           done = true;
00248         }
00249       }
00250           cout.flush();
00251     } // while (!done)
00252 
00253     cout << "Final ground predicate number: " << state_->getNumAtoms() << endl;
00254     cout << "Final ground clause number: " << state_->getNumClauses() << endl;
00255 
00256     cout<< "Time taken for MC-SAT sampling = "; 
00257     Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00258 
00259     //cout<< "Time taken for unit propagation = "; 
00260     //Timer::printTime(cout, upSecondsElapsed_); cout << endl;
00261 
00262     cout<< "Time taken for SampleSat = "; 
00263     Timer::printTime(cout, ssSecondsElapsed_); cout << endl;
00264 
00265       // Update gndPreds probability that it is true
00266     for (int i = 0; i < state_->getNumAtoms(); i++)
00267     {
00268       setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00269     }
00270     
00271       // Update query formula probs
00272     if (qfProbs_)
00273     {
00274       for (int j = 0; j < qfProbs_->size(); j++)
00275         (*qfProbs_)[j] = (*qfProbs_)[j] / numSamplesPerPred;
00276     }    
00277   }
00278 
00279  private:
00280  
00286   void performMCSatStep(const bool& burningIn)
00287   {
00288     Timer timer;
00289     double startTime;
00290     if (msdebug) cout << "Entering MC-SAT step" << endl;
00291       // Clause selection
00292     state_->setUseThreshold(true);
00293     state_->updatePrevSatisfied();
00294 
00295         int start = 0;
00296     
00297     state_->resetMakeBreakCostWatch();
00298 
00299         state_->killClauses(start);
00300 
00301     if (msdebug)
00302     {
00303       cout << "Num of clauses " << state_->getNumClauses() << endl;
00304       cout << "Num of dead clauses " << state_->getNumDeadClauses() << endl;
00305     }
00306     
00307       // Unit propagation on the clauses
00308     startTime = timer.time();    
00309         //up_->init();
00310     //up_->infer();
00311     //upSecondsElapsed_ += (timer.time() - startTime);
00312       // SampleSat on the clauses
00313     startTime = timer.time();
00314     mws_->init();
00315     mws_->infer();
00316     ssSecondsElapsed_ += (timer.time() - startTime);
00317 
00318     if (msdebug) 
00319     {
00320       cout << "Low state:" << endl;
00321       state_->printLowState(cout);
00322     }
00323     //state_->saveLowStateToGndPreds();
00324 
00325       // Reset parameters needed for MCSat step
00326     state_->resetFixedAtoms();
00327     state_->resetDeadClauses();
00328     state_->setUseThreshold(false);
00329     int numAtoms = state_->getNumAtoms();
00330       // If lazy, atoms may have been added, so have to grow numTrue_
00331     numTrue_.growToSize(numAtoms, 0);
00332 
00333     for (int i = 0; i < numAtoms; i++)
00334     {
00335       GroundPredicate* gndPred = state_->getGndPred(i);
00336       bool newAssignment = state_->getValueOfLowAtom(i + 1);
00337       
00338         // No need to update weight but still need to update truth/NumSat
00339       if (newAssignment != gndPred->getTruthValue())
00340       {
00341         gndPred->setTruthValue(newAssignment);
00342         updateClauses(i);
00343       }
00344 
00345         // If in actual sampling phase, track the num of times
00346         // the ground predicate is set to true
00347       if (!burningIn && newAssignment) numTrue_[i]++;
00348     }
00349     
00350       // If keeping track of true clause groundings
00351     if (!burningIn && trackClauseTrueCnts_)
00352       tallyCntsFromState();
00353       
00354       // If there are query formulas
00355     if (!burningIn && qfProbs_)
00356     {
00357       for (int i = 0; i < queryFormulas_->size(); i++)
00358       {
00359         Array<Predicate* >* formula = (*queryFormulas_)[i];
00360         bool satisfied = true;
00361         for (int j = 0; j < formula->size(); j++)
00362         {
00363           bool sense = (*formula)[j]->getSense();
00364           GroundPredicate* pred = new GroundPredicate((*formula)[j]);
00365           TruthValue tv = state_->getDomain()->getDB()->getValue(pred);
00366           if ((tv == TRUE && !sense) || (tv == FALSE && sense))
00367             satisfied = false;
00368           delete pred;
00369           if (!satisfied) break;
00370         }
00371         if (satisfied) (*qfProbs_)[i]++;
00372       }
00373     }
00374     
00375     if (msdebug) cout << "Leaving MC-SAT step" << endl;
00376   }
00377   
00385   void updateClauses(const int& gndPredIdx)
00386   {
00387     if (msdebug) cout << "Entering updateClauses" << endl;
00388     GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
00389     Array<int>& negGndClauses =
00390       state_->getNegOccurenceArray(gndPredIdx + 1);
00391     Array<int>& posGndClauses =
00392       state_->getPosOccurenceArray(gndPredIdx + 1);
00393     int gndClauseIdx;
00394     bool sense;
00395 
00396     for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00397     {
00398       if (i < negGndClauses.size())
00399       {
00400         gndClauseIdx = negGndClauses[i];
00401         sense = false;
00402       }
00403       else
00404       {
00405         gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00406         sense = true;
00407       }
00408 
00409       if (gndPred->getTruthValue() == sense)
00410         state_->incrementNumTrueLits(gndClauseIdx);
00411       else
00412         state_->decrementNumTrueLits(gndClauseIdx);
00413     }
00414     if (msdebug) cout << "Leaving updateClauses" << endl;
00415   }
00416   
00417  private:
00418 
00419     // Number of total steps (MC-SAT & Gibbs) for each MC-SAT step
00420   //int numStepsEveryMCSat_;
00421 
00422     // Unit propagation is performed in MC-SAT  
00423   //UnitPropagation* up_;
00424     // The base algorithm is SampleSat (MaxWalkSat with SS parameters)
00425   MaxWalkSat* mws_;
00426 
00427     // Time spent on UnitPropagation
00428   //double upSecondsElapsed_;
00429     // Time spent on SampleSat
00430   double ssSecondsElapsed_;
00431 };
00432 
00433 #endif /*MCSAT_H_*/

Generated on Sun Jun 7 11:55:12 2009 for Alchemy by  doxygen 1.5.1