00001 #ifndef _LW_Info_H_
00002 #define _LW_Info_H_
00003
00004 #include<fstream>
00005 #include <sys/time.h>
00006
00007 #include "clause.h"
00008 #include "clausefactory.h"
00009 #include "lwutil.h"
00010 #include "intclause.h"
00011 #include "mln.h"
00012 #include "timer.h"
00013
00014 class LWInfo
00015 {
00016 public:
00017
00018
00019 LWInfo(MLN *mln, Domain *domain)
00020 {
00021
00022 int seed;
00023 struct timeval tv;
00024 struct timezone tzp;
00025 gettimeofday(&tv,&tzp);
00026 seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00027 srandom(seed);
00028
00029 this->mln_ = mln;
00030 this->domain_ = domain;
00031 copyMLN();
00032 setHardClauseWeight();
00033 sampleSat_ = false;
00034 prevDB_ = NULL;
00035 numDBAtoms_ = domain_->getNumNonEvidenceAtoms();
00036 initBlocks();
00037 inBlock_ = false;
00038 }
00039
00040
00041 ~LWInfo()
00042 {
00043
00044 for(int i = 0; i < predHashArray_.size(); i++)
00045 delete predHashArray_[i];
00046
00047
00048 delete mln_;
00049
00050 if (prevDB_) delete prevDB_;
00051
00052 for(int i = 0; i < deadClauses_.size(); i++)
00053 {
00054 deadClauses_[i]->deleteIntPredicates();
00055 delete deadClauses_[i];
00056 }
00057 }
00058
00059 inline int getVarCount() { return predHashArray_.size(); }
00060
00061
00062
00063
00064 inline void copyMLN()
00065 {
00066
00067 MLN* posmln = new MLN();
00068 int clauseCnt = mln_->getNumClauses();
00069
00070
00071 for (int i = 0; i < clauseCnt; i++)
00072 {
00073 Clause* clause = (Clause *) mln_->getClause(i);
00074 if (clause->getWt() == 0) continue;
00075 int numPreds = clause->getNumPredicates();
00076
00077
00078 int idx;
00079 ostringstream oss;
00080 Clause* newClause = new Clause(*clause);
00081 newClause->printWithoutWtWithStrVar(oss, domain_);
00082 bool app = posmln->appendClause(oss.str(), false, newClause,
00083 clause->getWt(), clause->isHardClause(), idx);
00084 if (app)
00085 {
00086 posmln->setFormulaNumPreds(oss.str(), numPreds);
00087 posmln->setFormulaIsHard(oss.str(), clause->isHardClause());
00088 posmln->setFormulaPriorMean(oss.str(), clause->getWt());
00089 }
00090 }
00091 mln_ = posmln;
00092
00093 }
00094
00095
00096
00097
00098 inline void setHardClauseWeight()
00099 {
00100
00101 LWInfo::HARD_WT = 10.0;
00102
00103 int clauseCnt = mln_->getNumClauses();
00104 double sumSoftWts = 0.0;
00105 double minWt = DBL_MAX;
00106 double maxWt = DBL_MIN;
00107
00108
00109 int maxAllowedWt = 4000;
00110
00111
00112 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00113 {
00114 Clause* fclause = (Clause *) mln_->getClause(clauseno);
00115
00116 if (fclause->isHardClause()) continue;
00117
00118
00119 double wt = abs(fclause->getWt());
00120
00121 double numGndings = fclause->getNumGroundings(domain_);
00122
00123 if (wt < minWt) minWt = wt;
00124 if (wt > maxWt) maxWt = wt;
00125 sumSoftWts += wt*numGndings;
00126 assert(minWt >= 0);
00127 assert(maxWt >= 0);
00128 }
00129 assert(sumSoftWts >= 0);
00130
00131
00132 if (sumSoftWts > 0)
00133 {
00134
00135 LWInfo::WSCALE = 1.0;
00136 if (maxWt > maxAllowedWt) LWInfo::WSCALE = maxAllowedWt/maxWt;
00137 else
00138 {
00139 if (minWt < 10)
00140 {
00141 LWInfo::WSCALE = 10/minWt;
00142 if (LWInfo::WSCALE*maxWt > maxAllowedWt) LWInfo::WSCALE = maxAllowedWt/maxWt;
00143 }
00144 }
00145
00146 LWInfo::HARD_WT = (sumSoftWts + 10.0)*LWInfo::WSCALE;
00147 }
00148
00149 }
00150
00151
00152
00153
00154 inline void removeSoftClauses()
00155 {
00156
00157 for (int i = 0; i < mln_->getNumClauses(); i++)
00158 {
00159 Clause* clause = (Clause *) mln_->getClause(i);
00160 if (!clause->isHardClause())
00161 {
00162 mln_->removeClause(i);
00163 i--;
00164 }
00165 }
00166 }
00167
00168 inline void reset()
00169 {
00170 for(int i = 0; i < predHashArray_.size(); i++)
00171 {
00172 delete predHashArray_[i];
00173 }
00174
00175 predHashArray_.clear();
00176 predArray_.clear();
00177 }
00178
00179
00180 inline bool isActive(int atom)
00181 {
00182 return (domain_->getDB()->getActiveStatus(predArray_[atom]));
00183 }
00184
00185
00186 inline bool isDeactivated(int atom)
00187 {
00188 return (domain_->getDB()->getDeactivatedStatus(predArray_[atom]));
00189 }
00190
00191
00192 inline void setActive(int atom)
00193 {
00194 domain_->getDB()->setActiveStatus(predArray_[atom], true);
00195 }
00196
00197
00198 inline void setInactive(int atom)
00199 {
00200 domain_->getDB()->setActiveStatus(predArray_[atom], false);
00201 }
00202
00203 void updatePredArray()
00204 {
00205 int startindex = 0;
00206 if(predArray_.size() > 0)
00207 startindex = predArray_.size()-1;
00208
00209 predArray_.growToSize(predHashArray_.size()+1);
00210 for(int i = startindex; i < predHashArray_.size(); i++)
00211 predArray_[i+1] = predHashArray_[i];
00212 }
00213
00214
00215 void getSupersetClauses(Array<IntClause *> &supersetClauses)
00216 {
00217
00218 int seed;
00219 struct timeval tv;
00220 struct timezone tzp;
00221 gettimeofday(&tv,&tzp);
00222 seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00223 srandom(seed);
00224
00225 Clause *fclause;
00226 IntClause *intClause;
00227 int clauseCnt;
00228 IntClauseHashArray clauseHashArray;
00229
00230 Array<IntClause *>* intClauses = new Array<IntClause *>;
00231
00232 clauseCnt = mln_->getNumClauses();
00233
00234 int clauseno = 0;
00235 while(clauseno < clauseCnt)
00236 {
00237 fclause = (Clause *) mln_->getClause(clauseno);
00238
00239 double wt = fclause->getWt();
00240 intClauses->clear();
00241 bool ignoreActivePreds = false;
00242 fclause->getActiveClauses(NULL, domain_, intClauses,
00243 &predHashArray_, ignoreActivePreds);
00244 updatePredArray();
00245
00246 for (int i = 0; i < intClauses->size(); i++)
00247 {
00248 intClause = (*intClauses)[i];
00249
00250 int pos = clauseHashArray.find(intClause);
00251 if(pos >= 0)
00252 {
00253 clauseHashArray[pos]->addWt(wt);
00254 intClause->deleteIntPredicates();
00255 delete intClause;
00256 continue;
00257 }
00258
00259 intClause->setWt(wt);
00260 clauseHashArray.append(intClause);
00261 }
00262 clauseno++;
00263 }
00264
00265 for(int i = 0; i < clauseHashArray.size(); i++)
00266 {
00267 intClause = clauseHashArray[i];
00268 supersetClauses.append(intClause);
00269 }
00270 delete intClauses;
00271 }
00272
00273
00274
00275
00276 void selectClauses(const Array<IntClause *> &supersetClauses,
00277 Array<Array<int> *> &walksatClauses,
00278 Array<int> &walksatClauseWts)
00279 {
00280 assert(sampleSat_);
00281
00282
00283 int seed;
00284 struct timeval tv;
00285 struct timezone tzp;
00286 gettimeofday(&tv,&tzp);
00287 seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00288 srandom(seed);
00289
00290 for(int i = 0; i < deadClauses_.size(); i++)
00291 {
00292 deadClauses_[i]->deleteIntPredicates();
00293 delete deadClauses_[i];
00294 }
00295 deadClauses_.clearAndCompress();
00296
00297
00298
00299 for (int i = 0; i < supersetClauses.size(); i++)
00300 {
00301 assert(prevDB_);
00302 IntClause* intClause = supersetClauses[i];
00303 double wt = intClause->getWt();
00304 if (wt == 0) continue;
00305
00306
00307
00308 bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00309 if ((wt > 0 && !sat) ||
00310 (wt < 0 && sat))
00311 {
00312 continue;
00313 }
00314
00315
00316 double threshold =
00317 intClause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00318
00319 if (random() > threshold)
00320 {
00321 deadClauses_.append(new IntClause(*intClause));
00322 continue;
00323 }
00324
00325
00326 Array<int>* litClause = (Array<int> *)intClause->getIntPredicates();
00327 walksatClauses.append(new Array<int>(*litClause));
00328 if (wt >= 0) walksatClauseWts.append(1);
00329 else walksatClauseWts.append(-1);
00330 }
00331 }
00332
00333
00341 void getWalksatClauses(Predicate *inputPred,
00342 Array<Array<int> *> &walksatClauses,
00343 Array<int> &walksatClauseWts,
00344 bool const & active)
00345 {
00346
00347 int seed;
00348 struct timeval tv;
00349 struct timezone tzp;
00350 gettimeofday(&tv,&tzp);
00351 seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00352 srandom(seed);
00353
00354 Clause *fclause;
00355 IntClause *intClause;
00356 int clauseCnt;
00357 IntClauseHashArray clauseHashArray;
00358
00359 Array<IntClause *>* intClauses = new Array<IntClause *>;
00360
00361 const Array<IndexClause*>* indexClauses = NULL;
00362
00363 if(inputPred == NULL)
00364 {
00365 clauseCnt = mln_->getNumClauses();
00366 }
00367 else
00368 {
00369 if (domain_->getDB()->getDeactivatedStatus(inputPred)) return;
00370 int predid = inputPred->getId();
00371 indexClauses = mln_->getClausesContainingPred(predid);
00372 clauseCnt = indexClauses->size();
00373 }
00374
00375 int clauseno = 0;
00376 while(clauseno < clauseCnt)
00377 {
00378 if(inputPred)
00379 fclause = (Clause *) (*indexClauses)[clauseno]->clause;
00380 else
00381 fclause = (Clause *) mln_->getClause(clauseno);
00382
00383 double wt = fclause->getWt();
00384 intClauses->clear();
00385 bool ignoreActivePreds = false;
00386
00387 if (active)
00388 {
00389 fclause->getActiveClauses(inputPred, domain_, intClauses,
00390 &predHashArray_, ignoreActivePreds);
00391 }
00392 else
00393 {
00394 fclause->getInactiveClauses(inputPred, domain_, intClauses,
00395 &predHashArray_);
00396 }
00397 updatePredArray();
00398 cout << "intClauses size " << intClauses->size() << endl;
00399 for (int i = 0; i < intClauses->size(); i++)
00400 {
00401 intClause = (*intClauses)[i];
00402
00403
00404 if (sampleSat_)
00405 {
00406 assert(prevDB_);
00407
00408
00409
00410 bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00411 if ((wt >= 0 && !sat) ||
00412 (wt < 0 && sat))
00413 {
00414 intClause->deleteIntPredicates();
00415 delete intClause;
00416 continue;
00417 }
00418
00419
00420 if (deadClauses_.contains(intClause))
00421 {
00422 intClause->deleteIntPredicates();
00423 delete intClause;
00424 continue;
00425 }
00426
00427
00428 double threshold =
00429 fclause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00430
00431 if (random() > threshold)
00432 {
00433 deadClauses_.append(intClause);
00434 continue;
00435 }
00436 }
00437
00438 int pos = clauseHashArray.find(intClause);
00439 if(pos >= 0)
00440 {
00441 clauseHashArray[pos]->addWt(wt);
00442 intClause->deleteIntPredicates();
00443 delete intClause;
00444 continue;
00445 }
00446
00447 intClause->setWt(wt);
00448 clauseHashArray.append(intClause);
00449 }
00450 clauseno++;
00451 }
00452
00453 Array<int>* litClause;
00454 for(int i = 0; i < clauseHashArray.size(); i++)
00455 {
00456 intClause = clauseHashArray[i];
00457 double weight = intClause->getWt();
00458 litClause = (Array<int> *)intClause->getIntPredicates();
00459 walksatClauses.append(litClause);
00460 if (sampleSat_)
00461 {
00462 if (weight >= 0) walksatClauseWts.append(1);
00463 else walksatClauseWts.append(-1);
00464 }
00465 else
00466 {
00467 if (weight >= 0)
00468 walksatClauseWts.append((int)(weight*LWInfo::WSCALE + 0.5));
00469 else
00470 walksatClauseWts.append((int)(weight*LWInfo::WSCALE - 0.5));
00471 }
00472
00473 delete intClause;
00474 }
00475
00476 delete intClauses;
00477 }
00478
00479
00480
00481 void getWalksatClauses(Array<Array<int> *> &allClauses, Array<int> &allClauseWeights)
00482 {
00483 getWalksatClauses(NULL, allClauses, allClauseWeights, true);
00484 }
00485
00486
00487 void getWalksatClausesWhenFlipped(int atom,
00488 Array<Array<int> *> &walksatClauses,
00489 Array<int> &walksatClauseWts)
00490 {
00491 Predicate *pred = predArray_[atom];
00492 TruthValue oldval = domain_->getDB()->getValue(pred);
00493 TruthValue val;
00494 (oldval == TRUE)? val=FALSE : val = TRUE;
00495
00496
00497 Predicate* otherPred = NULL;
00498 inBlock_ = false;
00499
00500 int blockIdx = domain_->getBlock(pred);
00501 if (blockIdx >= 0)
00502 {
00503
00504 const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
00505 if (block->size() > 1)
00506 {
00507 inBlock_ = true;
00508 int chosen = -1;
00509
00510 if (oldval == TRUE)
00511 {
00512 bool ok = false;
00513 while(!ok)
00514 {
00515 chosen = random() % block->size();
00516 if (!pred->same((*block)[chosen]))
00517 ok = true;
00518 }
00519 assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
00520 }
00521
00522 else
00523 {
00524 for (int i = 0; i < block->size(); i++)
00525 {
00526 if (domain_->getDB()->getValue((*block)[i]) == TRUE)
00527 {
00528 chosen = i;
00529 assert(!pred->same((*block)[chosen]));
00530 break;
00531 }
00532 }
00533 }
00534 assert(chosen >= 0);
00535 otherPred = (*block)[chosen];
00536
00537 domain_->getDB()->setValue(otherPred, oldval);
00538 }
00539 }
00540
00541 domain_->getDB()->setValue(pred, val);
00542
00543 getWalksatClauses(pred, walksatClauses, walksatClauseWts, true);
00544 if (inBlock_)
00545 getWalksatClauses(otherPred, walksatClauses, walksatClauseWts, true);
00546
00547
00548 domain_->getDB()->setValue(pred,oldval);
00549 if (inBlock_)
00550 {
00551 domain_->getDB()->setValue(otherPred, val);
00552
00553
00554 for(int atom = 1; atom <= getVarCount(); atom++)
00555 {
00556 if (otherPred->same(predArray_[atom]))
00557 {
00558 otherAtom_ = atom;
00559 break;
00560 }
00561 }
00562 }
00563 }
00564
00565
00566 int getUnSatCostPerPred(Predicate* pred,
00567 Array<Array<int> *> &walksatClauses,
00568 Array<int> &walksatClauseWts)
00569 {
00570
00571 int seed;
00572 struct timeval tv;
00573 struct timezone tzp;
00574 gettimeofday(&tv,&tzp);
00575 seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00576 srandom(seed);
00577
00578 int unSatCost = 0;
00579 Clause *clause;
00580 IntClause *intClause;
00581 IntClauseHashArray clauseHashArray;
00582
00583 Array<IntClause *>* intClauses = new Array<IntClause *>;
00584
00585 const Array<IndexClause*>* indclauses;
00586 int predid = pred->getId();
00587 indclauses = mln_->getClausesContainingPred(predid);
00588
00589 for(int j = 0; j < indclauses->size(); j++)
00590 {
00591 clause = (Clause *) (*indclauses)[j]->clause;
00592
00593
00594
00595
00596
00597
00598
00599
00600
00601
00602
00603
00604
00605
00606
00607
00608
00609
00610
00611
00612
00613 double wt = clause->getWt();
00614 intClauses->clear();
00615
00616 if(abs(wt) < WEIGHT_EPSILON)
00617 continue;
00618
00619 bool ignoreActivePreds = true;
00620 clause->getActiveClauses(pred, domain_, intClauses,
00621 &predHashArray_, ignoreActivePreds);
00622 updatePredArray();
00623
00624 for (int i = 0; i < intClauses->size(); i++)
00625 {
00626 intClause = (*intClauses)[i];
00627
00628
00629 if (sampleSat_)
00630 {
00631 assert(prevDB_);
00632
00633
00634
00635
00636
00637 bool sat = intClause->isSatisfied(&predHashArray_, prevDB_);
00638 if ((wt >= 0 && !sat) ||
00639 (wt < 0 && sat))
00640 {
00641 intClause->deleteIntPredicates();
00642 delete intClause;
00643 continue;
00644 }
00645
00646
00647 if (deadClauses_.contains(intClause))
00648 {
00649 intClause->deleteIntPredicates();
00650 delete intClause;
00651 continue;
00652 }
00653
00654
00655 double threshold =
00656 clause->isHardClause() ? RAND_MAX : RAND_MAX*(1-exp(-abs(wt)));
00657
00658 if (random() > threshold)
00659 {
00660 deadClauses_.append(intClause);
00661 continue;
00662 }
00663 }
00664
00665 int pos = clauseHashArray.find(intClause);
00666 if(pos >= 0)
00667 {
00668 clauseHashArray[pos]->addWt(wt);
00669 intClause->deleteIntPredicates();
00670 delete intClause;
00671 continue;
00672 }
00673
00674
00675
00676 if (wt == LWInfo::HARD_WT) intClause->setWtToHardWt();
00677 else intClause->setWt(wt);
00678 clauseHashArray.append(intClause);
00679
00680 }
00681
00682 }
00683
00684 Array<int>* litClause;
00685 for(int i = 0; i < clauseHashArray.size(); i++)
00686 {
00687 intClause = clauseHashArray[i];
00688
00689 int weight = (int)(intClause->getWt());
00690 litClause = (Array<int> *)intClause->getIntPredicates();
00691 walksatClauses.append(litClause);
00692 if (sampleSat_)
00693 {
00694 if (weight >= 0) walksatClauseWts.append(1);
00695 else walksatClauseWts.append(-1);
00696 unSatCost += 1;
00697 }
00698 else
00699 {
00700 if (weight >= 0)
00701 {
00702 walksatClauseWts.append((int)(weight*LWInfo::WSCALE + 0.5));
00703 unSatCost += (int)(weight*LWInfo::WSCALE + 0.5);
00704 }
00705 else
00706 {
00707 walksatClauseWts.append((int)(weight*LWInfo::WSCALE - 0.5));
00708 unSatCost += (int)(weight*LWInfo::WSCALE - 0.5);
00709 }
00710 }
00711
00712
00713
00714
00715 delete intClause;
00716 }
00717 delete intClauses;
00718
00719 return unSatCost;
00720 }
00721
00722
00723
00724
00725
00726 int getUnSatCostWhenFlipped(int atom,
00727 Array<Array<int> *> &walksatClauses,
00728 Array<int> &walksatClauseWts)
00729 {
00730 int unSatCost = 0;
00731 TruthValue val;
00732 Predicate *pred = predArray_[atom];
00733 TruthValue oldval = domain_->getDB()->getValue(pred);
00734 (oldval == TRUE) ? val = FALSE: val = TRUE;
00735
00736
00737 Predicate* otherPred = NULL;
00738 bool inBlock = false;
00739
00740 int blockIdx = domain_->getBlock(pred);
00741 if (blockIdx >= 0)
00742 {
00743
00744 const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
00745 if (block->size() > 1)
00746 {
00747 inBlock = true;
00748 int chosen = -1;
00749
00750 if (oldval == TRUE)
00751 {
00752 bool ok = false;
00753 while(!ok)
00754 {
00755 chosen = random() % block->size();
00756 if (!pred->same((*block)[chosen]))
00757 ok = true;
00758 }
00759 assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
00760 }
00761
00762 else
00763 {
00764 for (int i = 0; i < block->size(); i++)
00765 {
00766 if (domain_->getDB()->getValue((*block)[i]) == TRUE)
00767 {
00768 chosen = i;
00769 assert(!pred->same((*block)[chosen]));
00770 break;
00771 }
00772 }
00773 }
00774 assert(chosen >= 0);
00775 otherPred = (*block)[chosen];
00776
00777 domain_->getDB()->setValue(otherPred, oldval);
00778 }
00779 }
00780
00781 domain_->getDB()->setValue(pred, val);
00782
00783
00784 unSatCost += getUnSatCostPerPred(pred, walksatClauses, walksatClauseWts);
00785 if (inBlock)
00786 unSatCost += getUnSatCostPerPred(otherPred, walksatClauses, walksatClauseWts);
00787
00788
00789 domain_->getDB()->setValue(pred,oldval);
00790 if (inBlock)
00791 domain_->getDB()->setValue(otherPred, val);
00792 return unSatCost;
00793 }
00794
00795
00796 void setVarVals(int newVals[])
00797 {
00798 for(int atom = 1; atom <= getVarCount(); atom++)
00799 {
00800 Predicate *ipred = predArray_[atom];
00801 if(newVals[atom] == 1)
00802 domain_->getDB()->setValue(ipred,TRUE);
00803 else
00804 domain_->getDB()->setValue(ipred,FALSE);
00805 }
00806 }
00807
00808
00809 void flipVar(int atom)
00810 {
00811 Predicate *ipred = predArray_[atom];
00812 TruthValue val;
00813 val = domain_->getDB()->getValue(ipred);
00814 if(val == TRUE)
00815 domain_->getDB()->setValue(ipred,FALSE);
00816 else
00817 domain_->getDB()->setValue(ipred,TRUE);
00818 }
00819
00820
00821 bool getVarVal(int atom)
00822 {
00823 Predicate *ipred = predArray_[atom];
00824 bool val;
00825 (domain_->getDB()->getValue(ipred) == TRUE)? val = true : val = false;
00826 return val;
00827 }
00828
00829
00830 void setVarVal(int atom, bool val)
00831 {
00832 Predicate *ipred = predArray_[atom];
00833 if (val)
00834 domain_->getDB()->setValue(ipred, TRUE);
00835 else
00836 domain_->getDB()->setValue(ipred, FALSE);
00837 }
00838
00839 Predicate* getVar(int atom)
00840 {
00841 return predArray_[atom];
00842 }
00843
00844 void setAllActive()
00845 {
00846 LWUtil::setAllActive(domain_);
00847 }
00848
00849 void setAllInactive()
00850 {
00851 for(int atom = 1; atom <= getVarCount(); atom++)
00852 {
00853 Predicate* ipred = predArray_[atom];
00854 domain_->getDB()->setActiveStatus(ipred, false);
00855 }
00856 }
00857
00858 void setAllFalse()
00859 {
00860 for(int atom = 1; atom <= getVarCount(); atom++)
00861 {
00862 Predicate* ipred = predArray_[atom];
00863 domain_->getDB()->setValue(ipred, FALSE);
00864 }
00865 }
00866
00867
00868 void removeVars(Array<int> indices)
00869 {
00870 for (int i = 0; i < indices.size(); i++)
00871 {
00872 Predicate* pred = predArray_[indices[i]];
00873 domain_->getDB()->setActiveStatus(pred, false);
00874 domain_->getDB()->setDeactivatedStatus(pred, true);
00875 }
00876 reset();
00877
00878 }
00879
00880 void setSampleSat(bool s)
00881 {
00882 sampleSat_ = s;
00883 }
00884
00885 bool getSampleSat()
00886 {
00887 return sampleSat_;
00888 }
00889
00890
00891 void setPrevDB()
00892 {
00893 if (prevDB_) { delete prevDB_; prevDB_ = NULL; }
00894 prevDB_ = new Database((*domain_->getDB()));
00895 }
00896
00897 int getNumDBAtoms()
00898 {
00899 return numDBAtoms_;
00900 }
00901
00902
00903
00904
00905
00906 bool activateRandomAtom(Array<Array<int> *> &walksatClauses,
00907 Array<int> &walksatClauseWts,
00908 int& toflip)
00909 {
00910
00911 int seed;
00912 struct timeval tv;
00913 struct timezone tzp;
00914 gettimeofday(&tv,&tzp);
00915 seed = (unsigned int)((( tv.tv_sec & 0177 ) * 1000000) + tv.tv_usec);
00916 srandom(seed);
00917
00918 Predicate* pred;
00919 pred = domain_->getNonEvidenceAtom(random() % numDBAtoms_);
00920
00921 int pos = -1;
00922 if ((pos = predHashArray_.find(pred)) == -1)
00923 {
00924 domain_->getDB()->setActiveStatus(pred, true);
00925 predHashArray_.append(pred);
00926 updatePredArray();
00927 toflip = predHashArray_.size();
00928 getWalksatClauses(pred, walksatClauses, walksatClauseWts, true);
00929 return true;
00930 }
00931 else
00932 {
00933 delete pred;
00934 toflip = pos + 1;
00935 return false;
00936 }
00937 }
00938
00939
00940 void chooseOtherToFlip(int atom,
00941 Array<Array<int> *> &walksatClauses,
00942 Array<int> &walksatClauseWts)
00943 {
00944
00945 assert(isActive(atom));
00946 Predicate *pred = predArray_[atom];
00947 TruthValue oldval = domain_->getDB()->getValue(pred);
00948 TruthValue val;
00949 (oldval == TRUE)? val = FALSE : val = TRUE;
00950
00951
00952 Predicate* otherPred = NULL;
00953 inBlock_ = false;
00954 int blockIdx = domain_->getBlock(pred);
00955
00956 if (blockIdx < 0)
00957 {
00958 return;
00959 }
00960 else
00961 {
00962
00963 const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
00964 if (block->size() > 1)
00965 {
00966 inBlock_ = true;
00967 int chosen = -1;
00968
00969 if (oldval == TRUE)
00970 {
00971 bool ok = false;
00972 while(!ok)
00973 {
00974 chosen = random() % block->size();
00975 if (!pred->same((*block)[chosen]))
00976 ok = true;
00977 }
00978 assert(domain_->getDB()->getValue((*block)[chosen]) == FALSE);
00979 }
00980
00981 else
00982 {
00983 for (int i = 0; i < block->size(); i++)
00984 {
00985 if (domain_->getDB()->getValue((*block)[i]) == TRUE)
00986 {
00987 chosen = i;
00988 assert(!pred->same((*block)[chosen]));
00989 break;
00990 }
00991 }
00992 }
00993 assert(chosen >= 0);
00994 otherPred = (*block)[chosen];
00995
00996 domain_->getDB()->setValue(otherPred, oldval);
00997 }
00998 }
00999
01000 domain_->getDB()->setValue(pred, val);
01001 if (inBlock_)
01002 getWalksatClauses(otherPred, walksatClauses, walksatClauseWts, true);
01003
01004
01005 domain_->getDB()->setValue(pred,oldval);
01006 if (inBlock_)
01007 {
01008 domain_->getDB()->setValue(otherPred, val);
01009
01010
01011 for(int atom = 1; atom <= getVarCount(); atom++)
01012 {
01013 if (otherPred->same(predArray_[atom]))
01014 {
01015 otherAtom_ = atom;
01016 break;
01017 }
01018 }
01019 }
01020 }
01021
01022
01023
01024 void setOthersInBlockToFalse(const int& atomIdx,
01025 const int& blockIdx)
01026 {
01027 const Array<Predicate*>* block = domain_->getPredBlock(blockIdx);
01028 for (int i = 0; i < block->size(); i++)
01029 {
01030 if (i != atomIdx)
01031 domain_->getDB()->setValue((*block)[i], FALSE);
01032 }
01033 }
01034
01035
01036
01037 void initBlocks()
01038 {
01039 const Array<Array<Predicate*>*>* blocks = domain_->getPredBlocks();
01040 const Array<bool>* blockEvidence = domain_->getBlockEvidenceArray();
01041
01042 for (int i = 0; i < blocks->size(); i++)
01043 {
01044 Array<Predicate*>* block = (*blocks)[i];
01045 int chosen = -1;
01046
01047 if ((*blockEvidence)[i])
01048 {
01049 chosen = domain_->getEvidenceIdxInBlock(i);
01050 setOthersInBlockToFalse(chosen, i);
01051 continue;
01052 }
01053 else
01054 {
01055
01056 chosen = random() % block->size();
01057 }
01058 assert(chosen >= 0);
01059 domain_->getDB()->setValue((*block)[chosen], TRUE);
01060 setOthersInBlockToFalse(chosen, i);
01061 }
01062 }
01063
01064 int getBlock(int atom)
01065 {
01066 const Array<Array<Predicate*>*>* blocks = domain_->getPredBlocks();
01067 Predicate* pred = getVar(atom);
01068
01069 for (int i = 0; i < blocks->size(); i++)
01070 {
01071 Array<Predicate*>* block = (*blocks)[i];
01072 for (int j = 0; j < block->size(); j++)
01073 {
01074 if ((*block)[j]->same(pred))
01075 return i;
01076 }
01077 }
01078 return -1;
01079 }
01080
01081 bool inBlock(int atom)
01082 {
01083 return (getBlock(atom) >= 0);
01084 }
01085
01086 bool inBlockWithEvidence(int atom)
01087 {
01088 int blockIdx = getBlock(atom);
01089 if (blockIdx >= 0 && domain_->getBlockEvidence(blockIdx))
01090 return true;
01091 return false;
01092 }
01093
01094 bool getInBlock()
01095 {
01096 return inBlock_;
01097 }
01098
01099 void setInBlock(const bool val)
01100 {
01101 inBlock_ = val;
01102 }
01103
01104 int getOtherAtom()
01105 {
01106 assert(inBlock_ && otherAtom_ >= 0);
01107 return otherAtom_;
01108 }
01109
01110 void setOtherAtom(const int val)
01111 {
01112 otherAtom_ = val;
01113 }
01114
01115
01116 void setEvidence(const int atom, const bool val)
01117 {
01118 domain_->getDB()->setEvidenceStatus(predArray_[atom], val);
01119 }
01120
01121
01122 bool getEvidence(const int atom)
01123 {
01124 return domain_->getDB()->getEvidenceStatus(predArray_[atom]);
01125 }
01126
01127 void incrementNumDBAtoms()
01128 {
01129 numDBAtoms_++;
01130 }
01131
01132 void decrementNumDBAtoms()
01133 {
01134 numDBAtoms_--;
01135 }
01136
01137 void printIntClauses(Array<IntClause *> clauses)
01138 {
01139 for (int i = 0; i < clauses.size(); i++)
01140 {
01141 clauses[i]->printWithWtAndStrVar(cout, domain_, &predHashArray_);
01142 cout << endl;
01143 }
01144 }
01145
01146
01147
01148
01149
01150
01151 void propagateFixedAtoms(Array<Array<int> *> &clauses,
01152 Array<int> &clauseWeights,
01153 bool* fixedAtoms,
01154 int maxFixedAtoms)
01155 {
01156 Array<Array<int> *> tmpClauses;
01157 Array<int> tmpClauseWeights;
01158
01159
01160 int count = 0;
01161 for (int i = 0; i < maxFixedAtoms; i++)
01162 {
01163 if (fixedAtoms[i]) count++;
01164 }
01165 cout << "Fixed atoms before propagating: " << count << endl;
01166
01167
01168 for (int i = 0; i < maxFixedAtoms; i++)
01169 {
01170 if (fixedAtoms[i])
01171 {
01172
01173
01174 getWalksatClauses(predArray_[i], tmpClauses, tmpClauseWeights, false);
01175 cout << "Atom " << i << endl;
01176 cout << "Clauses " << tmpClauses.size() << endl;
01177
01178
01179
01180 for (int j = 0; j < tmpClauses.size(); j++)
01181 {
01182
01183 if (tmpClauses[j]->size() == 1)
01184 {
01185 clauses.append(tmpClauses[j]);
01186 clauseWeights.append(clauseWeights[j]);
01187 fixedAtoms[(*tmpClauses[j])[0]] = true;
01188 }
01189 else
01190 {
01191 delete tmpClauses[j];
01192 }
01193 }
01194 }
01195 tmpClauses.clear();
01196 tmpClauseWeights.clear();
01197 }
01198
01199
01200 count = 0;
01201 for (int i = 0; i < maxFixedAtoms; i++)
01202 {
01203 if (fixedAtoms[i]) count++;
01204 }
01205 cout << "Fixed atoms after propagating: " << count << endl;
01206
01207 exit(0);
01208 }
01209
01210
01211 public:
01212
01213 static double HARD_WT;
01214
01215 static double WSCALE;
01216
01217 private:
01218
01219 MLN *mln_;
01220 Domain *domain_;
01221
01222 Array<Predicate *> predArray_;
01223 PredicateHashArray predHashArray_;
01224
01225 Array<Array<int> *> predToClauseIds_;
01226
01227
01228 bool sampleSat_;
01229
01230 IntClauseHashArray deadClauses_;
01231
01232 Database* prevDB_;
01233
01234 int numDBAtoms_;
01235
01236 bool inBlock_;
01237
01238 int otherAtom_;
01239
01240 };
01241
01242 #endif
01243