variablestate.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef VARIABLESTATE_H_
00067 #define VARIABLESTATE_H_
00068 
00069 #include "mrf.h"
00070 
00071 const int NOVALUE = -1;
00072 const bool vsdebug = false;
00073 
00093 class VariableState
00094 {
00095  public:
00096 
00115   VariableState(GroundPredicateHashArray* const& unknownQueries,
00116                 GroundPredicateHashArray* const& knownQueries,
00117                 Array<TruthValue>* const & knownQueryValues,
00118                 const Array<int>* const & allPredGndingsAreQueries,
00119                 const bool& markHardGndClauses,
00120                 const bool& trackParentClauseWts,
00121                 const MLN* const & mln, const Domain* const & domain,
00122                 const bool& lazy)
00123   {
00124     this->mln_ = (MLN*)mln;
00125     this->domain_ = (Domain*)domain;
00126     this->lazy_ = lazy;
00127 
00128       // Instantiate information
00129     baseNumAtoms_ = 0;
00130     activeAtoms_ = 0;
00131     numFalseClauses_ = 0;
00132     costOfFalseClauses_ = 0.0;
00133     lowCost_ = LDBL_MAX;
00134     lowBad_ = INT_MAX;
00135 
00136       // Clauses and preds are stored in gndClauses_ and gndPreds_
00137     gndClauses_ = new Array<GroundClause*>;
00138     gndPreds_ = new Array<GroundPredicate*>;
00139 
00140       // Set the hard clause weight
00141     setHardClauseWeight();
00142 
00143       // Lazy version: Produce state with initial active atoms and clauses
00144     if (lazy_)
00145     {
00146         // Unknown preds are treated as false
00147       domain_->getDB()->setPerformingInference(true);
00148 
00149         // Blocks are copied from the domain
00150       initLazyBlocks();
00151 
00152       clauseLimit_ = INT_MAX;
00153       noApprox_ = false;
00154       haveDeactivated_ = false;
00155 
00157         // Get initial set of active atoms (atoms in unsat. clauses)
00158         // Assumption is: all atoms are initially false except those in blocks
00159 
00160         // One atom in each block is set to true and activated
00161       addOneAtomToEachBlock();
00162 
00163 //cout << "After addOneAtomToEachBlock" << endl;
00164 //for (int i = 1; i < atom_.size(); i++)
00165 //  cout << atom_[i] << endl;
00166 
00167       bool ignoreActivePreds = false;
00168       getActiveClauses(newClauses_, ignoreActivePreds);
00169       int defaultCnt = newClauses_.size();
00170       long double defaultCost = 0;
00171 
00172       for (int i = 0; i < defaultCnt; i++)
00173       {
00174         if (newClauses_[i]->isHardClause())
00175           defaultCost += hardWt_;
00176         else
00177           defaultCost += abs(newClauses_[i]->getWt());
00178       }
00179 
00180         // Clear ground clauses in the ground preds
00181       for (int i = 0; i < gndPredHashArray_.size(); i++)
00182         gndPredHashArray_[i]->removeGndClauses();
00183 
00184         // Delete new clauses
00185       for (int i = 0; i < newClauses_.size(); i++)
00186         delete newClauses_[i];
00187       newClauses_.clear();
00188 
00189       baseNumAtoms_ = gndPredHashArray_.size();
00190       cout << "Number of Baseatoms = " << baseNumAtoms_ << endl;
00191       cout << "Default => Cost\t" << "******\t" << " Clause Cnt\t" << endl;
00192       cout << "           " << defaultCost << "\t" << "******\t" << defaultCnt
00193            << "\t" << endl << endl;
00194 
00195         // Set base atoms as active in DB
00196       for (int i = 0; i < baseNumAtoms_; i++)
00197       {
00198         domain_->getDB()->setActiveStatus(gndPredHashArray_[i], true);
00199         activeAtoms_++;        
00200       }
00201 
00202         // Add the rest of the atoms in the blocks, but don't activate
00203       fillLazyBlocks();
00204 
00205         // Get the initial set of active clauses
00206       ignoreActivePreds = false;
00207       getActiveClauses(newClauses_, ignoreActivePreds);      
00208     } // End lazy version
00209       // Eager version: Use KBMC to produce the state
00210     else
00211     {
00212       unePreds_ = unknownQueries;
00213       knePreds_ = knownQueries;
00214       knePredValues_ = knownQueryValues;
00215 
00216         // MRF is built on known and unknown queries
00217       int size = 0;
00218       if (unknownQueries) size += unknownQueries->size();
00219       if (knownQueries) size += knownQueries->size();
00220       GroundPredicateHashArray* queries = new GroundPredicateHashArray(size);
00221       if (unknownQueries) queries->append(unknownQueries);
00222       if (knownQueries) queries->append(knownQueries);
00223       mrf_ = new MRF(queries, allPredGndingsAreQueries, domain_,
00224                      domain_->getDB(), mln_, markHardGndClauses,
00225                      trackParentClauseWts, -1);
00226         //delete to save space. Can be deleted because no more gndClauses are
00227         //appended to gndPreds beyond this point
00228       mrf_->deleteGndPredsGndClauseSets();
00229         //do not delete the intArrRep in gndPreds_;
00230       delete queries;
00231 
00232         // Blocks built in MRF
00233       blocks_ = mrf_->getBlocks();
00234       blockEvidence_ = mrf_->getBlockEvidence();
00235         
00236         // Put ground clauses in newClauses_
00237       newClauses_ = *(Array<GroundClause*>*)mrf_->getGndClauses();
00238         // Put ground preds in the hash array
00239       //const Array<GroundPredicate*>* gndPreds = mrf_->getGndPreds();
00240       const GroundPredicateHashArray* gndPreds = mrf_->getGndPreds();
00241       for (int i = 0; i < gndPreds->size(); i++)
00242         gndPredHashArray_.append((*gndPreds)[i]);
00243     
00244         // baseNumAtoms_ are all atoms in eager version
00245       baseNumAtoms_ = gndPredHashArray_.size();        
00246     } // End eager version
00247     
00248       // At this point, ground clauses are held in newClauses_
00249       // and ground predicates are held in gndPredHashArray_
00250       // for both versions
00251     
00252       // Add the clauses and preds and fill info arrays
00253     bool initial = true;
00254     addNewClauses(initial);
00255     
00256     cout << "Initial num. of clauses: " << getNumClauses() << endl;
00257   }
00258 
00263   ~VariableState()
00264   {
00265     if (lazy_)
00266     {
00267         // Block information from lazy version is deleted
00268       for (int i = 0; i < blocks_->size(); i++)
00269         (*blocks_)[i].clearAndCompress();
00270       delete blocks_;
00271     
00272       delete blockEvidence_;  
00273     }
00274     else
00275     {
00276         // MRF from eager version is deleted
00277       if (mrf_) delete mrf_;
00278       //if (unePreds_) delete unePreds_;
00279       //if (knePreds_) delete knePreds_;
00280       //if (knePredValues_) delete knePredValues_;
00281     }    
00282   }
00283 
00284 
00292   void addNewClauses(bool initial)
00293   {
00294     if (vsdebug)
00295       cout << "Adding " << newClauses_.size() << " new clauses.." << endl;
00296 
00297       // Store the old number of clauses and atoms
00298     int oldNumClauses = getNumClauses();
00299     int oldNumAtoms = getNumAtoms();
00300 
00301     gndClauses_->append(newClauses_);
00302     gndPreds_->growToSize(gndPredHashArray_.size());
00303 
00304     int numAtoms = getNumAtoms();
00305     int numClauses = getNumClauses();
00306       // If no new atoms or clauses have been added, then do nothing
00307     if (numAtoms == oldNumAtoms && numClauses == oldNumClauses) return;
00308 
00309     if (vsdebug) cout << "Clauses: " << numClauses << endl;
00310 
00311       // atomIdx starts at 1
00312     atom_.growToSize(numAtoms + 1, false);
00313     makeCost_.growToSize(numAtoms + 1, 0.0);
00314     breakCost_.growToSize(numAtoms + 1, 0.0);
00315     lowAtom_.growToSize(numAtoms + 1, false);
00316     fixedAtom_.growToSize(numAtoms + 1, 0);
00317 
00318       // Copy ground preds to gndPreds_ and set values in atom and lowAtom
00319     for (int i = oldNumAtoms; i < gndPredHashArray_.size(); i++)
00320     {
00321       (*gndPreds_)[i] = gndPredHashArray_[i];
00322 
00323       if (vsdebug)
00324       {
00325         cout << "New pred: ";
00326         (*gndPreds_)[i]->print(cout, domain_);
00327         cout << endl;
00328       }
00329 
00330       lowAtom_[i + 1] = atom_[i + 1] = 
00331         (domain_->getDB()->getValue((*gndPreds_)[i]) == TRUE) ? true : false;
00332     }
00333     newClauses_.clear();
00334 
00335     clause_.growToSize(numClauses);
00336     clauseCost_.growToSize(numClauses);
00337     falseClause_.growToSize(numClauses);
00338     whereFalse_.growToSize(numClauses);
00339     numTrueLits_.growToSize(numClauses);
00340     watch1_.growToSize(numClauses);
00341     watch2_.growToSize(numClauses);
00342     isSatisfied_.growToSize(numClauses, false);
00343     deadClause_.growToSize(numClauses, false);
00344     threshold_.growToSize(numClauses, false);
00345 
00346     occurence_.growToSize(2*numAtoms + 1);
00347 
00348     for (int i = oldNumClauses; i < numClauses; i++)
00349     {
00350       GroundClause* gndClause = (*gndClauses_)[i];
00351 
00352       if (vsdebug)
00353       {
00354         cout << "New clause: ";
00355         gndClause->print(cout, domain_, &gndPredHashArray_);
00356         cout << endl;
00357       }
00358       
00359         // Set thresholds for clause selection
00360       if (gndClause->isHardClause()) threshold_[i] = RAND_MAX;
00361       else
00362       {
00363         double w = gndClause->getWt();
00364         threshold_[i] = RAND_MAX*(1 - exp(-abs(w)));
00365         if (vsdebug)
00366         {
00367           cout << "Weight: " << w << endl;            
00368         }
00369       }
00370       if (vsdebug)
00371         cout << "Threshold: " << threshold_[i] << endl;            
00372       
00373       int numGndPreds = gndClause->getNumGroundPredicates();
00374       clause_[i].growToSize(numGndPreds);
00375 
00376       for (int j = 0; j < numGndPreds; j++) 
00377       {
00378           // idx in gndClause + 1 (negated if neg. literal)
00379         //int idx = gndClause->getGroundPredicateIndex(j);
00380         //assert(idx >= 0);
00381         //int lit = (gndClause->getGroundPredicateSense(j)) ?
00382         //          idx + 1 : -(idx + 1);
00383         int lit = gndClause->getGroundPredicateIndex(j);
00384         clause_[i][j] = lit;
00385         int litIdx = 2*abs(lit) - (lit > 0);
00386         occurence_[litIdx].append(i);
00387       }
00388 
00389         // Hard clause weight has been previously determined
00390       if (gndClause->isHardClause())
00391         clauseCost_[i] = hardWt_;
00392       else
00393         clauseCost_[i] = gndClause->getWt();
00394     }
00395 
00396     if (!initial)
00397     {
00398       //initNumSatLiterals(1, oldNumClauses);
00399       if (useThreshold_)
00400       {
00401         killClauses(oldNumClauses);
00402       }
00403       else
00404       {
00405         initMakeBreakCostWatch(oldNumClauses);
00406       }
00407     }
00408     if (vsdebug) cout << "Done adding new clauses.." << endl;
00409   }
00410 
00414   void init()
00415   {
00416       // Reset info concerning true lits, false clauses, etc.
00417     for (int i = 0; i < getNumClauses(); i++) numTrueLits_[i] = 0;
00418     numFalseClauses_ = 0;
00419     costOfFalseClauses_ = 0.0;
00420     lowCost_ = LDBL_MAX;
00421     lowBad_ = INT_MAX;
00422 
00423       // Initialize info arrays
00424     initMakeBreakCostWatch(0);
00425   }
00426   
00432   void initRandom()
00433   {
00434       // Set one in each block to true randomly
00435     initBlocksRandom();
00436 
00437       // Random truth value for all not in blocks
00438     for (int i = 1; i <= baseNumAtoms_; i++)
00439     {
00440         // fixedAtom_[i] = -1: false, fixedAtom_[i] = 1: true
00441       if (fixedAtom_[i] != 0) setValueOfAtom(i, (fixedAtom_[i] == 1));
00442         // Blocks are initialized above
00443       if (getBlockIndex(i - 1) >= 0)
00444       {
00445         if (vsdebug) cout << "Atom " << i << " in block" << endl;
00446         continue;
00447       }
00448         // Not fixed and not in block
00449       else
00450       {
00451         if (vsdebug) cout << "Atom " << i << " not in block" << endl;
00452         setValueOfAtom(i, random() % 2);
00453       }
00454     }
00455     init();
00456   }
00457 
00461   void initBlocksRandom()
00462   {
00463     if (vsdebug)
00464     {
00465       cout << "Initializing blocks randomly" << endl;
00466       cout << "Num. of blocks: " << blocks_->size() << endl;
00467     }
00468     
00469       // For each block: select one to set to true
00470     for (int i = 0; i < blocks_->size(); i++)
00471     {
00472         // True fixed atom in the block: set others to false
00473       if (int trueFixedAtomInBlock = getTrueFixedAtomInBlock(i) >= 0)
00474       {
00475         if (vsdebug) cout << "True fixed atom in block " << i << endl;
00476         setOthersInBlockToFalse(trueFixedAtomInBlock, i);
00477         continue;
00478       }
00479 
00480         // If evidence atom exists, then all others are false
00481       if ((*blockEvidence_)[i])
00482       {
00483           // If first argument is -1, then all are set to false
00484         setOthersInBlockToFalse(-1, i);
00485         continue;
00486       }
00487 
00488         // Eager version: pick one at random
00489       Array<int>& block = (*blocks_)[i];
00490       bool ok = false;
00491       while (!ok)
00492       {
00493         int chosen = random() % block.size();
00494           // Atom not fixed
00495         if (fixedAtom_[block[chosen] + 1] == 0)
00496         {
00497           if (vsdebug) cout << "Atom " << block[chosen] + 1 
00498                             << " chosen in block" << endl;
00499           setValueOfAtom(block[chosen] + 1, true);
00500           setOthersInBlockToFalse(chosen, i);
00501           ok = true;
00502         }
00503       }
00504     }
00505     if (vsdebug) cout << "Done initializing blocks randomly" << endl;
00506   }      
00507 
00515   void initMakeBreakCostWatch(const int& startClause)
00516   {
00517     int theTrueLit = -1;
00518       // Initialize breakCost and makeCost in the following:
00519     for (int i = startClause; i < getNumClauses(); i++)
00520     {
00521         // Don't look at dead clauses
00522       if (deadClause_[i]) continue;
00523       int trueLit1 = 0;
00524       int trueLit2 = 0;
00525       long double cost = clauseCost_[i];
00526       numTrueLits_[i] = 0;
00527       for (int j = 0; j < getClauseSize(i); j++)
00528       {
00529         if (isTrueLiteral(clause_[i][j]))
00530         { // ij is true lit
00531           numTrueLits_[i]++;
00532           theTrueLit = abs(clause_[i][j]);
00533           if (!trueLit1) trueLit1 = theTrueLit;
00534           else if (trueLit1 && !trueLit2) trueLit2 = theTrueLit;
00535         }
00536       }
00537 
00538         // Unsatisfied positive-weighted clauses or
00539         // Satisfied negative-weighted clauses
00540       if ((numTrueLits_[i] == 0 && cost > 0) ||
00541           (numTrueLits_[i] > 0 && cost < 0))
00542       {
00543         whereFalse_[i] = numFalseClauses_;
00544         falseClause_[numFalseClauses_] = i;
00545         numFalseClauses_++;
00546         costOfFalseClauses_ += abs(cost);
00547         if (highestCost_ == abs(cost)) {eqHighest_ = true; numHighest_++;}
00548 
00549           // Unsat. pos. clause: increase makeCost_ of all atoms
00550         if (numTrueLits_[i] == 0)
00551           for (int j = 0; j < getClauseSize(i); j++)
00552           {
00553             makeCost_[abs(clause_[i][j])] += cost;
00554           }
00555 
00556           // Sat. neg. clause: increase makeCost_ if one true literal
00557         if (numTrueLits_[i] == 1)
00558         {
00559             // Subtract neg. cost
00560           makeCost_[theTrueLit] -= cost;
00561           watch1_[i] = theTrueLit;
00562         }
00563         else if (numTrueLits_[i] > 1)
00564         {
00565           watch1_[i] = trueLit1;
00566           watch2_[i] = trueLit2;
00567         }
00568       }
00569         // Pos. clauses satisfied by one true literal
00570       else if (numTrueLits_[i] == 1 && cost > 0)
00571       {
00572         breakCost_[theTrueLit] += cost;
00573         watch1_[i] = theTrueLit;
00574       }
00575         // Pos. clauses satisfied by 2 or more true literals
00576       else if (cost > 0)
00577       { /*if (numtruelit[i] == 2)*/
00578         watch1_[i] = trueLit1;
00579         watch2_[i] = trueLit2;
00580       }
00581         // Unsat. neg. clauses: increase breakCost of all atoms
00582       else if (numTrueLits_[i] == 0 && cost < 0)
00583       {
00584         for (int j = 0; j < getClauseSize(i); j++)
00585           breakCost_[abs(clause_[i][j])] -= cost;
00586       }
00587     } // for all clauses
00588   }
00589 
00590   int getNumAtoms() { return gndPreds_->size(); }
00591   
00592   int getNumClauses() { return gndClauses_->size(); }
00593   
00594   int getNumDeadClauses()
00595   { 
00596     int count = 0;
00597     for (int i = 0; i < deadClause_.size(); i++)
00598       if (deadClause_[i]) count++;
00599     return count;
00600   }
00601 
00605   int getIndexOfRandomAtom()
00606   {
00607     int numAtoms = getNumAtoms();
00608     if (numAtoms == 0) return NOVALUE;
00609     return random()%numAtoms + 1;
00610   }
00611 
00617   int getIndexOfAtomInRandomFalseClause()
00618   {
00619     if (numFalseClauses_ == 0) return NOVALUE;
00620     int clauseIdx = falseClause_[random()%numFalseClauses_];
00621       // Pos. clause: return index of random atom
00622     if (clauseCost_[clauseIdx] > 0)
00623       return abs(clause_[clauseIdx][random()%getClauseSize(clauseIdx)]);
00624       // Neg. clause: find random true lit
00625     else
00626       return getRandomTrueLitInClause(clauseIdx);
00627   }
00628   
00633   int getRandomFalseClauseIndex()
00634   {
00635     if (numFalseClauses_ == 0) return NOVALUE;
00636     return falseClause_[random()%numFalseClauses_];
00637   }
00638   
00643   long double getCostOfFalseClauses()
00644   {
00645     return costOfFalseClauses_;
00646   }
00647   
00652   int getNumFalseClauses()
00653   {
00654     return numFalseClauses_;
00655   }
00656 
00663   bool getValueOfAtom(const int& atomIdx)
00664   {
00665     return atom_[atomIdx];
00666   }
00667 
00674   bool getValueOfLowAtom(const int& atomIdx)
00675   {
00676     return lowAtom_[atomIdx];
00677   }
00678 
00686   void setValueOfAtom(const int& atomIdx, const bool& value)
00687   {
00688     if (vsdebug) cout << "Setting value of atom " << atomIdx 
00689                       << " to " << value << endl;
00690       // If atom already has this value, then do nothing
00691     if (atom_[atomIdx] == value) return;
00692       // Propagate assigment to DB
00693     GroundPredicate* p = gndPredHashArray_[atomIdx - 1];
00694     if (value)
00695     {
00696       domain_->getDB()->setValue(p, TRUE);
00697     }
00698     else
00699     {
00700       domain_->getDB()->setValue(p, FALSE);
00701     }
00702       // If not active, then activate it
00703     if (lazy_ && !isActive(atomIdx))
00704     {
00705       bool ignoreActivePreds = false;
00706       activateAtom(atomIdx, ignoreActivePreds);
00707     }
00708     atom_[atomIdx] = value;
00709   }
00710 
00714   Array<int>& getNegOccurenceArray(const int& atomIdx)
00715   {
00716     int litIdx = 2*atomIdx;
00717     return getOccurenceArray(litIdx);
00718   }
00719 
00723   Array<int>& getPosOccurenceArray(const int& atomIdx)
00724   {
00725     int litIdx = 2*atomIdx - 1;
00726     return getOccurenceArray(litIdx);
00727   }
00728 
00734   void flipAtom(const int& toFlip)
00735   {
00736     bool toFlipValue = getValueOfAtom(toFlip);
00737     register int clauseIdx;
00738     int sign;
00739     int oppSign;
00740     int litIdx;
00741     if (toFlipValue)
00742       sign = 1;
00743     else
00744       sign = 0;
00745     oppSign = sign ^ 1;
00746 
00747     flipAtomValue(toFlip);
00748     
00749       // Update all clauses in which the atom occurs as a true literal
00750     litIdx = 2*toFlip - sign;
00751     Array<int>& posOccArray = getOccurenceArray(litIdx);
00752     for (int i = 0; i < posOccArray.size(); i++)
00753     {
00754       clauseIdx = posOccArray[i];
00755         // Don't look at dead clauses
00756       if (deadClause_[clauseIdx]) continue;
00757         // The true lit became a false lit
00758       int numTrueLits = decrementNumTrueLits(clauseIdx);
00759       long double cost = getClauseCost(clauseIdx);
00760       int watch1 = getWatch1(clauseIdx);
00761       int watch2 = getWatch2(clauseIdx);
00762 
00763         // 1. If last true lit was flipped, then we have to update
00764         // the makecost / breakcost info accordingly        
00765       if (numTrueLits == 0)
00766       {
00767           // Pos. clause
00768         if (cost > 0)
00769         {
00770             // Add this clause as false in the state
00771           addFalseClause(clauseIdx);
00772             // Decrease toFlip's breakcost (add neg. cost)
00773           addBreakCost(toFlip, -cost);
00774             // Increase makecost of all vars in clause (add pos. cost)
00775           addMakeCostToAtomsInClause(clauseIdx, cost);
00776         }
00777           // Neg. clause
00778         else
00779         {
00780           assert(cost < 0);
00781             // Remove this clause as false in the state
00782           removeFalseClause(clauseIdx);
00783             // Increase breakcost of all vars in clause (add pos. cost)
00784           addBreakCostToAtomsInClause(clauseIdx, -cost);        
00785             // Decrease toFlip's makecost (add neg. cost)
00786           addMakeCost(toFlip, cost);
00787         }
00788       }
00789         // 2. If there is now one true lit left, then move watch2
00790         // up to watch1 and increase the breakcost / makecost of watch1
00791       else if (numTrueLits == 1)
00792       {
00793         if (watch1 == toFlip)
00794         {
00795           assert(watch1 != watch2);
00796           setWatch1(clauseIdx, watch2);
00797           watch1 = getWatch1(clauseIdx);
00798         }
00799 
00800           // Pos. clause: Increase toFlip's breakcost (add pos. cost)
00801         if (cost > 0)
00802         {
00803           addBreakCost(watch1, cost);
00804         }
00805           // Neg. clause: Increase toFlip's makecost (add pos. cost)
00806         else
00807         {
00808           assert(cost < 0);
00809           addMakeCost(watch1, -cost);
00810         }
00811       }
00812         // 3. If there are 2 or more true lits left, then we have to
00813         // find a new true lit to watch if one was flipped
00814       else
00815       { /* numtruelit[clauseIdx] >= 2 */
00816           // If watch1[clauseIdx] has been flipped
00817         if (watch1 == toFlip)
00818         {
00819             // find a different true literal to watch
00820           int diffTrueLit = getTrueLiteralOtherThan(clauseIdx, watch1, watch2);
00821           setWatch1(clauseIdx, diffTrueLit);
00822         }
00823           // If watch2[clauseIdx] has been flipped
00824         else if (watch2 == toFlip)
00825         {
00826             // find a different true literal to watch
00827           int diffTrueLit = getTrueLiteralOtherThan(clauseIdx, watch1, watch2);
00828           setWatch2(clauseIdx, diffTrueLit);
00829         }
00830       }
00831     }
00832         
00833       // Update all clauses in which the atom occurs as a false literal
00834     litIdx = 2*toFlip - oppSign;
00835     Array<int>& negOccArray = getOccurenceArray(litIdx);
00836     for (int i = 0; i < negOccArray.size(); i++)
00837     {
00838       clauseIdx = negOccArray[i];
00839         // Don't look at dead clauses
00840       if (deadClause_[clauseIdx]) continue;
00841         // The false lit became a true lit
00842       int numTrueLits = incrementNumTrueLits(clauseIdx);
00843       long double cost = getClauseCost(clauseIdx);
00844       int watch1 = getWatch1(clauseIdx);
00845 
00846         // 1. If this is the only true lit, then we have to update
00847         // the makecost / breakcost info accordingly        
00848       if (numTrueLits == 1)
00849       {
00850           // Pos. clause
00851         if (cost > 0)
00852         {
00853             // Remove this clause as false in the state
00854           removeFalseClause(clauseIdx);
00855             // Increase toFlip's breakcost (add pos. cost)
00856           addBreakCost(toFlip, cost);        
00857             // Decrease makecost of all vars in clause (add neg. cost)
00858           addMakeCostToAtomsInClause(clauseIdx, -cost);
00859         }
00860           // Neg. clause
00861         else
00862         {
00863           assert(cost < 0);
00864             // Add this clause as false in the state
00865           addFalseClause(clauseIdx);
00866             // Decrease breakcost of all vars in clause (add neg. cost)
00867           addBreakCostToAtomsInClause(clauseIdx, cost);
00868             // Increase toFlip's makecost (add pos. cost)
00869           addMakeCost(toFlip, -cost);
00870         }
00871           // Watch this atom
00872         setWatch1(clauseIdx, toFlip);
00873       }
00874         // 2. If there are now exactly 2 true lits, then watch second atom
00875         // and update breakcost
00876       else
00877       if (numTrueLits == 2)
00878       {
00879         if (cost > 0)
00880         {
00881             // Pos. clause
00882             // Decrease breakcost of first atom being watched (add neg. cost)
00883           addBreakCost(watch1, -cost);
00884         }
00885         else
00886         {
00887             // Neg. clause
00888           assert(cost < 0);
00889             // Decrease makecost of first atom being watched (add neg. cost)
00890           addMakeCost(watch1, cost);
00891         }
00892         
00893           // Watch second atom
00894         setWatch2(clauseIdx, toFlip);
00895       }
00896     }
00897   }
00898 
00903   void flipAtomValue(const int& atomIdx)
00904   {
00905     bool opposite = !atom_[atomIdx];
00906     setValueOfAtom(atomIdx, opposite);
00907   }
00908 
00920   long double getImprovementByFlipping(const int& atomIdx)
00921   {
00922     if (lazy_ && !isActive(atomIdx))
00923     {
00924         // First flip the atom to activate it, then flip it back
00925       flipAtom(atomIdx);
00926       flipAtom(atomIdx);
00927     }
00928     long double improvement = makeCost_[atomIdx] - breakCost_[atomIdx];
00929     return improvement;
00930   }
00931   
00938   void activateAtom(const int& atomIdx, const bool& ignoreActivePreds)
00939   {
00940       // Lazy version: if atom is not active, we need to activate clauses
00941       // and take their cost into account
00942     if (lazy_ && !isActive(atomIdx))
00943     {
00944       Predicate* p =
00945         gndPredHashArray_[atomIdx - 1]->createEquivalentPredicate(domain_);
00946       getActiveClauses(p, newClauses_, true, ignoreActivePreds);
00947         // Add the clauses and preds and fill info arrays
00948       bool initial = false;
00949       addNewClauses(initial);
00950         // Set active status in db
00951       domain_->getDB()->setActiveStatus(p, true);
00952       activeAtoms_++;
00953       delete p;
00954     }        
00955   }
00956 
00963   bool isActive(const int& atomIdx)
00964   {
00965     return domain_->getDB()->getActiveStatus(gndPredHashArray_[atomIdx-1]);
00966   }
00967   
00974   bool isActive(const Predicate* pred)
00975   {
00976     return domain_->getDB()->getActiveStatus(pred);
00977   }
00978 
00982   Array<int>& getOccurenceArray(const int& idx)
00983   {
00984     return occurence_[idx];
00985   }
00986   
00990   int incrementNumTrueLits(const int& clauseIdx)
00991   {
00992     return ++numTrueLits_[clauseIdx];
00993   }
00994 
00998   int decrementNumTrueLits(const int& clauseIdx)
00999   {
01000     return --numTrueLits_[clauseIdx];
01001   }
01002 
01006   int getNumTrueLits(const int& clauseIdx)
01007   {
01008     return numTrueLits_[clauseIdx];
01009   }
01010 
01014   long double getClauseCost(const int& clauseIdx)
01015   {
01016     return clauseCost_[clauseIdx];
01017   }
01018   
01022   Array<int>& getAtomsInClause(const int& clauseIdx)
01023   {
01024     return clause_[clauseIdx];
01025   }
01026 
01030   void addFalseClause(const int& clauseIdx)
01031   {
01032     falseClause_[numFalseClauses_] = clauseIdx;
01033     whereFalse_[clauseIdx] = numFalseClauses_;
01034     numFalseClauses_++;
01035     costOfFalseClauses_ += abs(clauseCost_[clauseIdx]);
01036   }
01037   
01041   void removeFalseClause(const int& clauseIdx)
01042   {
01043     numFalseClauses_--;
01044     falseClause_[whereFalse_[clauseIdx]] = falseClause_[numFalseClauses_];
01045     whereFalse_[falseClause_[numFalseClauses_]] = whereFalse_[clauseIdx];
01046     costOfFalseClauses_ -= abs(clauseCost_[clauseIdx]);
01047   }
01048 
01052   void addBreakCost(const int& atomIdx, const long double& cost)
01053   {
01054     breakCost_[atomIdx] += cost;
01055   }
01056 
01060   void subtractBreakCost(const int& atomIdx, const long double& cost)
01061   {
01062     breakCost_[atomIdx] -= cost;
01063   }
01064 
01071   void addBreakCostToAtomsInClause(const int& clauseIdx,
01072                                    const long double& cost)
01073   {
01074     register int size = getClauseSize(clauseIdx);
01075     for (int i = 0; i < size; i++)
01076     {
01077       register int lit = clause_[clauseIdx][i];
01078       breakCost_[abs(lit)] += cost;
01079     }
01080   }
01081 
01088   void subtractBreakCostFromAtomsInClause(const int& clauseIdx,
01089                                           const long double& cost)
01090   {
01091     register int size = getClauseSize(clauseIdx);
01092     for (int i = 0; i < size; i++)
01093     {
01094       register int lit = clause_[clauseIdx][i];
01095       breakCost_[abs(lit)] -= cost;
01096     }
01097   }
01098 
01105   void addMakeCost(const int& atomIdx, const long double& cost)
01106   {
01107     makeCost_[atomIdx] += cost;
01108   }
01109 
01116   void subtractMakeCost(const int& atomIdx, const long double& cost)
01117   {
01118     makeCost_[atomIdx] -= cost;
01119   }
01120 
01127   void addMakeCostToAtomsInClause(const int& clauseIdx,
01128                                   const long double& cost)
01129   {
01130     register int size = getClauseSize(clauseIdx);
01131     for (int i = 0; i < size; i++)
01132     {
01133       register int lit = clause_[clauseIdx][i];
01134       makeCost_[abs(lit)] += cost;
01135     }
01136   }
01137 
01144   void subtractMakeCostFromAtomsInClause(const int& clauseIdx,
01145                                          const long double& cost)
01146   {
01147     register int size = getClauseSize(clauseIdx);
01148     for (int i = 0; i < size; i++)
01149     {
01150       register int lit = clause_[clauseIdx][i];
01151       makeCost_[abs(lit)] -= cost;
01152     }
01153   }
01154 
01164   const int getTrueLiteralOtherThan(const int& clauseIdx,
01165                                     const int& atomIdx1,
01166                                     const int& atomIdx2)
01167   {
01168     register int size = getClauseSize(clauseIdx);
01169     for (int i = 0; i < size; i++)
01170     {
01171       register int lit = clause_[clauseIdx][i];
01172       register int v = abs(lit);
01173       if (isTrueLiteral(lit) && v != atomIdx1 && v != atomIdx2)
01174         return v;
01175     }
01176       // If we're here, then no other true lit exists
01177     assert(false);
01178     return -1;
01179   }
01180   
01184   const bool isTrueLiteral(const int& literal)
01185   {
01186     return ((literal > 0) == atom_[abs(literal)]);
01187   }
01188 
01192   const int getAtomInClause(const int& atomIdxInClause, const int& clauseIdx)
01193   {
01194     return clause_[clauseIdx][atomIdxInClause];
01195   }
01196 
01200   const int getRandomAtomInClause(const int& clauseIdx)
01201   {
01202     return clause_[clauseIdx][random()%getClauseSize(clauseIdx)];
01203   }
01204 
01211   const int getRandomTrueLitInClause(const int& clauseIdx)
01212   {
01213     assert(numTrueLits_[clauseIdx] > 0);
01214     int trueLit = random()%numTrueLits_[clauseIdx];
01215     int whichTrueLit = 0;
01216     for (int i = 0; i < getClauseSize(clauseIdx); i++)
01217     {
01218       int lit = clause_[clauseIdx][i];
01219       int atm = abs(lit);
01220         // True literal
01221       if (isTrueLiteral(lit))
01222         if (trueLit == whichTrueLit++)
01223           return atm;
01224     }
01225       // If we're here, then no other true lit exists
01226     assert(false);
01227     return -1;
01228   }
01229 
01230   const double getMaxClauseWeight()
01231   {
01232     double maxWeight = 0.0;
01233     for (int i = 0; i < getNumClauses(); i++)
01234     {
01235       double weight = abs(clauseCost_[i]);
01236       if (weight > maxWeight) maxWeight = weight;
01237     }
01238     return maxWeight;
01239   }
01240   
01241   const long double getMakeCost(const int& atomIdx)
01242   {
01243     return makeCost_[atomIdx];
01244   }
01245    
01246   const long double getBreakCost(const int& atomIdx)
01247   {
01248     return breakCost_[atomIdx];
01249   }
01250    
01251   const int getClauseSize(const int& clauseIdx)
01252   {
01253     return clause_[clauseIdx].size();
01254   }
01255 
01256   const int getWatch1(const int& clauseIdx)
01257   {
01258     return watch1_[clauseIdx];
01259   }
01260 
01261   void setWatch1(const int& clauseIdx, const int& atomIdx)
01262   {
01263     watch1_[clauseIdx] = atomIdx;
01264   }
01265   
01266   const int getWatch2(const int& clauseIdx)
01267   {
01268     return watch2_[clauseIdx];
01269   }
01270 
01271   void setWatch2(const int& clauseIdx, const int& atomIdx)
01272   {
01273     watch2_[clauseIdx] = atomIdx;
01274   }
01275   
01276   const bool isBlockEvidence(const int& blockIdx)
01277   {
01278     return (*blockEvidence_)[blockIdx];
01279   }
01280 
01281   const int getBlockSize(const int& blockIdx)
01282   {
01283       // Lazy: blocks are in domain
01284     if (lazy_)
01285       return domain_->getPredBlock(blockIdx)->size();
01286     else
01287       return (*blocks_)[blockIdx].size();
01288   }
01289   
01294   const int getBlockIndex(const int& atomIdx)
01295   {
01296     for (int i = 0; i < blocks_->size(); i++)
01297     {
01298       int blockIdx = (*blocks_)[i].find(atomIdx);
01299       if (blockIdx >= 0)
01300         return i;
01301     }
01302     return -1;
01303   }
01304   
01305   Array<int>& getBlockArray(const int& blockIdx)
01306   {
01307     return (*blocks_)[blockIdx];
01308   }
01309 
01310   bool getBlockEvidence(const int& blockIdx)
01311   {
01312     return (*blockEvidence_)[blockIdx];
01313   }
01314   
01315   int getNumBlocks()
01316   {
01317     return blocks_->size();
01318   }
01319   
01323   const long double getLowCost()
01324   {
01325     return lowCost_; 
01326   }
01327 
01331   const int getLowBad()
01332   {
01333     return lowBad_;
01334   }
01335 
01340   void makeUnitCosts()
01341   {
01342     for (int i = 0; i < clauseCost_.size(); i++)
01343     {
01344       if (clauseCost_[i] > 0) clauseCost_[i] = 1.0;
01345       else
01346       {
01347         assert(clauseCost_[i] < 0);
01348         clauseCost_[i] = -1.0;
01349       }
01350     }
01351   }
01352 
01356   void saveLowState()
01357   {
01358     if (vsdebug) cout << "Saving low state: " << endl;
01359     for (int i = 1; i <= getNumAtoms(); i++)
01360     {
01361       lowAtom_[i] = atom_[i];
01362       if (vsdebug) cout << lowAtom_[i] << endl;
01363     }
01364     lowCost_ = costOfFalseClauses_;
01365     lowBad_ = numFalseClauses_;
01366   }
01367 
01371   int getTrueFixedAtomInBlock(const int blockIdx)
01372   {
01373     Array<int>& block = (*blocks_)[blockIdx];
01374     for (int i = 0; i < block.size(); i++)
01375       if (fixedAtom_[block[i] + 1] > 0) return i;
01376     return -1;
01377   }
01378 
01379   const GroundPredicateHashArray* getGndPredHashArrayPtr() const
01380   {
01381     return &gndPredHashArray_;
01382   }
01383 
01384   const GroundPredicateHashArray* getUnePreds() const
01385   {
01386     return unePreds_;
01387   }
01388 
01389   const GroundPredicateHashArray* getKnePreds() const
01390   {
01391     return knePreds_;
01392   }
01393 
01394   const Array<TruthValue>* getKnePredValues() const
01395   {
01396     return knePredValues_;
01397   }
01398 
01402   void setGndClausesWtsToSumOfParentWts()
01403   {
01404     for (int i = 0; i < gndClauses_->size(); i++)
01405     {
01406       (*gndClauses_)[i]->setWtToSumOfParentWts();
01407       if (vsdebug) cout << "Setting cost of clause " << i << " to "
01408                         << (*gndClauses_)[i]->getWt() << endl;
01409       clauseCost_[i] = (*gndClauses_)[i]->getWt();
01410     }
01411   }
01412 
01421   void getNumClauseGndings(Array<double>* const & numGndings, bool tv)
01422   {
01423     // TODO: lazy version
01424     IntPairItr itr;
01425     IntPair *clauseFrequencies;
01426     
01427       // numGndings should have been initialized with non-negative values
01428     int clauseCnt = numGndings->size();
01429     assert(clauseCnt == mln_->getNumClauses());
01430     for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
01431       assert ((*numGndings)[clauseno] >= 0);
01432     
01433     for (int i = 0; i < gndClauses_->size(); i++)
01434     {
01435       GroundClause *gndClause = (*gndClauses_)[i];
01436       int satLitcnt = getNumTrueLits(i);
01437       if (tv && satLitcnt == 0)
01438         continue;
01439       if (!tv && satLitcnt > 0)
01440         continue;
01441 
01442       clauseFrequencies = gndClause->getClauseFrequencies();
01443       for (itr = clauseFrequencies->begin();
01444            itr != clauseFrequencies->end(); itr++)
01445       {
01446         int clauseno = itr->first;
01447         int frequency = itr->second;
01448         (*numGndings)[clauseno] += frequency;
01449       }
01450     }
01451   }
01452 
01461   void getNumClauseGndings(double numGndings[], int clauseCnt, bool tv)
01462   {
01463     // TODO: lazy version
01464     IntPairItr itr;
01465     IntPair *clauseFrequencies;
01466     
01467     for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
01468       numGndings[clauseno] = 0;
01469     
01470     for (int i = 0; i < gndClauses_->size(); i++)
01471     {
01472       GroundClause *gndClause = (*gndClauses_)[i];
01473       int satLitcnt = getNumTrueLits(i);
01474       if (tv && satLitcnt == 0)
01475         continue;
01476       if (!tv && satLitcnt > 0)
01477         continue;
01478 
01479       clauseFrequencies = gndClause->getClauseFrequencies();
01480       for (itr = clauseFrequencies->begin();
01481            itr != clauseFrequencies->end(); itr++)
01482       {
01483         int clauseno = itr->first;
01484         int frequency = itr->second;
01485         numGndings[clauseno] += frequency;
01486       }
01487     }
01488   }
01489 
01501   void getNumClauseGndingsWithUnknown(double numGndings[], int clauseCnt,
01502                                       bool tv,
01503                                       const Array<bool>* const& unknownPred)
01504   {
01505     assert(unknownPred->size() == getNumAtoms());
01506     IntPairItr itr;
01507     IntPair *clauseFrequencies;
01508     
01509     for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
01510       numGndings[clauseno] = 0;
01511     
01512     for (int i = 0; i < gndClauses_->size(); i++)
01513     {
01514       GroundClause *gndClause = (*gndClauses_)[i];
01515       int satLitcnt = 0;
01516       bool unknown = false;
01517       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
01518       {
01519         int lit = gndClause->getGroundPredicateIndex(j);
01520         if ((*unknownPred)[abs(lit) - 1])
01521         {
01522           unknown = true;
01523           continue;
01524         }
01525         if (isTrueLiteral(lit)) satLitcnt++;
01526       }
01527       
01528       if (tv && satLitcnt == 0)
01529         continue;
01530       if (!tv && (satLitcnt > 0 || unknown))
01531         continue;
01532 
01533       clauseFrequencies = gndClause->getClauseFrequencies();
01534       for (itr = clauseFrequencies->begin();
01535            itr != clauseFrequencies->end(); itr++)
01536       {
01537         int clauseno = itr->first;
01538         int frequency = itr->second;
01539         numGndings[clauseno] += frequency;
01540       }
01541     }
01542   }
01543 
01550   void setOthersInBlockToFalse(const int& atomIdx,
01551                                const int& blockIdx)
01552   {
01553     Array<int>& block = (*blocks_)[blockIdx];
01554     for (int i = 0; i < block.size(); i++)
01555     {
01556         // Atom not the one specified and not fixed
01557       if (i != atomIdx && fixedAtom_[block[i] + 1] == 0)
01558         setValueOfAtom(block[i] + 1, false);
01559     }
01560   }
01561 
01562 
01571   void fixAtom(const int& atomIdx, const bool& value)
01572   {
01573     assert(atomIdx > 0);
01574       // If already fixed to opp. sense, then contradiction
01575     if ((fixedAtom_[atomIdx] == 1 && value == false) ||
01576         (fixedAtom_[atomIdx] == -1 && value == true))
01577     {
01578       cout << "Contradiction: Tried to fix atom " << atomIdx <<
01579       " to true and false ... exiting." << endl;
01580       exit(0);
01581     }
01582 
01583     if (vsdebug)
01584     {
01585       cout << "Fixing ";
01586       (*gndPreds_)[atomIdx - 1]->print(cout, domain_);
01587       cout << " to " << value << endl;
01588     }
01589     
01590     setValueOfAtom(atomIdx, value);
01591     fixedAtom_[atomIdx] = (value) ? 1 : -1;
01592   }
01593 
01605   Array<int>* simplifyClauseFromFixedAtoms(const int& clauseIdx)
01606   {
01607     Array<int>* returnArray = new Array<int>;
01608       // If already satisfied from fixed atoms, then return empty array
01609     if (isSatisfied_[clauseIdx]) return returnArray;
01610 
01611       // Keeps track of pos. clause being satisfied or 
01612       // neg. clause being unsatisfied due to fixed atoms
01613     bool isGood = (clauseCost_[clauseIdx] > 0) ? false : true;
01614       // Keeps track of all atoms being fixed to false in a pos. clause
01615     bool allFalseAtoms = (clauseCost_[clauseIdx] > 0) ? true : false;
01616       // Check each literal in clause
01617     for (int i = 0; i < getClauseSize(clauseIdx); i++)
01618     {
01619       int lit = clause_[clauseIdx][i];
01620       int fixedValue = fixedAtom_[abs(lit)];
01621 
01622       if (clauseCost_[clauseIdx] > 0)
01623       { // Pos. clause: check if clause is satisfied
01624         if ((fixedValue == 1 && lit > 0) ||
01625             (fixedValue == -1 && lit < 0))
01626         { // True fixed lit
01627           isGood = true;
01628           allFalseAtoms = false;
01629           returnArray->clear();
01630           break;
01631         }
01632         else if (fixedValue == 0)
01633         { // Lit not fixed
01634           allFalseAtoms = false;
01635           returnArray->append(lit);
01636         }
01637       }
01638       else
01639       { // Neg. clause:
01640         assert(clauseCost_[clauseIdx] < 0);
01641         if ((fixedValue == 1 && lit > 0) ||
01642             (fixedValue == -1 && lit < 0))
01643         { // True fixed lit
01644           cout << "Contradiction: Tried to fix atom " << abs(lit) <<
01645           " to true in a negative clause ... exiting." << endl;
01646           exit(0);
01647         }
01648         else
01649         { // False fixed lit or non-fixed lit
01650           returnArray->append(lit);
01651             // Non-fixed lit
01652           if (fixedValue == 0) isGood = false;          
01653         }
01654       }
01655     }
01656     if (allFalseAtoms)
01657     {
01658       cout << "Contradiction: All atoms in clause " << clauseIdx <<
01659       " fixed to false ... exiting." << endl;
01660       exit(0);
01661     }
01662     if (isGood) isSatisfied_[clauseIdx] = true;
01663     return returnArray;
01664   }
01665 
01672   const bool isDeadClause(const int& clauseIdx)
01673   {
01674     return deadClause_[clauseIdx];
01675   }
01676 
01680   void eliminateSoftClauses()
01681   {
01682     bool atLeastOneDead = false;
01683     for (int i = 0; i < getNumClauses(); i++)
01684     {
01685       if (!(*gndClauses_)[i]->isHardClause())
01686       {
01687         atLeastOneDead = true;
01688         deadClause_[i] = true;
01689       }
01690     }
01691     if (atLeastOneDead) initMakeBreakCostWatch(0);
01692   }
01693  
01701   void killClauses(const int& startClause)
01702   {
01703     for (int i = startClause; i < getNumClauses(); i++)
01704     {
01705       GroundClause* clause = (*gndClauses_)[i];
01706       if ((clauseGoodInPrevious(i)) &&
01707           (clause->isHardClause() || random() <= threshold_[i]))
01708       {
01709         if (vsdebug)
01710         {
01711           cout << "Keeping clause "<< i << " ";
01712           clause->print(cout, domain_, &gndPredHashArray_);
01713           cout << endl;
01714         }
01715         deadClause_[i] = false;
01716       }
01717       else
01718       {
01719         deadClause_[i] = true;
01720       }
01721     }
01722     initMakeBreakCostWatch(startClause);
01723   }
01724 
01725   
01733   const bool clauseGoodInPrevious(const int& clauseIdx)
01734   {
01735     //GroundClause* clause = (*gndClauses_)[clauseIdx];
01736     int numSatLits = numTrueLits_[clauseIdx];
01737       // Num. of satisfied lits in previous iteration is stored in clause
01738     if ((numSatLits > 0 && clauseCost_[clauseIdx] > 0.0) ||
01739         (numSatLits == 0 && clauseCost_[clauseIdx] < 0.0))
01740       return true;
01741     else
01742       return false;
01743   }
01744 
01748   void resetDeadClauses()
01749   {
01750     for (int i = 0; i < deadClause_.size(); i++)
01751       deadClause_[i] = false;
01752     initMakeBreakCostWatch(0);
01753   }
01754    
01758   void resetFixedAtoms()
01759   {
01760     for (int i = 0; i < fixedAtom_.size(); i++)
01761       fixedAtom_[i] = 0;
01762     for (int i = 0; i < isSatisfied_.size(); i++)
01763       isSatisfied_[i] = false;
01764   }
01765 
01766   void setLazy(const bool& l) { lazy_ = l; }
01767   const bool getLazy() { return lazy_; }
01768 
01769   void setUseThreshold(const bool& t) { useThreshold_ = t;}
01770   const bool getUseThreshold() { return useThreshold_; }
01771   
01772   long double getHardWt() { return hardWt_; }
01773   
01774   const Domain* getDomain() { return domain_; }
01775 
01776   const MLN* getMLN() { return mln_; }
01777 
01783   void printLowState(ostream& out)
01784   {
01785     for (int i = 0; i < getNumAtoms(); i++)
01786     {
01787       (*gndPreds_)[i]->print(out, domain_);
01788       out << " " << lowAtom_[i + 1] << endl;
01789     }
01790   }
01791 
01798   void printGndPred(const int& predIndex, ostream& out)
01799   {
01800     (*gndPreds_)[predIndex]->print(out, domain_);
01801   }
01802 
01804   
01811   GroundPredicate* getGndPred(const int& index)
01812   {
01813     return (*gndPreds_)[index];
01814   }
01815 
01822   GroundClause* getGndClause(const int& index)
01823   {
01824     return (*gndClauses_)[index];
01825   }
01826 
01830   void saveLowStateToGndPreds()
01831   {
01832     for (int i = 0; i < getNumAtoms(); i++)
01833       (*gndPreds_)[i]->setTruthValue(lowAtom_[i + 1]);
01834   }
01835 
01839   void saveLowStateToDB()
01840   {
01841     for (int i = 0; i < getNumAtoms(); i++)
01842     {
01843       GroundPredicate* p = gndPredHashArray_[i];
01844       bool value = lowAtom_[i + 1];
01845       if (value)
01846       {
01847         domain_->getDB()->setValue(p, TRUE);
01848       }
01849       else
01850       {
01851         domain_->getDB()->setValue(p, FALSE);
01852       }
01853     }
01854   }
01855 
01862   const int getGndPredIndex(GroundPredicate* const& gndPred)
01863   {
01864     return gndPreds_->find(gndPred);
01865   }
01866 
01867      
01869 
01870 
01872  
01886   void getActiveClauses(Predicate *inputPred,
01887                         Array<GroundClause*>& activeClauses,
01888                         bool const & active,
01889                         bool const & ignoreActivePreds)
01890   {
01891     Clause *fclause;
01892     GroundClause* newClause;
01893     int clauseCnt;
01894     GroundClauseHashArray clauseHashArray;
01895 
01896     Array<GroundClause*>* newClauses = new Array<GroundClause*>; 
01897   
01898     const Array<IndexClause*>* indexClauses = NULL;
01899       
01900       // inputPred is null: all active clauses should be retrieved
01901     if (inputPred == NULL)
01902     {
01903       clauseCnt = mln_->getNumClauses();
01904     }
01905       // Otherwise, look at all first order clauses containing the pred
01906     else
01907     {
01908       if (domain_->getDB()->getDeactivatedStatus(inputPred)) return;
01909       int predId = inputPred->getId();
01910       indexClauses = mln_->getClausesContainingPred(predId);
01911       clauseCnt = indexClauses->size();
01912     }
01913 
01914       // Look at each first-order clause and get active groundings
01915     int clauseno = 0;
01916     while (clauseno < clauseCnt)
01917     {
01918       if (inputPred)
01919         fclause = (Clause *) (*indexClauses)[clauseno]->clause;           
01920       else
01921         fclause = (Clause *) mln_->getClause(clauseno);
01922 
01923       if (vsdebug)
01924       {
01925         cout << "Getting active clauses for FO clause: ";
01926         fclause->print(cout, domain_);
01927         cout << endl;
01928       }
01929       
01930       long double wt = fclause->getWt();
01931       const double* parentWtPtr = NULL;
01932       if (!fclause->isHardClause()) parentWtPtr = fclause->getWtPtr();
01933       const int clauseId = mln_->findClauseIdx(fclause);
01934       newClauses->clear();
01935 
01936       fclause->getActiveClauses(inputPred, domain_, newClauses,
01937                                 &gndPredHashArray_, ignoreActivePreds);
01938 
01939       for (int i = 0; i < newClauses->size(); i++)
01940       {
01941         newClause = (*newClauses)[i];
01942         int pos = clauseHashArray.find(newClause);
01943           // If clause already present, then just add weight
01944         if (pos >= 0)
01945         {
01946           if (vsdebug)
01947           {
01948             cout << "Adding weight " << wt << " to clause ";
01949             clauseHashArray[pos]->print(cout, domain_, &gndPredHashArray_);
01950             cout << endl;
01951           }
01952           clauseHashArray[pos]->addWt(wt);
01953           if (parentWtPtr)
01954           {
01955             clauseHashArray[pos]->appendParentWtPtr(parentWtPtr);
01956             clauseHashArray[pos]->incrementClauseFrequency(clauseId, 1);
01957           }
01958           delete newClause;
01959           continue;
01960         }
01961 
01962           // If here, then clause is not yet present        
01963         newClause->setWt(wt);
01964         newClause->appendToGndPreds(&gndPredHashArray_);
01965         if (parentWtPtr)
01966         {
01967           newClause->appendParentWtPtr(parentWtPtr);
01968           newClause->incrementClauseFrequency(clauseId, 1);
01969           assert(newClause->getWt() == *parentWtPtr);
01970         }      
01971 
01972         if (vsdebug)
01973         {
01974           cout << "Appending clause ";
01975           newClause->print(cout, domain_, &gndPredHashArray_);
01976           cout << endl;
01977         }
01978         clauseHashArray.append(newClause);
01979       }
01980       clauseno++; 
01981     } //while (clauseno < clauseCnt)
01982 
01983     for (int i = 0; i < clauseHashArray.size(); i++)
01984     {
01985       newClause = clauseHashArray[i];
01986       activeClauses.append(newClause);
01987     }
01988     delete newClauses;
01989   }
01990 
01998   void getActiveClauses(Array<GroundClause*> &allClauses,
01999                         bool const & ignoreActivePreds)
02000   {
02001     getActiveClauses(NULL, allClauses, true, ignoreActivePreds);
02002   }
02003   
02004   int getNumActiveAtoms()
02005   {
02006     return activeAtoms_; 
02007   }
02008 
02013   void addOneAtomToEachBlock()
02014   {
02015     assert(lazy_);
02016       // For each block: select one to set to true
02017     for (int i = 0; i < blocks_->size(); i++)
02018     {
02019         // If evidence atom exists, then all others are false
02020       if ((*blockEvidence_)[i])
02021       {
02022           // If first argument is -1, then all are set to false
02023         setOthersInBlockToFalse(-1, i);
02024         continue;
02025       }
02026 
02027         // Assumption is initLazyBlocks has been called
02028         // Pick one ground pred from the block in the domain
02029       const Array<Predicate*>* block = domain_->getPredBlock(i);
02030 
02031       int chosen = random() % block->size();
02032       Predicate* pred = (*block)[chosen];
02033       GroundPredicate* groundPred = new GroundPredicate(pred);
02034 
02035         // If chosen pred is not yet present, then add it, otherwise delete it
02036       int index = gndPredHashArray_.find(groundPred);
02037       if (index < 0)
02038       {
02039           // Pred not yet present
02040         index = gndPredHashArray_.append(groundPred);
02041         (*blocks_)[i].append(index);
02042         chosen = (*blocks_)[i].size() - 1;
02043           // addNewClauses adds the predicate to the state and updates
02044           // info arrays
02045         bool initial = false;
02046         addNewClauses(initial);
02047       }
02048       else
02049       {
02050         delete groundPred;
02051         chosen = (*blocks_)[i].find(index);
02052       }
02053       setValueOfAtom(index + 1, true);
02054       setOthersInBlockToFalse(chosen, i);
02055     }
02056   }
02057 
02061   void initLazyBlocks()
02062   {
02063     assert(lazy_);
02064     blocks_ = new Array<Array<int> >;
02065     blocks_->growToSize(domain_->getNumPredBlocks());
02066     blockEvidence_ = new Array<bool>(*(domain_->getBlockEvidenceArray()));
02067   }
02068 
02072   void fillLazyBlocks()
02073   {
02074     assert(lazy_);
02075     const Array<Array<Predicate*>*>* blocks = domain_->getPredBlocks();
02076     for (int i = 0; i < blocks->size(); i++)
02077     {
02078       if (vsdebug) cout << "Block " << i << endl;
02079       Array<Predicate*>* block = (*blocks)[i];
02080       for (int j = 0; j < block->size(); j++)
02081       {
02082         Predicate* pred = (*block)[j];
02083         if (vsdebug)
02084         {
02085           cout << "\tPred: ";
02086           pred->printWithStrVar(cout, domain_);
02087           cout << endl;
02088         }
02089           // Add all non-evid preds in blocks to the state
02090         if (domain_->getDB()->getEvidenceStatus(pred))
02091           continue;
02092         GroundPredicate* groundPred = new GroundPredicate(pred);
02093 
02094           // Add pred if not yet present, otherwise delete it
02095         int index = gndPredHashArray_.find(groundPred);
02096         if (index < 0)
02097           index = gndPredHashArray_.append(groundPred);
02098         else
02099           delete groundPred;
02100 
02101           // Append the atom to the block if not yet there
02102         if (!(*blocks_)[i].contains(index))
02103           (*blocks_)[i].append(index);
02104       }
02105     }
02106       // addNewClauses adds the predicates to the state and updates info arrays
02107     bool initial = true;
02108     addNewClauses(initial);
02109   }
02110 
02112 
02113   
02114  private:
02115  
02121   void setHardClauseWeight()
02122   {
02123       // Soft weights are summed up to determine hard weight
02124     long double sumSoftWts = 0.0;
02125       // Determine hard clause weight
02126     int clauseCnt = mln_->getNumClauses();    
02127       // Sum up the soft weights of all grounded clauses
02128     for (int i = 0; i < clauseCnt; i++)
02129     {
02130       Clause* fclause = (Clause *) mln_->getClause(i);
02131         // Skip hard clauses
02132       if (fclause->isHardClause()) continue;
02133         // Weight could be negative
02134       long double wt = abs(fclause->getWt());
02135       long double numGndings = fclause->getNumGroundings(domain_);
02136       sumSoftWts += wt*numGndings;
02137     }
02138     assert(sumSoftWts >= 0);
02139       // Add constant so weight isn't zero if no soft clauses present
02140     hardWt_ = sumSoftWts + 10.0;
02141     cout << "Set hard weight to " << hardWt_ << endl;
02142   }
02143 
02144  private:
02145 
02146     // If true, this is a lazy variable state, else eager.
02147   bool lazy_;
02148 
02149     // Weight used for hard clauses (sum of soft weights + constant)
02150   long double hardWt_;
02151   
02152     // mln and domain are used to build MRF in eager state and to
02153     // retrieve active atoms in lazy state.
02154   MLN* mln_;
02155   Domain* domain_;
02156 
02157     // Eager version: Pointer to gndPreds_ and gndClauses_ in MRF
02158     // Lazy version: Holds active atoms and clauses
02159   Array<GroundPredicate*>* gndPreds_;
02160   Array<GroundClause*>* gndClauses_;
02161   
02162     // Predicates corresponding to the groundings of the unknown non-evidence
02163     // predicates
02164   GroundPredicateHashArray* unePreds_;
02165 
02166     // Predicates corresponding to the groundings of the known non-evidence
02167     // predicates
02168   GroundPredicateHashArray* knePreds_;
02169     // Actual truth values of ground known non-evidence preds
02170   Array<TruthValue>* knePredValues_;  
02171   
02172 
02174     // Number of distinct atoms in the first set of unsatisfied clauses
02175   int baseNumAtoms_;
02176     // If true, atoms are not deactivated when mem. is full
02177   bool noApprox_;
02178     // Indicates whether deactivation of atoms has taken place yet
02179   bool haveDeactivated_;
02180     // Max. amount of memory to use
02181   int memLimit_;
02182     // Max. amount of clauses memory can hold
02183   int clauseLimit_;
02185 
02186 
02188     // MRF is used with eager states. If lazy, this stays NULL.
02189   MRF* mrf_;
02191 
02192     // Holds the new active clauses
02193   Array<GroundClause*> newClauses_;
02194     // Holds the new gnd preds
02195   Array<GroundPredicate*> newPreds_;
02196     // Holds the ground predicates in a hash array.
02197     // Fast access is needed for comparing preds when activating clauses.
02198   GroundPredicateHashArray gndPredHashArray_;
02199 
02200     // Clauses to be satisfied
02201     // Indexed as clause_[clause_num][literal_num]
02202   Array<Array<int> > clause_;
02203     // Cost of each clause (can be negative)
02204   Array<long double> clauseCost_;
02205     // Highest cost of false clause
02206   long double highestCost_;
02207     // If true, more than one clause has highest cost
02208   bool eqHighest_;
02209     // Number of clauses with highest cost
02210   int numHighest_;
02211     // Clauses which are pos. and unsatisfied or neg. and satisfied
02212   Array<int> falseClause_;
02213     // Where each clause is listed in falseClause_
02214   Array<int> whereFalse_;
02215     // Number of true literals in each clause
02216   Array<int> numTrueLits_;
02217     // watch1_[c] contains the id of the first atom which c is watching
02218   Array<int> watch1_;
02219     // watch2_[c] contains the id of the second atom which c is watching
02220   Array<int> watch2_;
02221     // Which clauses are satisfied by fixed atoms
02222   Array<bool> isSatisfied_;
02223     // Clauses which are not to be considered
02224   Array<bool> deadClause_;
02225     // Use threshold to exclude clauses from the state?
02226   bool useThreshold_;
02227     // Pre-computed thresholds for each clause
02228   Array<long double> threshold_;
02229 
02230     // Holds the index of clauses in which each literal occurs
02231     // Indexed as occurence_[2*abs(lit) - (lit > 0)][occurence_num]
02232   Array<Array<int> > occurence_;
02233 
02234     // Current assigment of atoms
02235   Array<bool> atom_;
02236     // Cost of clauses which would become satisfied by flipping each atom
02237   Array<long double> makeCost_;
02238     // Cost of clauses which would become unsatisfied by flipping each atom
02239   Array<long double> breakCost_;
02240     // Indicates if an atom is fixed to a value (0 = not fixed, -1 = false,
02241     // 1 = true)
02242   Array<int> fixedAtom_;
02243 
02244     // Assigment of atoms producing lowest cost so far
02245   Array<bool> lowAtom_;
02246     // Cost of false clauses in the currently best state
02247   long double lowCost_;
02248     // Number of false clauses in the currently best state
02249   int lowBad_;
02250 
02251     // Current no. of unsatisfied clauses
02252   int numFalseClauses_;
02253     // Cost associated with the number of false clauses
02254   long double costOfFalseClauses_;
02255   
02256     // For block inference: blocks_ and blockEvidence_
02257     // All atom indices in (*blocks_)[i] are to be treated as in one block
02258   Array<Array<int> >* blocks_;
02259     // (*blockEvidence_)[i] states whether block i has true evidence and
02260     // thus all should be false
02261   Array<bool >* blockEvidence_;
02262 
02263     // Number of active atoms in state.  
02264   int activeAtoms_;
02265   
02266 };
02267 
02268 #endif /*VARIABLESTATE_H_*/

Generated on Tue Jan 16 05:30:04 2007 for Alchemy by  doxygen 1.5.1