lwinfo.h

00001 #ifndef _LW_Info_H_
00002 #define _LW_Info_H_
00003 
00004 #include<fstream>
00005 #include <sys/time.h>
00006 
00007 #include "clause.h"
00008 #include "clausefactory.h"
00009 #include "lwutil.h"
00010 #include "intclause.h"
00011 #include "mln.h"
00012 #include "timer.h"
00013 
00014 class LWInfo
00015 {
00016 public:
00017 
00018  //constructor   
00019 LWInfo(MLN *mln, Domain *domain)
00020 {
00021       // Randomizer
00022   int seed;
00023   struct timeval tv;
00024   struct timezone tzp;
00025   gettimeofday(&tv,&tzp);
00026   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00027   srandom(seed);
00028     
00029   this->mln_ = mln;
00030   this->domain_ = domain;
00031   copyMLN();
00032   setHardClauseWeight();
00033   sampleSat_ = false;
00034   prevDB_ = NULL;
00035   numDBAtoms_ = domain_->getNumNonEvidenceAtoms();
00036   initBlocks();
00037   inBlock_ = false;
00038 }
00039   
00040  //destructor
00041 ~LWInfo()
00042 {
00043 
00044   for(int i = 0; i < predHashArray_.size(); i++)
00045         delete predHashArray_[i];
00046 
00047     // mln_ is a copy made in copyMLN()
00048   delete mln_;
00049 
00050   if (prevDB_) delete prevDB_;
00051 
00052   for(int i = 0; i < deadClauses_.size(); i++)
00053   {
00054         deadClauses_[i]->deleteIntPredicates();
00055         delete deadClauses_[i];
00056   }
00057 }
00058  
00059 inline int getVarCount() { return predHashArray_.size(); }
00060 
00061 /*
00062  * Makes a copy of the mln.
00063  */
00064 inline void copyMLN()
00065 {
00066 
00067   MLN* posmln = new MLN();
00068   int clauseCnt = mln_->getNumClauses();
00069 
00070         // For each clause, add it to copy
00071   for (int i = 0; i < clauseCnt; i++)
00072   {
00073         Clause* clause = (Clause *) mln_->getClause(i);
00074     if (clause->getWt() == 0) continue;
00075         int numPreds = clause->getNumPredicates();
00076 
00077       // Add clause to mln copy
00078     int idx;
00079     ostringstream oss;
00080     Clause* newClause = new Clause(*clause);
00081     newClause->printWithoutWtWithStrVar(oss, domain_);
00082     bool app = posmln->appendClause(oss.str(), false, newClause,
00083                                                                   clause->getWt(), clause->isHardClause(), idx);
00084     if (app)
00085     {
00086       posmln->setFormulaNumPreds(oss.str(), numPreds);
00087       posmln->setFormulaIsHard(oss.str(), clause->isHardClause());
00088       posmln->setFormulaPriorMean(oss.str(), clause->getWt());
00089     }
00090   }
00091   mln_ = posmln;
00092   //mln_->printMLN(cout, domain_);
00093 }
00094 
00095 /*
00096  * Computes the hard clause weight from the first order clauses in the mln.
00097  */
00098 inline void setHardClauseWeight()
00099 {
00100         // This is the weight used if no soft clauses are present
00101   LWInfo::HARD_WT = 10.0;
00102   
00103   int clauseCnt = mln_->getNumClauses();
00104   double sumSoftWts = 0.0;
00105   double minWt = DBL_MAX;
00106   double maxWt = DBL_MIN;
00107     //this is a reasonable number to prevent the external executable 
00108     //MaxWalksat from have having overflow errors
00109   int maxAllowedWt = 4000;
00110   
00111         // Sum up the soft weights of all grounded clauses
00112   for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00113   {
00114     Clause* fclause = (Clause *) mln_->getClause(clauseno);
00115         
00116         if (fclause->isHardClause()) continue;
00117         
00118       // Weight could be negative
00119         double wt = abs(fclause->getWt());
00120 
00121     double numGndings = fclause->getNumGroundings(domain_);
00122 
00123         if (wt < minWt) minWt = wt;
00124     if (wt > maxWt) maxWt = wt;
00125     sumSoftWts += wt*numGndings;
00126     assert(minWt >= 0);
00127     assert(maxWt >= 0);
00128   } //for(clauseno < clauseCnt)
00129   assert(sumSoftWts >= 0);
00130 
00131         // If at least one soft clause
00132   if (sumSoftWts > 0)
00133   {
00134           //find out how much to scale weights by
00135         LWInfo::WSCALE = 1.0;
00136     if (maxWt > maxAllowedWt) LWInfo::WSCALE = maxAllowedWt/maxWt;
00137     else
00138     {
00139       if (minWt < 10)
00140       {
00141         LWInfo::WSCALE = 10/minWt;
00142         if (LWInfo::WSCALE*maxWt > maxAllowedWt) LWInfo::WSCALE = maxAllowedWt/maxWt;
00143       }
00144     }
00145 
00146         LWInfo::HARD_WT = (sumSoftWts + 10.0)*LWInfo::WSCALE;
00147   }
00148   //cout << "Set hard weight to " << LWInfo::HARD_WT << endl;
00149 }
00150 
00151 /*
00152  * Removes all soft clauses from the mln.
00153  */
00154 inline void removeSoftClauses()
00155 {
00156         // For each clause, remove it if not hard
00157   for (int i = 0; i < mln_->getNumClauses(); i++)
00158   {
00159         Clause* clause = (Clause *) mln_->getClause(i);
00160         if (!clause->isHardClause())
00161         {
00162           mln_->removeClause(i);
00163           i--;
00164         }
00165   }  
00166 }
00167 
00168 inline void reset()
00169 { 
00170   for(int i = 0; i < predHashArray_.size(); i++)
00171   {
00172         delete predHashArray_[i];
00173   }
00174 
00175   predHashArray_.clear();
00176   predArray_.clear();
00177 }
00178 
00179 /* check if the given atom is active or not */
00180 inline bool isActive(int atom)
00181 {
00182   return (domain_->getDB()->getActiveStatus(predArray_[atom]));
00183 }
00184 
00185 /* check if the given atom has been previously deactivated or not */
00186 inline bool isDeactivated(int atom)
00187 {
00188   return (domain_->getDB()->getDeactivatedStatus(predArray_[atom]));
00189 }
00190 
00191 /* set the given atom as active */
00192 inline void setActive(int atom)
00193 {
00194   domain_->getDB()->setActiveStatus(predArray_[atom], true);
00195 }
00196 
00197 /* set the given atom as inactive */
00198 inline void setInactive(int atom)
00199 {
00200   domain_->getDB()->setActiveStatus(predArray_[atom], false);
00201 }
00202 
00203 void updatePredArray()
00204 {
00205   int startindex = 0;
00206   if(predArray_.size() > 0)
00207     startindex = predArray_.size()-1;
00208          
00209   predArray_.growToSize(predHashArray_.size()+1);
00210   for(int i = startindex; i < predHashArray_.size(); i++)
00211     predArray_[i+1] = predHashArray_[i];
00212 }
00213 
00214   // Gets the set of all possible initial active clauses
00215 void getSupersetClauses(Array<IntClause *> &supersetClauses)
00216 {
00217     // Randomizer
00218   int seed;
00219   struct timeval tv;
00220   struct timezone tzp;
00221   gettimeofday(&tv,&tzp);
00222   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00223   srandom(seed);
00224   
00225   Clause *fclause;
00226   IntClause *intClause;
00227   int clauseCnt;
00228   IntClauseHashArray clauseHashArray;
00229 
00230   Array<IntClause *>* intClauses = new Array<IntClause *>; 
00231   
00232   clauseCnt = mln_->getNumClauses();
00233 
00234   int clauseno = 0;
00235   while(clauseno < clauseCnt)
00236   {
00237     fclause = (Clause *) mln_->getClause(clauseno);
00238     
00239     double wt = fclause->getWt();
00240     intClauses->clear();
00241     bool ignoreActivePreds = false;
00242     fclause->getActiveClauses(NULL, domain_, intClauses,
00243                               &predHashArray_, ignoreActivePreds);
00244     updatePredArray();
00245 
00246     for (int i = 0; i < intClauses->size(); i++)
00247     {
00248       intClause = (*intClauses)[i];
00249 
00250       int pos = clauseHashArray.find(intClause);
00251       if(pos >= 0)
00252       {
00253         clauseHashArray[pos]->addWt(wt);
00254         intClause->deleteIntPredicates();
00255         delete intClause;
00256         continue;
00257       }
00258            
00259       intClause->setWt(wt);
00260       clauseHashArray.append(intClause);
00261     }
00262     clauseno++;
00263   } //while(clauseno < clauseCnt)
00264       
00265   for(int i = 0; i < clauseHashArray.size(); i++)
00266   {
00267     intClause = clauseHashArray[i];
00268     supersetClauses.append(intClause);
00269   }
00270   delete intClauses;
00271 }
00272 
00273 
00274   // Perform clause selection on the superset of initial active
00275   // clauses and convert weights
00276 void selectClauses(const Array<IntClause *> &supersetClauses,
00277                    Array<Array<int> *> &walksatClauses,
00278                    Array<int> &walksatClauseWts)
00279 {
00280   assert(sampleSat_);
00281   
00282     // Randomizer
00283   int seed;
00284   struct timeval tv;
00285   struct timezone tzp;
00286   gettimeofday(&tv,&tzp);
00287   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00288   srandom(seed);
00289 
00290   for(int i = 0; i < deadClauses_.size(); i++)
00291   {
00292     deadClauses_[i]->deleteIntPredicates();
00293     delete deadClauses_[i];
00294   }
00295   deadClauses_.clearAndCompress();
00296 
00297     // Look at each clause in the superset and do clause
00298     // selection and convert weight
00299   for (int i = 0; i < supersetClauses.size(); i++)
00300   {
00301     assert(prevDB_);
00302     IntClause* intClause = supersetClauses[i];
00303     double wt = intClause->getWt();
00304     if (wt == 0) continue;
00305 
00306       // Pos. clause not satisfied in prev. iteration: don't activate
00307       // Neg. clause satisfied in prev. iteration: don't activate
00308     bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00309     if ((wt > 0 && !sat) ||
00310         (wt < 0 && sat))
00311     {
00312       continue;
00313     }
00314 
00315       // With prob. exp(-wt) don't ever activate it
00316     double threshold =
00317       intClause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00318 
00319     if (random() > threshold)
00320     {
00321       deadClauses_.append(new IntClause(*intClause));
00322       continue;
00323     }
00324 
00325       // If we're here, then clause has been selected
00326     Array<int>* litClause = (Array<int> *)intClause->getIntPredicates();
00327     walksatClauses.append(new Array<int>(*litClause));
00328     if (wt >= 0) walksatClauseWts.append(1);
00329     else walksatClauseWts.append(-1);
00330   }
00331 }
00332 
00333 
00341 void getWalksatClauses(Predicate *inputPred,
00342                                    Array<Array<int> *> &walksatClauses,
00343                                        Array<int> &walksatClauseWts,
00344                        bool const & active)
00345 {
00346     // Randomizer
00347   int seed;
00348   struct timeval tv;
00349   struct timezone tzp;
00350   gettimeofday(&tv,&tzp);
00351   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00352   srandom(seed);
00353   
00354   Clause *fclause;
00355   IntClause *intClause;
00356   int clauseCnt;
00357   IntClauseHashArray clauseHashArray;
00358 
00359   Array<IntClause *>* intClauses = new Array<IntClause *>; 
00360   
00361   const Array<IndexClause*>* indexClauses = NULL;
00362       
00363   if(inputPred == NULL)
00364   {
00365         clauseCnt = mln_->getNumClauses();
00366   }
00367   else
00368   {
00369         if (domain_->getDB()->getDeactivatedStatus(inputPred)) return;
00370         int predid = inputPred->getId(); 
00371         indexClauses = mln_->getClausesContainingPred(predid);
00372         clauseCnt = indexClauses->size();
00373   }
00374 
00375   int clauseno = 0;
00376   while(clauseno < clauseCnt)
00377   {
00378         if(inputPred)
00379           fclause = (Clause *) (*indexClauses)[clauseno]->clause;                       
00380         else
00381           fclause = (Clause *) mln_->getClause(clauseno);
00382         
00383         double wt = fclause->getWt();
00384     intClauses->clear();
00385         bool ignoreActivePreds = false;
00386 
00387     if (active)
00388     {
00389           fclause->getActiveClauses(inputPred, domain_, intClauses,
00390                                     &predHashArray_, ignoreActivePreds);
00391     }
00392     else
00393     {
00394       fclause->getInactiveClauses(inputPred, domain_, intClauses,
00395                                   &predHashArray_);
00396     }
00397     updatePredArray();
00398 cout << "intClauses size " << intClauses->size() << endl;
00399         for (int i = 0; i < intClauses->size(); i++)
00400     {
00401       intClause = (*intClauses)[i];
00402 
00403                 // If using samplesat, then do clause selection
00404           if (sampleSat_)
00405           {
00406                 assert(prevDB_);
00407 
00408                   // Pos. clause not satisfied in prev. iteration: don't activate
00409           // Neg. clause satisfied in prev. iteration: don't activate
00410         bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00411                 if ((wt >= 0 && !sat) ||
00412             (wt < 0 && sat))
00413                 {
00414                   intClause->deleteIntPredicates();
00415                   delete intClause;
00416                   continue;
00417                 }
00418 
00419           // In dead clauses: don't activate
00420         if (deadClauses_.contains(intClause))
00421         {
00422           intClause->deleteIntPredicates();
00423           delete intClause;
00424           continue;
00425         }
00426 
00427           // With prob. exp(-wt) don't ever activate it
00428                 double threshold =
00429                   fclause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00430 
00431                 if (random() > threshold)
00432                 {
00433                   deadClauses_.append(intClause);
00434                   continue;
00435                 }
00436           }
00437 
00438           int pos = clauseHashArray.find(intClause);
00439           if(pos >= 0)
00440           {
00441                 clauseHashArray[pos]->addWt(wt);
00442                 intClause->deleteIntPredicates();
00443                 delete intClause;
00444                 continue;
00445           }
00446            
00447           intClause->setWt(wt);
00448           clauseHashArray.append(intClause);
00449         }
00450         clauseno++;     
00451   } //while(clauseno < clauseCnt)
00452           
00453   Array<int>* litClause;
00454   for(int i = 0; i < clauseHashArray.size(); i++)
00455   {
00456         intClause = clauseHashArray[i];
00457         double weight = intClause->getWt();
00458         litClause = (Array<int> *)intClause->getIntPredicates();
00459         walksatClauses.append(litClause);
00460         if (sampleSat_)
00461     {
00462       if (weight >= 0) walksatClauseWts.append(1);
00463       else walksatClauseWts.append(-1);
00464     }
00465         else
00466         {
00467       if (weight >= 0)
00468         walksatClauseWts.append((int)(weight*LWInfo::WSCALE + 0.5));
00469       else
00470         walksatClauseWts.append((int)(weight*LWInfo::WSCALE - 0.5));
00471         }
00472         
00473     delete intClause;
00474   }
00475                  
00476   delete intClauses;
00477 }
00478 
00479 /* get all the active clauses and append to the
00480  * allClauses array */
00481 void getWalksatClauses(Array<Array<int> *> &allClauses, Array<int> &allClauseWeights)
00482 {
00483   getWalksatClauses(NULL, allClauses, allClauseWeights, true);
00484 }
00485 
00486 /* Get all the clauses which become active by flipping this atom */
00487 void getWalksatClausesWhenFlipped(int atom,            
00488                                                                   Array<Array<int> *> &walksatClauses,
00489                                   Array<int> &walksatClauseWts)
00490 {
00491   Predicate *pred = predArray_[atom];
00492   TruthValue oldval = domain_->getDB()->getValue(pred);
00493   TruthValue val;
00494   (oldval == TRUE)? val=FALSE : val = TRUE;
00495   
00496     // Used to store other pred flipped in block
00497   Predicate* otherPred = NULL;
00498   inBlock_ = false;
00499   
00500   int blockIdx = domain_->getBlock(pred);
00501   if (blockIdx >= 0)
00502   {
00503       // Dealing with pred in a block
00504     const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
00505     if (block->size() > 1)
00506     {
00507       inBlock_ = true;
00508       int chosen = -1;
00509         // 1->0: Pick one at random to flip
00510       if (oldval == TRUE)
00511       {
00512         bool ok = false;
00513         while(!ok)
00514         {
00515           chosen = random() % block->size();
00516           if (!pred->same((*block)[chosen]))
00517             ok = true;
00518         }
00519         assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
00520       }
00521         // 0->1: Flip the pred with value 1
00522       else
00523       {
00524         for (int i = 0; i < block->size(); i++)
00525         {
00526           if (domain_->getDB()->getValue((*block)[i]) == TRUE)
00527           {
00528             chosen = i;
00529             assert(!pred->same((*block)[chosen]));
00530             break;
00531           }
00532         }
00533       }
00534       assert(chosen >= 0);
00535       otherPred = (*block)[chosen];
00536         // Set new value of other atom in block
00537       domain_->getDB()->setValue(otherPred, oldval);  
00538     }
00539   }
00540     // Set new value of actual atom
00541   domain_->getDB()->setValue(pred, val);
00542 
00543   getWalksatClauses(pred, walksatClauses, walksatClauseWts, true);
00544   if (inBlock_)
00545     getWalksatClauses(otherPred, walksatClauses, walksatClauseWts, true);
00546     
00547     // Set old value of atom and other in block
00548   domain_->getDB()->setValue(pred,oldval);
00549   if (inBlock_)
00550   {
00551     domain_->getDB()->setValue(otherPred, val);
00552     
00553       // Set index of other pred
00554     for(int atom = 1; atom <= getVarCount(); atom++)
00555     {
00556       if (otherPred->same(predArray_[atom]))
00557       {
00558         otherAtom_ = atom;
00559         break;
00560       }
00561     }
00562   }
00563 }
00564 
00565 /* Get the cost of all the clauses which become unsatisfied by flipping this atom */
00566 int getUnSatCostPerPred(Predicate* pred,
00567                         Array<Array<int> *> &walksatClauses,
00568                         Array<int> &walksatClauseWts)
00569 {
00570     // Randomizer
00571   int seed;
00572   struct timeval tv;
00573   struct timezone tzp;
00574   gettimeofday(&tv,&tzp);
00575   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00576   srandom(seed);
00577   
00578   int unSatCost = 0;
00579   Clause *clause;
00580   IntClause *intClause;
00581   IntClauseHashArray clauseHashArray;
00582 
00583   Array<IntClause *>* intClauses = new Array<IntClause *>; 
00584 
00585   const Array<IndexClause*>* indclauses;
00586   int predid = pred->getId(); 
00587   indclauses = mln_->getClausesContainingPred(predid);
00588   
00589   for(int j = 0; j < indclauses->size(); j++)
00590   {      
00591     clause = (Clause *) (*indclauses)[j]->clause;
00592 /*    
00593     double weight = clause->getWt()*LWInfo::WSCALE;
00594     //int wt = abs((int)(clause->getWt()*LWInfo::WSCALE+0.5));
00595     int wt;
00596 
00597       // Samplesat: all weights are 1 or -1
00598     if (sampleSat_)
00599     {
00600       if (weight >= 0)
00601         wt = 1;
00602       else
00603         wt = -1;
00604     }
00605     else
00606     {
00607       if (weight >= 0)
00608         wt = (int)(weight + 0.5);
00609       else
00610         wt = (int)(weight - 0.5);
00611     }
00612 */
00613     double wt = clause->getWt();
00614     intClauses->clear();
00615 
00616     if(abs(wt) < WEIGHT_EPSILON)
00617       continue;
00618       
00619     bool ignoreActivePreds = true;
00620     clause->getActiveClauses(pred, domain_, intClauses,
00621                              &predHashArray_, ignoreActivePreds);
00622     updatePredArray();
00623     
00624     for (int i = 0; i < intClauses->size(); i++)
00625     {
00626       intClause = (*intClauses)[i];
00627 
00628         // If using samplesat, then do clause selection
00629       if (sampleSat_)
00630       {
00631         assert(prevDB_);
00632 
00633 //cout << i << " wt " << wt << endl;
00634 
00635           // Pos. clause not satisfied in prev. iteration: don't activate
00636           // Neg. clause satisfied in prev. iteration: don't activate
00637         bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00638         if ((wt >= 0 && !sat) ||
00639             (wt < 0 && sat))
00640         {
00641           intClause->deleteIntPredicates();
00642           delete intClause;
00643           continue;
00644         }
00645 
00646           // In dead clauses: don't activate
00647         if (deadClauses_.contains(intClause))
00648         {
00649           intClause->deleteIntPredicates();
00650           delete intClause;
00651           continue;
00652         }
00653 
00654           // With prob. exp(-wt) don't ever activate it
00655         double threshold =
00656           clause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00657 
00658         if (random() > threshold)
00659         {
00660           deadClauses_.append(intClause);
00661           continue;
00662         }
00663       }
00664 
00665       int pos = clauseHashArray.find(intClause);
00666       if(pos >= 0)
00667       {
00668         clauseHashArray[pos]->addWt(wt);
00669         intClause->deleteIntPredicates();
00670         delete intClause;
00671         continue;
00672       }
00673 
00674 //cout << "Hard wt. " << LWInfo::HARD_WT << endl;           
00675 //cout << i << " wt2 " << wt << endl;
00676       if (wt == LWInfo::HARD_WT) intClause->setWtToHardWt();
00677       else intClause->setWt(wt);
00678       clauseHashArray.append(intClause);
00679 //cout << i << " intClause->setWt " << intClause->getWt() << endl;
00680     }
00681     
00682   }
00683   
00684   Array<int>* litClause;
00685   for(int i = 0; i < clauseHashArray.size(); i++)
00686   {
00687     intClause = clauseHashArray[i];
00688 //cout << i << " intClause->getWt() " << intClause->getWt() << endl;
00689     int weight = (int)(intClause->getWt());
00690     litClause = (Array<int> *)intClause->getIntPredicates();
00691     walksatClauses.append(litClause);
00692     if (sampleSat_)
00693     {
00694       if (weight >= 0) walksatClauseWts.append(1);
00695       else walksatClauseWts.append(-1);
00696       unSatCost += 1;
00697     }
00698     else
00699     {
00700       if (weight >= 0)
00701       {
00702         walksatClauseWts.append((int)(weight*LWInfo::WSCALE + 0.5));
00703         unSatCost += (int)(weight*LWInfo::WSCALE + 0.5);
00704       }
00705       else
00706       {
00707         walksatClauseWts.append((int)(weight*LWInfo::WSCALE - 0.5));
00708         unSatCost += (int)(weight*LWInfo::WSCALE - 0.5);
00709       }
00710     }
00711 
00712     //walksatClauseWts.append(weight);
00713 //cout << i << " Weight " << weight << endl;
00714     //unSatCost += abs(weight);
00715     delete intClause;
00716   }
00717   delete intClauses;
00718   
00719   return unSatCost;
00720 }
00721 
00722 /* 
00723  * Get the cost of all the clauses which become unsatisfied by flipping this atom,
00724  * taking blocks into account
00725  */
00726 int getUnSatCostWhenFlipped(int atom,
00727                             Array<Array<int> *> &walksatClauses,
00728                             Array<int> &walksatClauseWts)
00729 {        
00730   int unSatCost = 0;
00731   TruthValue val;
00732   Predicate *pred = predArray_[atom];
00733   TruthValue oldval = domain_->getDB()->getValue(pred);
00734   (oldval == TRUE) ? val = FALSE: val = TRUE;
00735 
00736     // Used to store other pred in block and its previous value
00737   Predicate* otherPred = NULL;
00738   bool inBlock = false;
00739   
00740   int blockIdx = domain_->getBlock(pred);
00741   if (blockIdx >= 0)
00742   {
00743       // Dealing with pred in a block
00744     const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
00745     if (block->size() > 1)
00746     {
00747       inBlock = true;
00748       int chosen = -1;
00749         // 1->0: Pick one at random to flip
00750       if (oldval == TRUE)
00751       {
00752         bool ok = false;
00753         while(!ok)
00754         {
00755           chosen = random() % block->size();
00756           if (!pred->same((*block)[chosen]))
00757             ok = true;
00758         }
00759         assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
00760       }
00761         // 0->1: Flip the pred with value 1
00762       else
00763       {
00764         for (int i = 0; i < block->size(); i++)
00765         {
00766           if (domain_->getDB()->getValue((*block)[i]) == TRUE)
00767           {
00768             chosen = i;
00769             assert(!pred->same((*block)[chosen]));
00770             break;
00771           }
00772         }
00773       }
00774       assert(chosen >= 0);
00775       otherPred = (*block)[chosen];
00776         // Set new value of other atom in block
00777       domain_->getDB()->setValue(otherPred, oldval);  
00778     }
00779   }
00780     // Set new value of actual atom
00781   domain_->getDB()->setValue(pred, val);
00782 
00783     // Get unsat cost of pred and other pred
00784   unSatCost += getUnSatCostPerPred(pred, walksatClauses, walksatClauseWts);
00785   if (inBlock)
00786     unSatCost += getUnSatCostPerPred(otherPred, walksatClauses, walksatClauseWts);
00787     
00788     // Set old value of atom and other in block
00789   domain_->getDB()->setValue(pred,oldval);
00790   if (inBlock)
00791     domain_->getDB()->setValue(otherPred, val);  
00792   return unSatCost;
00793 }
00794 
00795 /* set Database Values to the given array of values */
00796 void setVarVals(int newVals[])
00797 {
00798   for(int atom = 1; atom <= getVarCount(); atom++)
00799   {
00800     Predicate *ipred = predArray_[atom];
00801         if(newVals[atom] == 1)
00802           domain_->getDB()->setValue(ipred,TRUE);
00803         else
00804           domain_->getDB()->setValue(ipred,FALSE);
00805   }
00806 }
00807 
00808 /* flip the predicate corresponding to given atom */
00809 void flipVar(int atom)
00810 {
00811   Predicate *ipred = predArray_[atom];
00812   TruthValue val;
00813   val = domain_->getDB()->getValue(ipred);
00814   if(val == TRUE)
00815     domain_->getDB()->setValue(ipred,FALSE);
00816   else
00817         domain_->getDB()->setValue(ipred,TRUE);
00818 }
00819 
00820 /* get the value (true/false) of the particular variable */
00821 bool getVarVal(int atom)
00822 {
00823   Predicate *ipred = predArray_[atom];
00824   bool val;
00825   (domain_->getDB()->getValue(ipred) == TRUE)? val = true : val = false;
00826   return val;
00827 }
00828 
00829 /* sets the value (true/false) of the particular variable */
00830 void setVarVal(int atom, bool val)
00831 {
00832   Predicate *ipred = predArray_[atom];
00833   if (val)
00834     domain_->getDB()->setValue(ipred, TRUE);
00835   else
00836     domain_->getDB()->setValue(ipred, FALSE);
00837 }
00838 
00839 Predicate* getVar(int atom)
00840 {
00841   return predArray_[atom];      
00842 }
00843 
00844 void setAllActive()
00845 {
00846   LWUtil::setAllActive(domain_);
00847 }
00848 
00849 void setAllInactive()
00850 {
00851   for(int atom = 1; atom <= getVarCount(); atom++)
00852   {
00853         Predicate* ipred = predArray_[atom];
00854         domain_->getDB()->setActiveStatus(ipred, false);
00855   }
00856 }
00857 
00858 void setAllFalse()
00859 {
00860   for(int atom = 1; atom <= getVarCount(); atom++)
00861   {
00862         Predicate* ipred = predArray_[atom];
00863         domain_->getDB()->setValue(ipred, FALSE);
00864   }
00865 }
00866 
00867 /* Deactivates and removes atoms at indices indicated */
00868 void removeVars(Array<int> indices)
00869 {
00870   for (int i = 0; i < indices.size(); i++)
00871   {
00872         Predicate* pred = predArray_[indices[i]];
00873         domain_->getDB()->setActiveStatus(pred, false);
00874         domain_->getDB()->setDeactivatedStatus(pred, true);
00875   }
00876   reset();
00877   //cout << "reset " << endl;
00878 }
00879 
00880 void setSampleSat(bool s)
00881 {
00882   sampleSat_ = s;
00883 }
00884 
00885 bool getSampleSat()
00886 {
00887   return sampleSat_;
00888 }
00889 
00890   //Make a deep copy of the assignments from last iteration
00891 void setPrevDB()
00892 {
00893   if (prevDB_) { delete prevDB_; prevDB_ = NULL; }
00894   prevDB_ = new Database((*domain_->getDB()));
00895 }
00896 
00897 int getNumDBAtoms()
00898 {
00899   return numDBAtoms_;
00900 }
00901 
00902 /*
00903  * Attempts to activate a random atom not already in memory and gets the
00904  * clauses activated by doing this. Returns true if successful, otherwise false.
00905  */
00906 bool activateRandomAtom(Array<Array<int> *> &walksatClauses,
00907                         Array<int> &walksatClauseWts,
00908                         int& toflip)
00909 {
00910     // Randomizer
00911   int seed;
00912   struct timeval tv;
00913   struct timezone tzp;
00914   gettimeofday(&tv,&tzp);
00915   seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00916   srandom(seed);
00917 
00918   Predicate* pred;
00919   pred = domain_->getNonEvidenceAtom(random() % numDBAtoms_);
00920 
00921   int pos = -1;
00922   if ((pos = predHashArray_.find(pred)) == -1)
00923   {
00924     domain_->getDB()->setActiveStatus(pred, true);
00925     predHashArray_.append(pred);
00926     updatePredArray();
00927     toflip = predHashArray_.size();
00928     getWalksatClauses(pred, walksatClauses, walksatClauseWts, true);
00929     return true;
00930   }
00931   else
00932   {
00933     delete pred;
00934     toflip = pos + 1;
00935     return false;
00936   }
00937 }
00938 
00939 /* If atom is in block, then another atom is chosen to flip and activated */
00940 void chooseOtherToFlip(int atom,
00941                        Array<Array<int> *> &walksatClauses,
00942                        Array<int> &walksatClauseWts)
00943 {
00944     // atom is assumed to be active
00945   assert(isActive(atom));
00946   Predicate *pred = predArray_[atom];
00947   TruthValue oldval = domain_->getDB()->getValue(pred);
00948   TruthValue val;
00949   (oldval == TRUE)? val = FALSE : val = TRUE;
00950   
00951     // Used to store other pred flipped in block
00952   Predicate* otherPred = NULL;
00953   inBlock_ = false;
00954   int blockIdx = domain_->getBlock(pred);
00955 
00956   if (blockIdx < 0)
00957   {
00958     return;
00959   }
00960   else
00961   {
00962       // Dealing with pred in a block
00963     const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
00964     if (block->size() > 1)
00965     {
00966       inBlock_ = true;
00967       int chosen = -1;
00968         // 1->0: Pick one at random to flip
00969       if (oldval == TRUE)
00970       {
00971         bool ok = false;
00972         while(!ok)
00973         {
00974           chosen = random() % block->size();
00975           if (!pred->same((*block)[chosen]))
00976             ok = true;
00977         }
00978         assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
00979       }
00980         // 0->1: Flip the pred with value 1
00981       else
00982       {
00983         for (int i = 0; i < block->size(); i++)
00984         {
00985           if (domain_->getDB()->getValue((*block)[i]) == TRUE)
00986           {
00987             chosen = i;
00988             assert(!pred->same((*block)[chosen]));
00989             break;
00990           }
00991         }
00992       }
00993       assert(chosen >= 0);
00994       otherPred = (*block)[chosen];
00995         // Set new value of other atom in block
00996       domain_->getDB()->setValue(otherPred, oldval);  
00997     }
00998   }
00999     // Set new value of actual atom
01000   domain_->getDB()->setValue(pred, val);
01001   if (inBlock_)
01002     getWalksatClauses(otherPred, walksatClauses, walksatClauseWts, true);
01003     
01004     // Set old value of atom and other in block
01005   domain_->getDB()->setValue(pred,oldval);
01006   if (inBlock_)
01007   {
01008     domain_->getDB()->setValue(otherPred, val);
01009     
01010       // Set index of other pred
01011     for(int atom = 1; atom <= getVarCount(); atom++)
01012     {
01013       if (otherPred->same(predArray_[atom]))
01014       {
01015         otherAtom_ = atom;
01016         break;
01017       }
01018     }
01019   }
01020 }
01021 
01022   // Sets the truth values of all atoms in the
01023   // block blockIdx except for the one with index atomIdx
01024 void setOthersInBlockToFalse(const int& atomIdx,
01025                              const int& blockIdx)
01026 {
01027   const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
01028   for (int i = 0; i < block->size(); i++)
01029   {
01030     if (i != atomIdx)
01031       domain_->getDB()->setValue((*block)[i], FALSE);
01032   }
01033 }
01034 
01035   // Makes an assignment to all preds in blocks which does not
01036   // violate the blocks. One randomly chosen pred in each box is set to true.
01037 void initBlocks()
01038 {
01039   const Array<Array<Predicate*>*>* blocks = domain_->getPredBlocks();
01040   const Array<bool>* blockEvidence = domain_->getBlockEvidenceArray();
01041   
01042   for (int i = 0; i < blocks->size(); i++)
01043   {
01044     Array<Predicate*>* block = (*blocks)[i];
01045     int chosen = -1;
01046       // If evidence atom exists, then all others are false
01047     if ((*blockEvidence)[i])
01048     {
01049       chosen = domain_->getEvidenceIdxInBlock(i);
01050       setOthersInBlockToFalse(chosen, i);
01051       continue;
01052     }
01053     else
01054     {
01055         // Set one at random to true
01056       chosen = random() % block->size();
01057     }
01058     assert(chosen >= 0);
01059     domain_->getDB()->setValue((*block)[chosen], TRUE);
01060     setOthersInBlockToFalse(chosen, i);
01061   }
01062 }
01063 
01064 int getBlock(int atom)
01065 {
01066   const Array<Array<Predicate*>*>* blocks = domain_->getPredBlocks();
01067   Predicate* pred = getVar(atom);
01068   
01069   for (int i = 0; i < blocks->size(); i++)
01070   {
01071     Array<Predicate*>* block = (*blocks)[i];
01072     for (int j = 0; j < block->size(); j++)
01073     {
01074       if ((*block)[j]->same(pred))
01075         return i;
01076     }
01077   }
01078   return -1;
01079 }
01080 
01081 bool inBlock(int atom)
01082 {
01083   return (getBlock(atom) >= 0);
01084 }
01085 
01086 bool inBlockWithEvidence(int atom)
01087 {
01088   int blockIdx = getBlock(atom);
01089   if (blockIdx >= 0 && domain_->getBlockEvidence(blockIdx))
01090     return true;
01091   return false;
01092 }
01093 
01094 bool getInBlock()
01095 {
01096   return inBlock_;
01097 }
01098 
01099 void setInBlock(const bool val)
01100 {
01101   inBlock_ = val;
01102 }
01103 
01104 int getOtherAtom()
01105 {
01106   assert(inBlock_ && otherAtom_ >= 0);
01107   return otherAtom_;
01108 }
01109 
01110 void setOtherAtom(const int val)
01111 {
01112   otherAtom_ = val;
01113 }
01114 
01115   // Set evidence status of given atom to given value
01116 void setEvidence(const int atom, const bool val)
01117 {
01118   domain_->getDB()->setEvidenceStatus(predArray_[atom], val);
01119 }
01120 
01121   // Get evidence status of given atom
01122 bool getEvidence(const int atom)
01123 {
01124   return domain_->getDB()->getEvidenceStatus(predArray_[atom]);
01125 }
01126 
01127 void incrementNumDBAtoms()
01128 {
01129   numDBAtoms_++;
01130 }
01131 
01132 void decrementNumDBAtoms()
01133 {
01134   numDBAtoms_--;
01135 }
01136 
01137 void printIntClauses(Array<IntClause *> clauses)
01138 {
01139   for (int i = 0; i < clauses.size(); i++)
01140   {
01141     clauses[i]->printWithWtAndStrVar(cout, domain_, &predHashArray_);
01142     cout << endl;
01143   }
01144 }
01145 
01146 
01147 /* 
01148  * Fixes atoms using unit propagation on non-active clauses.
01149  * Fixed atoms are stored in fixedAtoms_ and set as evidence in the DB.
01150  */
01151 void propagateFixedAtoms(Array<Array<int> *> &clauses,
01152                          Array<int> &clauseWeights,
01153                          bool* fixedAtoms,
01154                          int maxFixedAtoms)
01155 {
01156   Array<Array<int> *> tmpClauses;
01157   Array<int> tmpClauseWeights;
01158     
01159     //TEMP: count fixed atoms
01160   int count = 0;
01161   for (int i = 0; i < maxFixedAtoms; i++)
01162   {
01163     if (fixedAtoms[i]) count++;
01164   }
01165   cout << "Fixed atoms before propagating: " << count << endl;
01166   
01167     // Look at all non-active clauses containing a fixed atom
01168   for (int i = 0; i < maxFixedAtoms; i++)
01169   {
01170     if (fixedAtoms[i])
01171     {
01172         // Get all non-active clauses containing it,
01173         // filtering out dead clauses
01174       getWalksatClauses(predArray_[i], tmpClauses, tmpClauseWeights, false);
01175 cout << "Atom " << i << endl;
01176 cout << "Clauses " << tmpClauses.size() << endl;
01177 
01178         // clauses now contains potential clauses of all length
01179         // want to look at only unit clauses
01180       for (int j = 0; j < tmpClauses.size(); j++)
01181       {
01182           // Unit clause
01183         if (tmpClauses[j]->size() == 1)
01184         {
01185           clauses.append(tmpClauses[j]);
01186           clauseWeights.append(clauseWeights[j]);
01187           fixedAtoms[(*tmpClauses[j])[0]] = true;
01188         }
01189         else // Not unit clause
01190         {
01191           delete tmpClauses[j];
01192         }
01193       }
01194     }
01195     tmpClauses.clear();
01196     tmpClauseWeights.clear();
01197   }
01198   
01199       //TEMP: count fixed atoms
01200   count = 0;
01201   for (int i = 0; i < maxFixedAtoms; i++)
01202   {
01203     if (fixedAtoms[i]) count++;
01204   }
01205   cout << "Fixed atoms after propagating: " << count << endl;
01206   
01207   exit(0);
01208 }
01209 
01210 
01211 public:
01212         //Weight of hard clauses
01213   static double HARD_WT;
01214         //Scale to use on weights
01215   static double WSCALE;
01216 
01217 private:
01218 
01219   MLN *mln_;
01220   Domain *domain_;
01221   
01222   Array<Predicate *> predArray_;
01223   PredicateHashArray predHashArray_;
01224 
01225   Array<Array<int> *> predToClauseIds_;
01226 
01227         // Set to true when performing sample sat
01228   bool sampleSat_;
01229         // Dead clauses not to be considered for activation
01230   IntClauseHashArray deadClauses_;
01231         // DB holding assignment from previous iteration
01232   Database* prevDB_;
01233     // Total number of non-evidence ground atoms in the DB
01234   int numDBAtoms_;
01235     // Flag indicating if an atom in a block is being flipped
01236   bool inBlock_;
01237     // Index of other atom to be flipped in block
01238   int otherAtom_;
01239 
01240 };
01241 
01242 #endif
01243 

Generated on Tue Jan 16 05:30:03 2007 for Alchemy by  doxygen 1.5.1