hmcsat.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, and Daniel Lowd.
00006  * 
00007  * Copyright [2004-08] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, and Daniel Lowd. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, and Daniel Lowd in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef HMCSAT_H_                                          
00067 #define HMCSAT_H_
00068 
00069 #include "mcmc.h"
00070 #include "mcsatparams.h"
00071 #include "hmaxwalksat.h"
00072 #include "hvariablestate.h"
00073 
00074 const int hmsdebug = false;
00075 const int wjhmsdebug = false;
00076 
00083 class HMCSAT : public MCMC
00084 {
00085  public:
00089   HMCSAT(HVariableState* state, long int seed, const bool& trackClauseTrueCnts,
00090          MCSatParams* mcsatParams) 
00091        : MCMC(state, seed, trackClauseTrueCnts, mcsatParams)
00092   {
00093       // We don't need to track clause true counts in up and ss
00094     mws_ = new HMaxWalkSat(hstate_, seed, false, mcsatParams->mwsParams);
00095     bMaxWalkSat_ = false;
00096     contSamples_.growToSize(hstate_->contAtomNum_);
00097     print_vars_per_sample_ = false;
00098   }
00099 
00103   ~HMCSAT()
00104   {
00105     delete mws_;
00106   }
00107 
00108   void initNumTrueTotal()
00109   {
00110     int numDisPreds = hstate_->getNumAtoms();
00111     numTrue_.growToSize(numDisPreds, 0);
00112   }
00113 
00117   void init()
00118   {
00119     assert(numChains_ == 1);
00120     initNumTrueTotal();
00121     hstate_->eliminateSoftClauses();
00122       // Set num. of solutions temporarily to 1
00123     int numSolutions = mws_->getNumSolutions();
00124     mws_->setNumSolutions(1);
00125       // Initialize with MWS
00126     if (!bMaxWalkSat_)
00127     {
00128       hstate_->makeUnitCosts();
00129     }
00130 
00131       // get the initialization by satisfying only hard clauses
00132       // Since the hybrid constraints' threshold are intialized to very large negative values, they are all satisfied and wound't be involved in inference here.
00133     mws_->setHeuristic(TABU);
00134     mws_->init();
00135     mws_->infer();
00136 
00137     if (hmsdebug) 
00138     {
00139       cout << "Low state:" << endl;
00140       hstate_->printLowState(cout);
00141     }
00142 
00143     hstate_->saveLowStateToGndPreds();
00144     hstate_->saveCurrentAsLastAssignment();
00145     hstate_->UpdateHybridConstraintTh();
00146       // Set heuristic to SampleSat (Initialization could have been different)
00147     mws_->setHeuristic(SS);
00148     mws_->setNumSolutions(numSolutions);
00149     mws_->setTargetCost(0.0);
00150     hstate_->resetDeadClauses();
00151     sample_ = 0;
00152   }
00153 
00154         void printContSamples()
00155         {
00156                 for(int i = 0; i < contSamples_.size(); i ++)
00157                 {
00158                         for (int j = 0; j < contSamples_[i].size(); j++)
00159                         {
00160                                 contSampleLog_ << contSamples_[i][j] << " ";
00161                         }
00162                         contSampleLog_ << endl;
00163                 }
00164         }
00165 
00169         void infer()
00170         {
00171                 Timer timer;
00172                 // Burn-in only if burnMaxSteps positive
00173                 bool burningIn = (burnMaxSteps_ > 0) ? true : false;
00174                 double secondsElapsed = 0;
00175                 upSecondsElapsed_ = 0;
00176                 ssSecondsElapsed_ = 0;
00177                 double startTimeSec = timer.time();
00178                 double currentTimeSec;
00179                 int samplesPerOutput = 100;
00180 
00181                 // If keeping track of true clause groundings, then init to zero
00182                 if (trackClauseTrueCnts_)
00183                 {
00184                         for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00185                                 (*clauseTrueCnts_)[clauseno] = 0;
00186                         for (int clauseno = 0; clauseno < clauseTrueCntsCont_->size(); clauseno++)
00187                                 (*clauseTrueCntsCont_)[clauseno] = 0;
00188                 }
00189 
00190                 // Holds the ground preds which have currently been affected
00191                 GroundPredicateHashArray affectedGndPreds;
00192                 Array<int> affectedGndPredIndices;
00193                 
00194                 // Update the weights for Gibbs step
00195                 int numAtoms = hstate_->getNumAtoms();
00196 
00197                 for (int i = 0; i < numAtoms; i++)
00198                 {
00199                         affectedGndPreds.append(hstate_->getGndPred(i), numAtoms);
00200                         affectedGndPredIndices.append(i);
00201                 }
00202 
00203                 updateWtsForGndPredsH(affectedGndPreds, affectedGndPredIndices, 0);
00204                 affectedGndPreds.clear();
00205                 affectedGndPredIndices.clear();
00206                 cout << "Running MC-SAT sampling..." << endl;
00207                 // Sampling loop
00208                 int numSamplesPerPred = 0;
00209                 bool done = false;
00210                 while (!done)
00211                 {
00212                         ++sample_;
00213                         if (sample_ % samplesPerOutput == 0)
00214                         { 
00215                                 currentTimeSec = timer.time();
00216                                 secondsElapsed = currentTimeSec - startTimeSec;
00217                                 cout << "sample_ (per pred) " << sample_ << ", time elapsed = ";
00218                                 Timer::printTime(cout, secondsElapsed); cout << endl;
00219                         }
00220 
00221           performMCSatStep(burningIn);
00222 
00223                         if (!burningIn) numSamplesPerPred++;
00224 
00225                         if (burningIn) 
00226                         {
00227                                 if ((burnMaxSteps_ >= 0 && sample_ >= burnMaxSteps_) || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_))
00228                                 {
00229                                         cout << "Done burning. " << sample_ << " samples." << endl;
00230                                         burningIn = false;
00231                                         sample_ = 0;
00232                                 }
00233                         }
00234                         else 
00235                         {
00236                                 if (   (maxSteps_ >= 0 && sample_ >= maxSteps_)
00237                                         || (maxSeconds_ > 0 && secondsElapsed >= maxSeconds_)) 
00238                                 {
00239                                         cout << "Done MC-SAT sampling. " << sample_ << " samples."             
00240                                                 << endl;
00241                                         done = true;
00242                                 }
00243                         }
00244                         cout.flush();
00245                 } // while (!done)
00246 
00247                 cout<< "Time taken for MC-SAT sampling = "; 
00248                 Timer::printTime(cout, timer.time() - startTimeSec); cout << endl;
00249 
00250                 cout<< "Time taken for unit propagation = "; 
00251                 Timer::printTime(cout, upSecondsElapsed_); cout << endl;
00252 
00253                 cout<< "Time taken for SampleSat = "; 
00254                 Timer::printTime(cout, ssSecondsElapsed_); cout << endl;
00255 
00256                 // Update gndPreds probability that it is true
00257                 for (int i = 0; i < hstate_->getNumAtoms(); i++)
00258                 {
00259                         setProbTrue(i, numTrue_[i] / numSamplesPerPred);
00260                 }
00261 
00262                 // If keeping track of true clause groundings
00263                 if (trackClauseTrueCnts_)
00264                 {
00265                         // Set the true counts to the average over all samples
00266                         for (int i = 0; i < clauseTrueCnts_->size(); i++)
00267                                 (*clauseTrueCnts_)[i] = (*clauseTrueCnts_)[i] / numSamplesPerPred;
00268                         for (int i = 0; i < clauseTrueCntsCont_->size(); i++)
00269                                 (*clauseTrueCntsCont_)[i] = (*clauseTrueCntsCont_)[i] / numSamplesPerPred;
00270                 }
00271         }
00272 
00273 
00274         void SetContSampleFile(const char* contsamplelog)
00275         {
00276                 contSampleLog_.open(contsamplelog);
00277                 if (!contSampleLog_.is_open())
00278                 {
00279                         cout << "cont sample log file is not specified." << endl;
00280                 }
00281         }
00282 
00283 
00284 
00285         void SetPrintVarsPerSample(bool pps) {
00286             print_vars_per_sample_ = pps; 
00287         }
00288 
00289 
00290         bool bMaxWalkSat_;  // Variable indicating whether this inference is for hybrid MaxWalkSAT or hybrid MCSAT.
00291         
00292 private:
00298         void performMCSatStep(const bool& burningIn)
00299         {
00300                 if (hmsdebug)
00301                 {
00302                         cout << "Num of clauses " << hstate_->getNumClauses() << endl;
00303                         cout << "Num of dead clauses " << hstate_->getNumDeadClauses() << endl;
00304                 }
00305                 Timer timer;
00306                 double startTime;
00307                 if (hmsdebug) cout << "Entering MC-SAT step" << endl;
00308                 // Clause selection
00309                 hstate_->setUseThreshold(true);
00310                 int start = 0;          
00311                 
00312                 //at this point, 
00313                 //the dis and conti atom states needs to be set according to last round solution
00314                 hstate_->killClauses(start); //update the set of constraints of discrete variables
00315                 
00316                 // SampleSat on the clauses
00317                 startTime = timer.time();
00318                 mws_->init();
00319                 mws_->infer();
00320                 ssSecondsElapsed_ += (timer.time() - startTime);
00321                 //here we save the lowest state as the current state
00322                 hstate_->saveLowStateToGndPreds();
00323                 if (print_vars_per_sample_) {
00324                     cout << "HMCS Iter#" << sample_ << ":\t";
00325                     for (int i = 0; i < hstate_->getNumAtoms(); ++i) {
00326                         cout << hstate_->atom_[i + 1] << "\t";
00327                     }
00328 
00329                     for (int i = 0; i < hstate_->getNumContAtoms(); ++i) {
00330                         cout << hstate_->contAtoms_[i+1] << "\t";
00331                     }
00332                     cout << endl;
00333                 }
00334                 if (wjhmsdebug && hstate_->costOfTotalFalseConstraints_ < mws_->getTargetCost() + SMALLVALUE)
00335                 {
00336                         hstate_->printLowStateAll(cout);
00337                 }
00338                 if (wjhmsdebug && hstate_->costOfTotalFalseConstraints_ > mws_->getTargetCost() + SMALLVALUE)
00339                 {
00340                         cout << "not all satisfied, at sample " << sample_ << ". " << endl;
00341                         cout << "Error at sample: " << sample_ << ", " << hstate_->costOfTotalFalseConstraints_ << endl;
00342                         cout << "False cont constraints: " << hstate_->costHybridFalseConstraint_ << ", false discrete: " << hstate_->getCostOfFalseClauses() << endl;
00343                         //print out false clauses & constraints
00344                         hstate_->printFalseClauses(cout);
00345                 }
00346                 // Reset parameters needed for HMCSAT step              
00347                 int numAtoms = hstate_->getNumAtoms();
00348                 for (int i = 0; i < numAtoms; i++)
00349                 {
00350                         GroundPredicate* gndPred = hstate_->getGndPred(i);
00351                         bool newAssignment = hstate_->getValueOfLowAtom(i + 1);
00352                         // No need to update weight but still need to update truth/NumSat
00353                         if (newAssignment != gndPred->getTruthValue()) {
00354                                 gndPred->setTruthValue(newAssignment);
00355                                 updateClauses(i);
00356                         }
00357                 }
00358                 // Write the current sample of continuous variables to file.
00359                 for(int i = 0; i < hstate_->contAtomNum_; i++)
00360                 {
00361                         contSamples_[i].append(hstate_->contAtoms_[i+1]);
00362                         contSampleLog_ << hstate_->contAtoms_[i+1] << " ";
00363                 }
00364                 contSampleLog_ << endl;
00365 
00366                 for( int i = 0; i < hstate_->getNumAtoms(); i++)
00367                 {
00368                         bool newAssignment = hstate_->getValueOfAtom(i+1);
00369                         if (!burningIn && newAssignment) numTrue_[i]++;
00370                 }       
00371                 hstate_->resetFixedAtoms();
00372                 hstate_->resetDeadClauses(); // update true lit num for each dis clause
00373                 hstate_->saveCurrentAsLastAssignment(); // backup
00374                 // Continuous thresholds are updated after each sample round.
00375                 hstate_->UpdateHybridConstraintTh();
00376                 hstate_->setUseThreshold(false);
00377                 // If keeping track of true clause groundings
00378                 if (!burningIn && trackClauseTrueCnts_)
00379                 {
00380                         hstate_->getNumClauseGndings(clauseTrueCnts_, true);
00381                         hstate_->getContClauseGndings(clauseTrueCntsCont_);                     
00382                 }               
00383                 if (hmsdebug) cout << "Leaving MC-SAT step" << endl;
00384         }
00385 
00393         void updateClauses(const int& gndPredIdx)
00394         {
00395                 if (hmsdebug) cout << "Entering updateClauses" << endl;
00396                 GroundPredicate* gndPred = hstate_->getGndPred(gndPredIdx);
00397                 Array<int>& negGndClauses =
00398                         hstate_->getNegOccurenceArray(gndPredIdx + 1);
00399                 Array<int>& posGndClauses =
00400                         hstate_->getPosOccurenceArray(gndPredIdx + 1);
00401                 int gndClauseIdx;
00402                 bool sense;
00403 
00404                 for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00405                 {
00406                         if (i < negGndClauses.size())
00407                         {
00408                                 gndClauseIdx = negGndClauses[i];
00409                                 sense = false;
00410                         }
00411                         else
00412                         {
00413                                 gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00414                                 sense = true;
00415                         }
00416 
00417                         if (gndPred->getTruthValue() == sense)
00418                                 hstate_->incrementNumTrueLits(gndClauseIdx);
00419                         else
00420                                 hstate_->decrementNumTrueLits(gndClauseIdx);
00421                 }
00422                 if (hmsdebug) cout << "Leaving updateClauses" << endl;
00423         }
00424 
00425 private:
00426         ofstream contSampleLog_;
00427         int sample_;
00428         bool print_vars_per_sample_;
00429 
00430         Array<Array<double> > contSamples_;
00431 
00432         // Unit propagation is performed in MC-SAT  
00433         //HUnitPropagation* up_;
00434         // The base algorithm is SampleSat (MaxWalkSat with SS parameters)
00435         HMaxWalkSat* mws_;
00436 
00437         // Time spent on UnitPropagation
00438         double upSecondsElapsed_;
00439         // Time spent on SampleSat
00440         double ssSecondsElapsed_;
00441 };
00442 
00443 #endif /*HMCSAT_H_*/

Generated on Sun Jun 7 11:55:12 2009 for Alchemy by  doxygen 1.5.1