lazyinfo.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef _Lazy_Info_H_
00067 #define _Lazy_Info_H_
00068 
00069 #include <fstream>
00070 #include <sys/time.h>
00071 
00072 #include "clause.h"
00073 #include "clausefactory.h"
00074 #include "lazyutil.h"
00075 #include "intclause.h"
00076 #include "mln.h"
00077 #include "timer.h"
00078 
00079 class LazyInfo
00080 {
00081 public:
00082 
00083  //constructor   
00084 LazyInfo(MLN *mln, Domain *domain)
00085 {
00086       // Randomizer
00087   int seed;
00088   struct timeval tv;
00089   struct timezone tzp;
00090   gettimeofday(&tv,&tzp);
00091   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00092   srandom(seed);
00093     
00094   this->mln_ = mln;
00095   this->domain_ = domain;
00096   copyMLN();
00097   setHardClauseWeight();
00098   sampleSat_ = false;
00099   prevDB_ = NULL;
00100   numDBAtoms_ = domain_->getNumNonEvidenceAtoms();
00101   initBlocks();
00102   inBlock_ = false;
00103 }
00104   
00105  //destructor
00106 ~LazyInfo()
00107 {
00108 
00109   for(int i = 0; i < predHashArray_.size(); i++)
00110         delete predHashArray_[i];
00111 
00112     // mln_ is a copy made in copyMLN()
00113   delete mln_;
00114 
00115   if (prevDB_) delete prevDB_;
00116 
00117   for(int i = 0; i < deadClauses_.size(); i++)
00118   {
00119         deadClauses_[i]->deleteIntPredicates();
00120         delete deadClauses_[i];
00121   }
00122 }
00123  
00124 inline int getVarCount() { return predHashArray_.size(); }
00125 
00126 /*
00127  * Makes a copy of the mln.
00128  */
00129 inline void copyMLN()
00130 {
00131 
00132   MLN* posmln = new MLN();
00133   int clauseCnt = mln_->getNumClauses();
00134 
00135         // For each clause, add it to copy
00136   for (int i = 0; i < clauseCnt; i++)
00137   {
00138         Clause* clause = (Clause *) mln_->getClause(i);
00139     if (clause->getWt() == 0) continue;
00140         int numPreds = clause->getNumPredicates();
00141 
00142       // Add clause to mln copy
00143     int idx;
00144     ostringstream oss;
00145     Clause* newClause = new Clause(*clause);
00146     newClause->printWithoutWtWithStrVar(oss, domain_);
00147     bool app = posmln->appendClause(oss.str(), false, newClause,
00148                                                                   clause->getWt(), clause->isHardClause(), idx);
00149     if (app)
00150     {
00151       posmln->setFormulaNumPreds(oss.str(), numPreds);
00152       posmln->setFormulaIsHard(oss.str(), clause->isHardClause());
00153       posmln->setFormulaPriorMean(oss.str(), clause->getWt());
00154     }
00155   }
00156   mln_ = posmln;
00157   //mln_->printMLN(cout, domain_);
00158 }
00159 
00160 /*
00161  * Computes the hard clause weight from the first order clauses in the mln.
00162  */
00163 inline void setHardClauseWeight()
00164 {
00165         // This is the weight used if no soft clauses are present
00166   LazyInfo::HARD_WT = 10.0;
00167   
00168   int clauseCnt = mln_->getNumClauses();
00169   double sumSoftWts = 0.0;
00170   double minWt = DBL_MAX;
00171   double maxWt = DBL_MIN;
00172     //this is a reasonable number to prevent the external executable 
00173     //MaxWalksat from have having overflow errors
00174   int maxAllowedWt = 4000;
00175   
00176         // Sum up the soft weights of all grounded clauses
00177   for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00178   {
00179     Clause* fclause = (Clause *) mln_->getClause(clauseno);
00180         
00181         if (fclause->isHardClause()) continue;
00182         
00183       // Weight could be negative
00184         double wt = abs(fclause->getWt());
00185 
00186     double numGndings = fclause->getNumGroundings(domain_);
00187 
00188         if (wt < minWt) minWt = wt;
00189     if (wt > maxWt) maxWt = wt;
00190     sumSoftWts += wt*numGndings;
00191     assert(minWt >= 0);
00192     assert(maxWt >= 0);
00193   } //for(clauseno < clauseCnt)
00194   assert(sumSoftWts >= 0);
00195 
00196         // If at least one soft clause
00197   if (sumSoftWts > 0)
00198   {
00199           //find out how much to scale weights by
00200         LazyInfo::WSCALE = 1.0;
00201     if (maxWt > maxAllowedWt) LazyInfo::WSCALE = maxAllowedWt/maxWt;
00202     else
00203     {
00204       if (minWt < 10)
00205       {
00206         LazyInfo::WSCALE = 10/minWt;
00207         if (LazyInfo::WSCALE*maxWt > maxAllowedWt) LazyInfo::WSCALE = maxAllowedWt/maxWt;
00208       }
00209     }
00210 
00211         LazyInfo::HARD_WT = (sumSoftWts + 10.0)*LazyInfo::WSCALE;
00212   }
00213   //cout << "Set hard weight to " << LazyInfo::HARD_WT << endl;
00214 }
00215 
00216 /*
00217  * Removes all soft clauses from the mln.
00218  */
00219 inline void removeSoftClauses()
00220 {
00221         // For each clause, remove it if not hard
00222   for (int i = 0; i < mln_->getNumClauses(); i++)
00223   {
00224         Clause* clause = (Clause *) mln_->getClause(i);
00225         if (!clause->isHardClause())
00226         {
00227           mln_->removeClause(i);
00228           i--;
00229         }
00230   }  
00231 }
00232 
00233 inline void reset()
00234 { 
00235   for(int i = 0; i < predHashArray_.size(); i++)
00236   {
00237         delete predHashArray_[i];
00238   }
00239 
00240   predHashArray_.clear();
00241   predArray_.clear();
00242 }
00243 
00244 /* check if the given atom is active or not */
00245 inline bool isActive(int atom)
00246 {
00247   return (domain_->getDB()->getActiveStatus(predArray_[atom]));
00248 }
00249 
00250 /* check if the given atom has been previously deactivated or not */
00251 inline bool isDeactivated(int atom)
00252 {
00253   return (domain_->getDB()->getDeactivatedStatus(predArray_[atom]));
00254 }
00255 
00256 /* set the given atom as active */
00257 inline void setActive(int atom)
00258 {
00259   domain_->getDB()->setActiveStatus(predArray_[atom], true);
00260 }
00261 
00262 /* set the given atom as inactive */
00263 inline void setInactive(int atom)
00264 {
00265   domain_->getDB()->setActiveStatus(predArray_[atom], false);
00266 }
00267 
00268 void updatePredArray()
00269 {
00270   int startindex = 0;
00271   if(predArray_.size() > 0)
00272     startindex = predArray_.size()-1;
00273          
00274   predArray_.growToSize(predHashArray_.size()+1);
00275   for(int i = startindex; i < predHashArray_.size(); i++)
00276     predArray_[i+1] = predHashArray_[i];
00277 }
00278 
00279   // Gets the set of all possible initial active clauses
00280 void getSupersetClauses(Array<IntClause *> &supersetClauses)
00281 {
00282     // Randomizer
00283   int seed;
00284   struct timeval tv;
00285   struct timezone tzp;
00286   gettimeofday(&tv,&tzp);
00287   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00288   srandom(seed);
00289   
00290   Clause *fclause;
00291   IntClause *intClause;
00292   int clauseCnt;
00293   IntClauseHashArray clauseHashArray;
00294 
00295   Array<IntClause *>* intClauses = new Array<IntClause *>; 
00296   
00297   clauseCnt = mln_->getNumClauses();
00298 
00299   int clauseno = 0;
00300   while(clauseno < clauseCnt)
00301   {
00302     fclause = (Clause *) mln_->getClause(clauseno);
00303     
00304     double wt = fclause->getWt();
00305     intClauses->clear();
00306     bool ignoreActivePreds = false;
00307     fclause->getActiveClauses(NULL, domain_, intClauses,
00308                               &predHashArray_, ignoreActivePreds);
00309     updatePredArray();
00310 
00311     for (int i = 0; i < intClauses->size(); i++)
00312     {
00313       intClause = (*intClauses)[i];
00314 
00315       int pos = clauseHashArray.find(intClause);
00316       if(pos >= 0)
00317       {
00318         clauseHashArray[pos]->addWt(wt);
00319         intClause->deleteIntPredicates();
00320         delete intClause;
00321         continue;
00322       }
00323            
00324       intClause->setWt(wt);
00325       clauseHashArray.append(intClause);
00326     }
00327     clauseno++;
00328   } //while(clauseno < clauseCnt)
00329       
00330   for(int i = 0; i < clauseHashArray.size(); i++)
00331   {
00332     intClause = clauseHashArray[i];
00333     supersetClauses.append(intClause);
00334   }
00335   delete intClauses;
00336 }
00337 
00338 
00339   // Perform clause selection on the superset of initial active
00340   // clauses and convert weights
00341 void selectClauses(const Array<IntClause *> &supersetClauses,
00342                    Array<Array<int> *> &walksatClauses,
00343                    Array<int> &walksatClauseWts)
00344 {
00345   assert(sampleSat_);
00346   
00347     // Randomizer
00348   int seed;
00349   struct timeval tv;
00350   struct timezone tzp;
00351   gettimeofday(&tv,&tzp);
00352   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00353   srandom(seed);
00354 
00355   for(int i = 0; i < deadClauses_.size(); i++)
00356   {
00357     deadClauses_[i]->deleteIntPredicates();
00358     delete deadClauses_[i];
00359   }
00360   deadClauses_.clearAndCompress();
00361 
00362     // Look at each clause in the superset and do clause
00363     // selection and convert weight
00364   for (int i = 0; i < supersetClauses.size(); i++)
00365   {
00366     assert(prevDB_);
00367     IntClause* intClause = supersetClauses[i];
00368     double wt = intClause->getWt();
00369     if (wt == 0) continue;
00370 
00371       // Pos. clause not satisfied in prev. iteration: don't activate
00372       // Neg. clause satisfied in prev. iteration: don't activate
00373     bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00374     if ((wt > 0 && !sat) ||
00375         (wt < 0 && sat))
00376     {
00377       continue;
00378     }
00379 
00380       // With prob. exp(-wt) don't ever activate it
00381     double threshold =
00382       intClause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00383 
00384     if (random() > threshold)
00385     {
00386       deadClauses_.append(new IntClause(*intClause));
00387       continue;
00388     }
00389 
00390       // If we're here, then clause has been selected
00391     Array<int>* litClause = (Array<int> *)intClause->getIntPredicates();
00392     walksatClauses.append(new Array<int>(*litClause));
00393     if (wt >= 0) walksatClauseWts.append(1);
00394     else walksatClauseWts.append(-1);
00395   }
00396 }
00397 
00398 
00406 void getWalksatClauses(Predicate *inputPred,
00407                                    Array<Array<int> *> &walksatClauses,
00408                                        Array<int> &walksatClauseWts,
00409                        bool const & active)
00410 {
00411     // Randomizer
00412   int seed;
00413   struct timeval tv;
00414   struct timezone tzp;
00415   gettimeofday(&tv,&tzp);
00416   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00417   srandom(seed);
00418   
00419   Clause *fclause;
00420   IntClause *intClause;
00421   int clauseCnt;
00422   IntClauseHashArray clauseHashArray;
00423 
00424   Array<IntClause *>* intClauses = new Array<IntClause *>; 
00425   
00426   const Array<IndexClause*>* indexClauses = NULL;
00427       
00428   if(inputPred == NULL)
00429   {
00430         clauseCnt = mln_->getNumClauses();
00431   }
00432   else
00433   {
00434         if (domain_->getDB()->getDeactivatedStatus(inputPred)) return;
00435         int predid = inputPred->getId(); 
00436         indexClauses = mln_->getClausesContainingPred(predid);
00437         clauseCnt = indexClauses->size();
00438   }
00439 
00440   int clauseno = 0;
00441   while(clauseno < clauseCnt)
00442   {
00443         if(inputPred)
00444           fclause = (Clause *) (*indexClauses)[clauseno]->clause;                       
00445         else
00446           fclause = (Clause *) mln_->getClause(clauseno);
00447         
00448         double wt = fclause->getWt();
00449     intClauses->clear();
00450         bool ignoreActivePreds = false;
00451 
00452     if (active)
00453     {
00454           fclause->getActiveClauses(inputPred, domain_, intClauses,
00455                                     &predHashArray_, ignoreActivePreds);
00456     }
00457     else
00458     {
00459       fclause->getInactiveClauses(inputPred, domain_, intClauses,
00460                                   &predHashArray_);
00461     }
00462     updatePredArray();
00463 cout << "intClauses size " << intClauses->size() << endl;
00464         for (int i = 0; i < intClauses->size(); i++)
00465     {
00466       intClause = (*intClauses)[i];
00467 
00468                 // If using samplesat, then do clause selection
00469           if (sampleSat_)
00470           {
00471                 assert(prevDB_);
00472 
00473                   // Pos. clause not satisfied in prev. iteration: don't activate
00474           // Neg. clause satisfied in prev. iteration: don't activate
00475         bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00476                 if ((wt >= 0 && !sat) ||
00477             (wt < 0 && sat))
00478                 {
00479                   intClause->deleteIntPredicates();
00480                   delete intClause;
00481                   continue;
00482                 }
00483 
00484           // In dead clauses: don't activate
00485         if (deadClauses_.contains(intClause))
00486         {
00487           intClause->deleteIntPredicates();
00488           delete intClause;
00489           continue;
00490         }
00491 
00492           // With prob. exp(-wt) don't ever activate it
00493                 double threshold =
00494                   fclause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00495 
00496                 if (random() > threshold)
00497                 {
00498                   deadClauses_.append(intClause);
00499                   continue;
00500                 }
00501           }
00502 
00503           int pos = clauseHashArray.find(intClause);
00504           if(pos >= 0)
00505           {
00506                 clauseHashArray[pos]->addWt(wt);
00507                 intClause->deleteIntPredicates();
00508                 delete intClause;
00509                 continue;
00510           }
00511            
00512           intClause->setWt(wt);
00513           clauseHashArray.append(intClause);
00514         }
00515         clauseno++;     
00516   } //while(clauseno < clauseCnt)
00517           
00518   Array<int>* litClause;
00519   for(int i = 0; i < clauseHashArray.size(); i++)
00520   {
00521         intClause = clauseHashArray[i];
00522         double weight = intClause->getWt();
00523         litClause = (Array<int> *)intClause->getIntPredicates();
00524         walksatClauses.append(litClause);
00525         if (sampleSat_)
00526     {
00527       if (weight >= 0) walksatClauseWts.append(1);
00528       else walksatClauseWts.append(-1);
00529     }
00530         else
00531         {
00532       if (weight >= 0)
00533         walksatClauseWts.append((int)(weight*LazyInfo::WSCALE + 0.5));
00534       else
00535         walksatClauseWts.append((int)(weight*LazyInfo::WSCALE - 0.5));
00536         }
00537         
00538     delete intClause;
00539   }
00540                  
00541   delete intClauses;
00542 }
00543 
00544 /* get all the active clauses and append to the
00545  * allClauses array */
00546 void getWalksatClauses(Array<Array<int> *> &allClauses, Array<int> &allClauseWeights)
00547 {
00548   getWalksatClauses(NULL, allClauses, allClauseWeights, true);
00549 }
00550 
00551 /* Get all the clauses which become active by flipping this atom */
00552 void getWalksatClausesWhenFlipped(int atom,            
00553                                                                   Array<Array<int> *> &walksatClauses,
00554                                   Array<int> &walksatClauseWts)
00555 {
00556   Predicate *pred = predArray_[atom];
00557   TruthValue oldval = domain_->getDB()->getValue(pred);
00558   TruthValue val;
00559   (oldval == TRUE)? val=FALSE : val = TRUE;
00560   
00561     // Used to store other pred flipped in block
00562   Predicate* otherPred = NULL;
00563   inBlock_ = false;
00564   
00565   int blockIdx = domain_->getBlock(pred);
00566   if (blockIdx >= 0)
00567   {
00568       // Dealing with pred in a block
00569     const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
00570     if (block->size() > 1)
00571     {
00572       inBlock_ = true;
00573       int chosen = -1;
00574         // 1->0: Pick one at random to flip
00575       if (oldval == TRUE)
00576       {
00577         bool ok = false;
00578         while(!ok)
00579         {
00580           chosen = random() % block->size();
00581           if (!pred->same((*block)[chosen]))
00582             ok = true;
00583         }
00584         assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
00585       }
00586         // 0->1: Flip the pred with value 1
00587       else
00588       {
00589         for (int i = 0; i < block->size(); i++)
00590         {
00591           if (domain_->getDB()->getValue((*block)[i]) == TRUE)
00592           {
00593             chosen = i;
00594             assert(!pred->same((*block)[chosen]));
00595             break;
00596           }
00597         }
00598       }
00599       assert(chosen >= 0);
00600       otherPred = (*block)[chosen];
00601         // Set new value of other atom in block
00602       domain_->getDB()->setValue(otherPred, oldval);  
00603     }
00604   }
00605     // Set new value of actual atom
00606   domain_->getDB()->setValue(pred, val);
00607 
00608   getWalksatClauses(pred, walksatClauses, walksatClauseWts, true);
00609   if (inBlock_)
00610     getWalksatClauses(otherPred, walksatClauses, walksatClauseWts, true);
00611     
00612     // Set old value of atom and other in block
00613   domain_->getDB()->setValue(pred,oldval);
00614   if (inBlock_)
00615   {
00616     domain_->getDB()->setValue(otherPred, val);
00617     
00618       // Set index of other pred
00619     for(int atom = 1; atom <= getVarCount(); atom++)
00620     {
00621       if (otherPred->same(predArray_[atom]))
00622       {
00623         otherAtom_ = atom;
00624         break;
00625       }
00626     }
00627   }
00628 }
00629 
00630 /* Get the cost of all the clauses which become unsatisfied by flipping this atom */
00631 int getUnSatCostPerPred(Predicate* pred,
00632                         Array<Array<int> *> &walksatClauses,
00633                         Array<int> &walksatClauseWts)
00634 {
00635     // Randomizer
00636   int seed;
00637   struct timeval tv;
00638   struct timezone tzp;
00639   gettimeofday(&tv,&tzp);
00640   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00641   srandom(seed);
00642   
00643   int unSatCost = 0;
00644   Clause *clause;
00645   IntClause *intClause;
00646   IntClauseHashArray clauseHashArray;
00647 
00648   Array<IntClause *>* intClauses = new Array<IntClause *>; 
00649 
00650   const Array<IndexClause*>* indclauses;
00651   int predid = pred->getId(); 
00652   indclauses = mln_->getClausesContainingPred(predid);
00653   
00654   for(int j = 0; j < indclauses->size(); j++)
00655   {      
00656     clause = (Clause *) (*indclauses)[j]->clause;
00657 /*    
00658     double weight = clause->getWt()*LazyInfo::WSCALE;
00659     //int wt = abs((int)(clause->getWt()*LazyInfo::WSCALE+0.5));
00660     int wt;
00661 
00662       // Samplesat: all weights are 1 or -1
00663     if (sampleSat_)
00664     {
00665       if (weight >= 0)
00666         wt = 1;
00667       else
00668         wt = -1;
00669     }
00670     else
00671     {
00672       if (weight >= 0)
00673         wt = (int)(weight + 0.5);
00674       else
00675         wt = (int)(weight - 0.5);
00676     }
00677 */
00678     double wt = clause->getWt();
00679     intClauses->clear();
00680 
00681     if(abs(wt) < WEIGHT_EPSILON)
00682       continue;
00683       
00684     bool ignoreActivePreds = true;
00685     clause->getActiveClauses(pred, domain_, intClauses,
00686                              &predHashArray_, ignoreActivePreds);
00687     updatePredArray();
00688     
00689     for (int i = 0; i < intClauses->size(); i++)
00690     {
00691       intClause = (*intClauses)[i];
00692 
00693         // If using samplesat, then do clause selection
00694       if (sampleSat_)
00695       {
00696         assert(prevDB_);
00697 
00698 //cout << i << " wt " << wt << endl;
00699 
00700           // Pos. clause not satisfied in prev. iteration: don't activate
00701           // Neg. clause satisfied in prev. iteration: don't activate
00702         bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00703         if ((wt >= 0 && !sat) ||
00704             (wt < 0 && sat))
00705         {
00706           intClause->deleteIntPredicates();
00707           delete intClause;
00708           continue;
00709         }
00710 
00711           // In dead clauses: don't activate
00712         if (deadClauses_.contains(intClause))
00713         {
00714           intClause->deleteIntPredicates();
00715           delete intClause;
00716           continue;
00717         }
00718 
00719           // With prob. exp(-wt) don't ever activate it
00720         double threshold =
00721           clause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00722 
00723         if (random() > threshold)
00724         {
00725           deadClauses_.append(intClause);
00726           continue;
00727         }
00728       }
00729 
00730       int pos = clauseHashArray.find(intClause);
00731       if(pos >= 0)
00732       {
00733         clauseHashArray[pos]->addWt(wt);
00734         intClause->deleteIntPredicates();
00735         delete intClause;
00736         continue;
00737       }
00738 
00739 //cout << "Hard wt. " << LazyInfo::HARD_WT << endl;           
00740 //cout << i << " wt2 " << wt << endl;
00741       if (wt == LazyInfo::HARD_WT) intClause->setWtToHardWt();
00742       else intClause->setWt(wt);
00743       clauseHashArray.append(intClause);
00744 //cout << i << " intClause->setWt " << intClause->getWt() << endl;
00745     }
00746     
00747   }
00748   
00749   Array<int>* litClause;
00750   for(int i = 0; i < clauseHashArray.size(); i++)
00751   {
00752     intClause = clauseHashArray[i];
00753 //cout << i << " intClause->getWt() " << intClause->getWt() << endl;
00754     int weight = (int)(intClause->getWt());
00755     litClause = (Array<int> *)intClause->getIntPredicates();
00756     walksatClauses.append(litClause);
00757     if (sampleSat_)
00758     {
00759       if (weight >= 0) walksatClauseWts.append(1);
00760       else walksatClauseWts.append(-1);
00761       unSatCost += 1;
00762     }
00763     else
00764     {
00765       if (weight >= 0)
00766       {
00767         walksatClauseWts.append((int)(weight*LazyInfo::WSCALE + 0.5));
00768         unSatCost += (int)(weight*LazyInfo::WSCALE + 0.5);
00769       }
00770       else
00771       {
00772         walksatClauseWts.append((int)(weight*LazyInfo::WSCALE - 0.5));
00773         unSatCost += (int)(weight*LazyInfo::WSCALE - 0.5);
00774       }
00775     }
00776 
00777     //walksatClauseWts.append(weight);
00778 //cout << i << " Weight " << weight << endl;
00779     //unSatCost += abs(weight);
00780     delete intClause;
00781   }
00782   delete intClauses;
00783   
00784   return unSatCost;
00785 }
00786 
00787 /* 
00788  * Get the cost of all the clauses which become unsatisfied by flipping this atom,
00789  * taking blocks into account
00790  */
00791 int getUnSatCostWhenFlipped(int atom,
00792                             Array<Array<int> *> &walksatClauses,
00793                             Array<int> &walksatClauseWts)
00794 {        
00795   int unSatCost = 0;
00796   TruthValue val;
00797   Predicate *pred = predArray_[atom];
00798   TruthValue oldval = domain_->getDB()->getValue(pred);
00799   (oldval == TRUE) ? val = FALSE: val = TRUE;
00800 
00801     // Used to store other pred in block and its previous value
00802   Predicate* otherPred = NULL;
00803   bool inBlock = false;
00804   
00805   int blockIdx = domain_->getBlock(pred);
00806   if (blockIdx >= 0)
00807   {
00808       // Dealing with pred in a block
00809     const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
00810     if (block->size() > 1)
00811     {
00812       inBlock = true;
00813       int chosen = -1;
00814         // 1->0: Pick one at random to flip
00815       if (oldval == TRUE)
00816       {
00817         bool ok = false;
00818         while(!ok)
00819         {
00820           chosen = random() % block->size();
00821           if (!pred->same((*block)[chosen]))
00822             ok = true;
00823         }
00824         assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
00825       }
00826         // 0->1: Flip the pred with value 1
00827       else
00828       {
00829         for (int i = 0; i < block->size(); i++)
00830         {
00831           if (domain_->getDB()->getValue((*block)[i]) == TRUE)
00832           {
00833             chosen = i;
00834             assert(!pred->same((*block)[chosen]));
00835             break;
00836           }
00837         }
00838       }
00839       assert(chosen >= 0);
00840       otherPred = (*block)[chosen];
00841         // Set new value of other atom in block
00842       domain_->getDB()->setValue(otherPred, oldval);  
00843     }
00844   }
00845     // Set new value of actual atom
00846   domain_->getDB()->setValue(pred, val);
00847 
00848     // Get unsat cost of pred and other pred
00849   unSatCost += getUnSatCostPerPred(pred, walksatClauses, walksatClauseWts);
00850   if (inBlock)
00851     unSatCost += getUnSatCostPerPred(otherPred, walksatClauses, walksatClauseWts);
00852     
00853     // Set old value of atom and other in block
00854   domain_->getDB()->setValue(pred,oldval);
00855   if (inBlock)
00856     domain_->getDB()->setValue(otherPred, val);  
00857   return unSatCost;
00858 }
00859 
00860 /* set Database Values to the given array of values */
00861 void setVarVals(int newVals[])
00862 {
00863   for(int atom = 1; atom <= getVarCount(); atom++)
00864   {
00865     Predicate *ipred = predArray_[atom];
00866         if(newVals[atom] == 1)
00867           domain_->getDB()->setValue(ipred,TRUE);
00868         else
00869           domain_->getDB()->setValue(ipred,FALSE);
00870   }
00871 }
00872 
00873 /* flip the predicate corresponding to given atom */
00874 void flipVar(int atom)
00875 {
00876   Predicate *ipred = predArray_[atom];
00877   TruthValue val;
00878   val = domain_->getDB()->getValue(ipred);
00879   if(val == TRUE)
00880     domain_->getDB()->setValue(ipred,FALSE);
00881   else
00882         domain_->getDB()->setValue(ipred,TRUE);
00883 }
00884 
00885 /* get the value (true/false) of the particular variable */
00886 bool getVarVal(int atom)
00887 {
00888   Predicate *ipred = predArray_[atom];
00889   bool val;
00890   (domain_->getDB()->getValue(ipred) == TRUE)? val = true : val = false;
00891   return val;
00892 }
00893 
00894 /* sets the value (true/false) of the particular variable */
00895 void setVarVal(int atom, bool val)
00896 {
00897   Predicate *ipred = predArray_[atom];
00898   if (val)
00899     domain_->getDB()->setValue(ipred, TRUE);
00900   else
00901     domain_->getDB()->setValue(ipred, FALSE);
00902 }
00903 
00904 Predicate* getVar(int atom)
00905 {
00906   return predArray_[atom];      
00907 }
00908 
00909 void setAllActive()
00910 {
00911   LazyUtil::setAllActive(domain_);
00912 }
00913 
00914 void setAllInactive()
00915 {
00916   for(int atom = 1; atom <= getVarCount(); atom++)
00917   {
00918         Predicate* ipred = predArray_[atom];
00919         domain_->getDB()->setActiveStatus(ipred, false);
00920   }
00921 }
00922 
00923 void setAllFalse()
00924 {
00925   for(int atom = 1; atom <= getVarCount(); atom++)
00926   {
00927         Predicate* ipred = predArray_[atom];
00928         domain_->getDB()->setValue(ipred, FALSE);
00929   }
00930 }
00931 
00932 /* Deactivates and removes atoms at indices indicated */
00933 void removeVars(Array<int> indices)
00934 {
00935   for (int i = 0; i < indices.size(); i++)
00936   {
00937         Predicate* pred = predArray_[indices[i]];
00938         domain_->getDB()->setActiveStatus(pred, false);
00939         domain_->getDB()->setDeactivatedStatus(pred, true);
00940   }
00941   reset();
00942   //cout << "reset " << endl;
00943 }
00944 
00945 void setSampleSat(bool s)
00946 {
00947   sampleSat_ = s;
00948 }
00949 
00950 bool getSampleSat()
00951 {
00952   return sampleSat_;
00953 }
00954 
00955   //Make a deep copy of the assignments from last iteration
00956 void setPrevDB()
00957 {
00958   if (prevDB_) { delete prevDB_; prevDB_ = NULL; }
00959   prevDB_ = new Database((*domain_->getDB()));
00960 }
00961 
00962 int getNumDBAtoms()
00963 {
00964   return numDBAtoms_;
00965 }
00966 
00967 /*
00968  * Attempts to activate a random atom not already in memory and gets the
00969  * clauses activated by doing this. Returns true if successful, otherwise false.
00970  */
00971 bool activateRandomAtom(Array<Array<int> *> &walksatClauses,
00972                         Array<int> &walksatClauseWts,
00973                         int& toflip)
00974 {
00975     // Randomizer
00976   int seed;
00977   struct timeval tv;
00978   struct timezone tzp;
00979   gettimeofday(&tv,&tzp);
00980   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00981   srandom(seed);
00982 
00983   Predicate* pred;
00984   pred = domain_->getNonEvidenceAtom(random() % numDBAtoms_);
00985 
00986   int pos = -1;
00987   if ((pos = predHashArray_.find(pred)) == -1)
00988   {
00989     domain_->getDB()->setActiveStatus(pred, true);
00990     predHashArray_.append(pred);
00991     updatePredArray();
00992     toflip = predHashArray_.size();
00993     getWalksatClauses(pred, walksatClauses, walksatClauseWts, true);
00994     return true;
00995   }
00996   else
00997   {
00998     delete pred;
00999     toflip = pos + 1;
01000     return false;
01001   }
01002 }
01003 
01004 /* If atom is in block, then another atom is chosen to flip and activated */
01005 void chooseOtherToFlip(int atom,
01006                        Array<Array<int> *> &walksatClauses,
01007                        Array<int> &walksatClauseWts)
01008 {
01009     // atom is assumed to be active
01010   assert(isActive(atom));
01011   Predicate *pred = predArray_[atom];
01012   TruthValue oldval = domain_->getDB()->getValue(pred);
01013   TruthValue val;
01014   (oldval == TRUE)? val = FALSE : val = TRUE;
01015   
01016     // Used to store other pred flipped in block
01017   Predicate* otherPred = NULL;
01018   inBlock_ = false;
01019   int blockIdx = domain_->getBlock(pred);
01020 
01021   if (blockIdx < 0)
01022   {
01023     return;
01024   }
01025   else
01026   {
01027       // Dealing with pred in a block
01028     const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
01029     if (block->size() > 1)
01030     {
01031       inBlock_ = true;
01032       int chosen = -1;
01033         // 1->0: Pick one at random to flip
01034       if (oldval == TRUE)
01035       {
01036         bool ok = false;
01037         while(!ok)
01038         {
01039           chosen = random() % block->size();
01040           if (!pred->same((*block)[chosen]))
01041             ok = true;
01042         }
01043         assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
01044       }
01045         // 0->1: Flip the pred with value 1
01046       else
01047       {
01048         for (int i = 0; i < block->size(); i++)
01049         {
01050           if (domain_->getDB()->getValue((*block)[i]) == TRUE)
01051           {
01052             chosen = i;
01053             assert(!pred->same((*block)[chosen]));
01054             break;
01055           }
01056         }
01057       }
01058       assert(chosen >= 0);
01059       otherPred = (*block)[chosen];
01060         // Set new value of other atom in block
01061       domain_->getDB()->setValue(otherPred, oldval);  
01062     }
01063   }
01064     // Set new value of actual atom
01065   domain_->getDB()->setValue(pred, val);
01066   if (inBlock_)
01067     getWalksatClauses(otherPred, walksatClauses, walksatClauseWts, true);
01068     
01069     // Set old value of atom and other in block
01070   domain_->getDB()->setValue(pred,oldval);
01071   if (inBlock_)
01072   {
01073     domain_->getDB()->setValue(otherPred, val);
01074     
01075       // Set index of other pred
01076     for(int atom = 1; atom <= getVarCount(); atom++)
01077     {
01078       if (otherPred->same(predArray_[atom]))
01079       {
01080         otherAtom_ = atom;
01081         break;
01082       }
01083     }
01084   }
01085 }
01086 
01087   // Sets the truth values of all atoms in the
01088   // block blockIdx except for the one with index atomIdx
01089 void setOthersInBlockToFalse(const int& atomIdx,
01090                              const int& blockIdx)
01091 {
01092   const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
01093   for (int i = 0; i < block->size(); i++)
01094   {
01095     if (i != atomIdx)
01096       domain_->getDB()->setValue((*block)[i], FALSE);
01097   }
01098 }
01099 
01100   // Makes an assignment to all preds in blocks which does not
01101   // violate the blocks. One randomly chosen pred in each box is set to true.
01102 void initBlocks()
01103 {
01104   const Array<Array<Predicate*>*>* blocks = domain_->getPredBlocks();
01105   const Array<bool>* blockEvidence = domain_->getBlockEvidenceArray();
01106   
01107   for (int i = 0; i < blocks->size(); i++)
01108   {
01109     Array<Predicate*>* block = (*blocks)[i];
01110     int chosen = -1;
01111       // If evidence atom exists, then all others are false
01112     if ((*blockEvidence)[i])
01113     {
01114       chosen = domain_->getEvidenceIdxInBlock(i);
01115       setOthersInBlockToFalse(chosen, i);
01116       continue;
01117     }
01118     else
01119     {
01120         // Set one at random to true
01121       chosen = random() % block->size();
01122     }
01123     assert(chosen >= 0);
01124     domain_->getDB()->setValue((*block)[chosen], TRUE);
01125     setOthersInBlockToFalse(chosen, i);
01126   }
01127 }
01128 
01129 int getBlock(int atom)
01130 {
01131   const Array<Array<Predicate*>*>* blocks = domain_->getPredBlocks();
01132   Predicate* pred = getVar(atom);
01133   
01134   for (int i = 0; i < blocks->size(); i++)
01135   {
01136     Array<Predicate*>* block = (*blocks)[i];
01137     for (int j = 0; j < block->size(); j++)
01138     {
01139       if ((*block)[j]->same(pred))
01140         return i;
01141     }
01142   }
01143   return -1;
01144 }
01145 
01146 bool inBlock(int atom)
01147 {
01148   return (getBlock(atom) >= 0);
01149 }
01150 
01151 bool inBlockWithEvidence(int atom)
01152 {
01153   int blockIdx = getBlock(atom);
01154   if (blockIdx >= 0 && domain_->getBlockEvidence(blockIdx))
01155     return true;
01156   return false;
01157 }
01158 
01159 bool getInBlock()
01160 {
01161   return inBlock_;
01162 }
01163 
01164 void setInBlock(const bool val)
01165 {
01166   inBlock_ = val;
01167 }
01168 
01169 int getOtherAtom()
01170 {
01171   assert(inBlock_ && otherAtom_ >= 0);
01172   return otherAtom_;
01173 }
01174 
01175 void setOtherAtom(const int val)
01176 {
01177   otherAtom_ = val;
01178 }
01179 
01180   // Set evidence status of given atom to given value
01181 void setEvidence(const int atom, const bool val)
01182 {
01183   domain_->getDB()->setEvidenceStatus(predArray_[atom], val);
01184 }
01185 
01186   // Get evidence status of given atom
01187 bool getEvidence(const int atom)
01188 {
01189   return domain_->getDB()->getEvidenceStatus(predArray_[atom]);
01190 }
01191 
01192 void incrementNumDBAtoms()
01193 {
01194   numDBAtoms_++;
01195 }
01196 
01197 void decrementNumDBAtoms()
01198 {
01199   numDBAtoms_--;
01200 }
01201 
01202 void printIntClauses(Array<IntClause *> clauses)
01203 {
01204   for (int i = 0; i < clauses.size(); i++)
01205   {
01206     clauses[i]->printWithWtAndStrVar(cout, domain_, &predHashArray_);
01207     cout << endl;
01208   }
01209 }
01210 
01211 
01212 /* 
01213  * Fixes atoms using unit propagation on non-active clauses.
01214  * Fixed atoms are stored in fixedAtoms_ and set as evidence in the DB.
01215  */
01216 void propagateFixedAtoms(Array<Array<int> *> &clauses,
01217                          Array<int> &clauseWeights,
01218                          bool* fixedAtoms,
01219                          int maxFixedAtoms)
01220 {
01221   Array<Array<int> *> tmpClauses;
01222   Array<int> tmpClauseWeights;
01223     
01224     //TEMP: count fixed atoms
01225   int count = 0;
01226   for (int i = 0; i < maxFixedAtoms; i++)
01227   {
01228     if (fixedAtoms[i]) count++;
01229   }
01230   cout << "Fixed atoms before propagating: " << count << endl;
01231   
01232     // Look at all non-active clauses containing a fixed atom
01233   for (int i = 0; i < maxFixedAtoms; i++)
01234   {
01235     if (fixedAtoms[i])
01236     {
01237         // Get all non-active clauses containing it,
01238         // filtering out dead clauses
01239       getWalksatClauses(predArray_[i], tmpClauses, tmpClauseWeights, false);
01240 cout << "Atom " << i << endl;
01241 cout << "Clauses " << tmpClauses.size() << endl;
01242 
01243         // clauses now contains potential clauses of all length
01244         // want to look at only unit clauses
01245       for (int j = 0; j < tmpClauses.size(); j++)
01246       {
01247           // Unit clause
01248         if (tmpClauses[j]->size() == 1)
01249         {
01250           clauses.append(tmpClauses[j]);
01251           clauseWeights.append(clauseWeights[j]);
01252           fixedAtoms[(*tmpClauses[j])[0]] = true;
01253         }
01254         else // Not unit clause
01255         {
01256           delete tmpClauses[j];
01257         }
01258       }
01259     }
01260     tmpClauses.clear();
01261     tmpClauseWeights.clear();
01262   }
01263   
01264       //TEMP: count fixed atoms
01265   count = 0;
01266   for (int i = 0; i < maxFixedAtoms; i++)
01267   {
01268     if (fixedAtoms[i]) count++;
01269   }
01270   cout << "Fixed atoms after propagating: " << count << endl;
01271   
01272   exit(0);
01273 }
01274 
01275 
01276 public:
01277         //Weight of hard clauses
01278   static double HARD_WT;
01279         //Scale to use on weights
01280   static double WSCALE;
01281 
01282 private:
01283 
01284   MLN *mln_;
01285   Domain *domain_;
01286   
01287   Array<Predicate *> predArray_;
01288   PredicateHashArray predHashArray_;
01289 
01290   Array<Array<int> *> predToClauseIds_;
01291 
01292         // Set to true when performing sample sat
01293   bool sampleSat_;
01294         // Dead clauses not to be considered for activation
01295   IntClauseHashArray deadClauses_;
01296         // DB holding assignment from previous iteration
01297   Database* prevDB_;
01298     // Total number of non-evidence ground atoms in the DB
01299   int numDBAtoms_;
01300     // Flag indicating if an atom in a block is being flipped
01301   bool inBlock_;
01302     // Index of other atom to be flipped in block
01303   int otherAtom_;
01304 
01305 };
01306 
01307 #endif
01308 

Generated on Tue Jan 16 05:30:02 2007 for Alchemy by  doxygen 1.5.1