00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef VARIABLESTATE_H_
00067 #define VARIABLESTATE_H_
00068
00069 #include "mrf.h"
00070 #include "timer.h"
00071
00072 const int NOVALUE = 100000000;
00073 const int DISANDCONT = 200000000;
00074 const int MULTIDIS = 300000000;
00075 const double NOSOL = 1234567890;
00076 const bool vsdebug = false;
00077
00097 class VariableState
00098 {
00099 public:
00100
00119 VariableState(GroundPredicateHashArray* const& unknownQueries,
00120 GroundPredicateHashArray* const& knownQueries,
00121 Array<TruthValue>* const & knownQueryValues,
00122 const Array<int>* const & allPredGndingsAreQueries,
00123 const bool& markHardGndClauses,
00124 const bool& trackParentClauseWts,
00125 const MLN* const & mln, const Domain* const & domain,
00126 const bool& lazy)
00127 {
00128 stillActivating_ = true;
00129 breakHardClauses_ = false;
00130
00131 inferenceMode_ = MODE_MWS;
00132 Timer timer;
00133 double startTime = timer.time();
00134
00135 this->mln_ = (MLN*)mln;
00136 this->domain_ = (Domain*)domain;
00137 this->lazy_ = lazy;
00138
00139
00140 baseNumAtoms_ = 0;
00141 activeAtoms_ = 0;
00142 numFalseClauses_ = 0;
00143 costOfFalseClauses_ = 0.0;
00144 lowCost_ = LDBL_MAX;
00145 lowBad_ = INT_MAX;
00146
00147
00148 gndClauses_ = new GroundClauseHashArray;
00149 gndPreds_ = new Array<GroundPredicate*>;
00150
00151
00152 setHardClauseWeight();
00153
00154
00155 if (lazy_)
00156 {
00157
00158 domain_->computeNumNonEvidAtoms();
00159 numNonEvAtoms_ = domain_->getNumNonEvidenceAtoms();
00160
00161 domain_->getDB()->setPerformingInference(true);
00162 clauseLimit_ = INT_MAX;
00163 noApprox_ = false;
00164 haveDeactivated_ = false;
00165
00167
00168
00169
00170
00171 initBlocksRandom();
00172
00173
00174 bool ignoreActivePreds = true;
00175 cout << "Getting initial active atoms ... " << endl;
00176 getActiveClauses(newClauses_, ignoreActivePreds);
00177 cout << "done." << endl;
00178 int defaultCnt = newClauses_.size();
00179 long double defaultCost = 0;
00180
00181 for (int i = 0; i < defaultCnt; i++)
00182 {
00183 if (newClauses_[i]->isHardClause())
00184 defaultCost += hardWt_;
00185 else
00186 defaultCost += abs(newClauses_[i]->getWt());
00187 }
00188
00189
00190 for (int i = 0; i < gndPredHashArray_.size(); i++)
00191 gndPredHashArray_[i]->removeGndClauses();
00192
00193
00194 for (int i = 0; i < newClauses_.size(); i++)
00195 delete newClauses_[i];
00196 newClauses_.clear();
00197
00198 baseNumAtoms_ = gndPredHashArray_.size();
00199 cout << "Number of Baseatoms = " << baseNumAtoms_ << endl;
00200 cout << "Default => Cost\t" << "******\t" << " Clause Cnt\t" << endl;
00201 cout << " " << defaultCost << "\t" << "******\t" << defaultCnt
00202 << "\t" << endl << endl;
00203
00204
00205 for (int i = 0; i < baseNumAtoms_; i++)
00206 {
00207 domain_->getDB()->setActiveStatus(gndPredHashArray_[i], true);
00208 activeAtoms_++;
00209 }
00210
00211
00212 ignoreActivePreds = false;
00213 cout << "Getting initial active clauses ... ";
00214 getActiveClauses(newClauses_, ignoreActivePreds);
00215 cout << "done." << endl;
00216 }
00217
00218 else
00219 {
00220 unePreds_ = unknownQueries;
00221 knePreds_ = knownQueries;
00222 knePredValues_ = knownQueryValues;
00223
00224
00225 int size = 0;
00226 if (unknownQueries) size += unknownQueries->size();
00227 if (knownQueries) size += knownQueries->size();
00228 GroundPredicateHashArray* queries = new GroundPredicateHashArray(size);
00229 if (unknownQueries) queries->append(unknownQueries);
00230 if (knownQueries) queries->append(knownQueries);
00231 mrf_ = new MRF(queries, allPredGndingsAreQueries, domain_,
00232 domain_->getDB(), mln_, markHardGndClauses,
00233 trackParentClauseWts, -1);
00234
00235
00236
00237 mrf_->deleteGndPredsGndClauseSets();
00238
00239 delete queries;
00240
00241
00242 newClauses_ = *(Array<GroundClause*>*)mrf_->getGndClauses();
00243
00244
00245 const GroundPredicateHashArray* gndPreds = mrf_->getGndPreds();
00246 for (int i = 0; i < gndPreds->size(); i++)
00247 gndPredHashArray_.append((*gndPreds)[i]);
00248
00249
00250 baseNumAtoms_ = gndPredHashArray_.size();
00251 }
00252
00253
00254
00255
00256
00257
00258 bool initial = true;
00259 addNewClauses(initial);
00260
00261 cout << "[VS] ";
00262 Timer::printTime(cout,timer.time()-startTime);
00263 cout << endl;
00264 cout << ">>> DONE: Initial num. of clauses: " << getNumClauses() << endl;
00265 }
00266
00270 ~VariableState()
00271 {
00272 if (lazy_)
00273 {
00274 if (gndClauses_)
00275 for (int i = 0; i < gndClauses_->size(); i++)
00276 delete (*gndClauses_)[i];
00277
00278 for (int i = 0; i < gndPredHashArray_.size(); i++)
00279 {
00280 gndPredHashArray_[i]->removeGndClauses();
00281 delete gndPredHashArray_[i];
00282 }
00283 }
00284 else
00285 {
00286
00287 if (mrf_)
00288 {
00289 delete mrf_;
00290 mrf_ = NULL;
00291 }
00292
00293
00294
00295 }
00296
00297
00298
00299 }
00300
00301
00309 void addNewClauses(bool initial)
00310 {
00311 if (initial) addNewClauses(ADD_CLAUSE_INITIAL, newClauses_);
00312 else addNewClauses(ADD_CLAUSE_REGULAR, newClauses_);
00313 }
00314
00318 void init()
00319 {
00320
00321 initMakeBreakCostWatch();
00322 }
00323
00327 void reinit()
00328 {
00329 clause_.clearAndCompress();
00330 clauseCost_.clearAndCompress();
00331 falseClause_.clearAndCompress();
00332 whereFalse_.clearAndCompress();
00333 numTrueLits_.clearAndCompress();
00334 watch1_.clearAndCompress();
00335 watch2_.clearAndCompress();
00336 isSatisfied_.clearAndCompress();
00337 deadClause_.clearAndCompress();
00338 threshold_.clearAndCompress();
00339
00340
00341 for (int i = 0; i < gndClauses_->size(); i++)
00342 newClauses_.append((*gndClauses_)[i]);
00343
00344 gndClauses_->clearAndCompress();
00345 gndPreds_->clearAndCompress();
00346 for (int i = 0; i < occurence_.size(); i++)
00347 occurence_[i].clearAndCompress();
00348 occurence_.clearAndCompress();
00349
00350
00351 bool initial = true;
00352 addNewClauses(initial);
00353 baseNumAtoms_ = gndPredHashArray_.size();
00354 init();
00355 }
00356
00362 void initRandom()
00363 {
00364
00365 if (inferenceMode_ == MODE_SAMPLESAT)
00366 {
00367 unitPropagation();
00368 }
00369
00370
00371 initBlocksRandom();
00372
00373 for (int i = 1; i <= gndPreds_->size(); i++)
00374 {
00375 if (fixedAtom_[i] != 0)
00376 {
00377 continue;
00378 }
00379
00380
00381 if (getBlockIndex(i - 1) >= 0)
00382 {
00383 if (vsdebug) cout << "Atom " << i << " in block" << endl;
00384 continue;
00385 }
00386
00387 else
00388 {
00389 if (vsdebug) cout << "Atom " << i << " not in block" << endl;
00390
00391
00392 if (!lazy_ || isActive(i))
00393 {
00394 bool isTrue = random() % 2;
00395 bool activate = false;
00396 setValueOfAtom(i, isTrue, activate, -1);
00397 }
00398 else if (inferenceMode_ == MODE_SAMPLESAT)
00399 {
00400 bool isInPrevUnsat = false;
00401 Array<int> oca = getPosOccurenceArray(i);
00402 for (int k = 0; k < oca.size(); k++)
00403 if (!prevSatisfiedClause_[oca[k]])
00404 {
00405 isInPrevUnsat = true;
00406 break;
00407 }
00408
00409 if (!isInPrevUnsat)
00410 {
00411 oca = getNegOccurenceArray(i);
00412 for (int k = 0; k < oca.size(); k++)
00413 if (!prevSatisfiedClause_[oca[k]])
00414 {
00415 isInPrevUnsat = true;
00416 break;
00417 }
00418 }
00419 if (isInPrevUnsat)
00420 {
00421 bool isTrue = random() % 2;
00422 if (isTrue)
00423 {
00424 if (activateAtom(i, false, false))
00425 setValueOfAtom(i, true, false, -1);
00426 }
00427 }
00428 }
00429 }
00430 }
00431
00432 init();
00433 }
00434
00438 void initBlocksRandom()
00439 {
00440 if (vsdebug)
00441 {
00442 cout << "Initializing blocks randomly" << endl;
00443 cout << "Num. of blocks: " << domain_->getNumPredBlocks() << endl;
00444 }
00445
00446 for (int i = 0; i < domain_->getNumPredBlocks(); i++)
00447 {
00448 int trueFixedAtom = getTrueFixedAtomInBlock(i);
00449
00450 if (trueFixedAtom >= 0)
00451 {
00452 if (vsdebug)
00453 {
00454 cout << "True fixed atom " << trueFixedAtom << " in block "
00455 << i << endl;
00456 }
00457 setOthersInBlockToFalse(trueFixedAtom, i);
00458 continue;
00459 }
00460
00461
00462 if (domain_->getBlockEvidence(i))
00463 {
00464 if (vsdebug) cout << "Block evidence in block " << i << endl;
00465
00466 setOthersInBlockToFalse(-1, i);
00467 continue;
00468 }
00469
00470
00471 bool ok = false;
00472 while (!ok)
00473 {
00474 const Predicate* pred = domain_->getRandomPredInBlock(i);
00475 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00476 int atomIdx = gndPredHashArray_.find(gndPred);
00477 delete pred;
00478
00479 assert(lazy_ || atomIdx >= 0);
00480
00481 if (atomIdx == -1)
00482 {
00483 atomIdx = gndPredHashArray_.append(gndPred);
00484 bool initial = false;
00485 addNewClauses(initial);
00486 ok = true;
00487 }
00488
00489 else
00490 {
00491 delete gndPred;
00492 if (fixedAtom_[atomIdx + 1] == 0)
00493 {
00494 if (vsdebug) cout << "Atom " << atomIdx + 1
00495 << " chosen in block " << i << endl;
00496 ok = true;
00497 }
00498 else
00499 {
00500 if (vsdebug) cout << "Atom " << atomIdx + 1
00501 << " is fixed to " << fixedAtom_[atomIdx + 1]
00502 << endl;
00503
00504 continue;
00505 }
00506 }
00507 bool activate = false;
00508 setValueOfAtom(atomIdx + 1, true, activate, i);
00509 setOthersInBlockToFalse(atomIdx, i);
00510 }
00511 }
00512 if (vsdebug) cout << "Done initializing blocks randomly" << endl;
00513 }
00514
00518 void resetMakeBreakCostWatch()
00519 {
00520
00521 for (int i = 0; i < getNumClauses(); i++) numTrueLits_[i] = 0;
00522 numFalseClauses_ = 0;
00523 costOfFalseClauses_ = 0.0;
00524 lowCost_ = LDBL_MAX;
00525 lowBad_ = INT_MAX;
00526
00527 assert(makeCost_.size() == breakCost_.size());
00528
00529 for (int i = 0; i < makeCost_.size(); i++)
00530 {
00531 makeCost_[i] = breakCost_[i] = 0.0;
00532 }
00533 }
00534
00535
00540 void initMakeBreakCostWatch()
00541 {
00542 resetMakeBreakCostWatch();
00543 initMakeBreakCostWatch(0);
00544 }
00545
00553 void initMakeBreakCostWatch(const int& startClause)
00554 {
00555 int theTrueLit = -1;
00556
00557 for (int i = startClause; i < getNumClauses(); i++)
00558 {
00559
00560 if (deadClause_[i] || isSatisfied_[i]) continue;
00561
00562 int trueLit1 = 0;
00563 int trueLit2 = 0;
00564 long double cost = clauseCost_[i];
00565 numTrueLits_[i] = 0;
00566 for (int j = 0; j < getClauseSize(i); j++)
00567 {
00568 if (isTrueLiteral(clause_[i][j]))
00569 {
00570 numTrueLits_[i]++;
00571 theTrueLit = abs(clause_[i][j]);
00572 if (!trueLit1) trueLit1 = theTrueLit;
00573 else if (trueLit1 && !trueLit2) trueLit2 = theTrueLit;
00574 }
00575 }
00576
00577
00578
00579 if ((numTrueLits_[i] == 0 && cost >= 0) ||
00580 (numTrueLits_[i] > 0 && cost < 0))
00581 {
00582 whereFalse_[i] = numFalseClauses_;
00583 falseClause_[numFalseClauses_] = i;
00584 numFalseClauses_++;
00585 costOfFalseClauses_ += abs(cost);
00586 if (highestCost_ == abs(cost)) {eqHighest_ = true; numHighest_++;}
00587
00588
00589 if (numTrueLits_[i] == 0)
00590 for (int j = 0; j < getClauseSize(i); j++)
00591 {
00592 makeCost_[abs(clause_[i][j])] += cost;
00593 }
00594
00595
00596 if (numTrueLits_[i] == 1)
00597 {
00598
00599 makeCost_[theTrueLit] -= cost;
00600 watch1_[i] = theTrueLit;
00601 }
00602 else if (numTrueLits_[i] > 1)
00603 {
00604 watch1_[i] = trueLit1;
00605 watch2_[i] = trueLit2;
00606 }
00607 }
00608
00609 else if (numTrueLits_[i] == 1 && cost >= 0)
00610 {
00611 breakCost_[theTrueLit] += cost;
00612 watch1_[i] = theTrueLit;
00613 }
00614
00615 else if (cost >= 0)
00616 {
00617 watch1_[i] = trueLit1;
00618 watch2_[i] = trueLit2;
00619 }
00620
00621 else if (numTrueLits_[i] == 0 && cost < 0)
00622 {
00623 for (int j = 0; j < getClauseSize(i); j++)
00624 breakCost_[abs(clause_[i][j])] -= cost;
00625 }
00626 }
00627 }
00628
00629 int getNumAtoms() { return gndPreds_->size(); }
00630
00631 int getNumClauses() { return gndClauses_->size(); }
00632
00633 int getNumDeadClauses()
00634 {
00635 int count = 0;
00636 for (int i = 0; i < deadClause_.size(); i++)
00637 if (deadClause_[i]) count++;
00638 return count;
00639 }
00640
00646 int getIndexOfRandomAtom()
00647 {
00648
00649 if (lazy_)
00650 {
00651 Predicate* pred = domain_->getNonEvidenceAtom(random() % numNonEvAtoms_);
00652 GroundPredicate* gndPred = new GroundPredicate(pred);
00653 delete pred;
00654
00655 int idx = gndPredHashArray_.find(gndPred);
00656
00657 if (idx >= 0)
00658 {
00659 delete gndPred;
00660 return idx + 1;
00661 }
00662
00663 else
00664 {
00665 if (vsdebug)
00666 {
00667 cout << "Adding randomly ";
00668 gndPred->print(cout, domain_);
00669 cout << " to the state" << endl;
00670 }
00671 gndPredHashArray_.append(gndPred);
00672 bool initial = false;
00673 addNewClauses(initial);
00674
00675 return gndPredHashArray_.size();
00676 }
00677 }
00678
00679 else
00680 {
00681 int numAtoms = getNumAtoms();
00682 if (numAtoms == 0) return NOVALUE;
00683 return random()%numAtoms + 1;
00684 }
00685 }
00686
00692 int getIndexOfAtomInRandomFalseClause()
00693 {
00694 if (numFalseClauses_ == 0) return NOVALUE;
00695 int clauseIdx = falseClause_[random()%numFalseClauses_];
00696
00697
00698 if (clauseCost_[clauseIdx] >= 0)
00699 {
00700
00701 while (true)
00702 {
00703 int i = random()%getClauseSize(clauseIdx);
00704 if (!fixedAtom_[abs(clause_[clauseIdx][i])])
00705 return abs(clause_[clauseIdx][i]);
00706 }
00707 }
00708
00709 else
00710 return getRandomTrueLitInClause(clauseIdx);
00711 }
00712
00717 int getRandomFalseClauseIndex()
00718 {
00719 if (numFalseClauses_ == 0) return NOVALUE;
00720 return falseClause_[random()%numFalseClauses_];
00721 }
00722
00727 long double getCostOfFalseClauses()
00728 {
00729 return costOfFalseClauses_;
00730 }
00731
00736 int getNumFalseClauses()
00737 {
00738 return numFalseClauses_;
00739 }
00740
00747 bool getValueOfAtom(const int& atomIdx)
00748 {
00749 return atom_[atomIdx];
00750 }
00751
00758 bool getValueOfLowAtom(const int& atomIdx)
00759 {
00760 return lowAtom_[atomIdx];
00761 }
00762
00773 void setValueOfAtom(const int& atomIdx, const bool& value,
00774 const bool& activate, const int& blockIdx)
00775 {
00776
00777 if (atom_[atomIdx] == value) return;
00778 if (vsdebug) cout << "Setting value of atom " << atomIdx
00779 << " to " << value << endl;
00780
00781 GroundPredicate* p = gndPredHashArray_[atomIdx - 1];
00782 if (value)
00783 domain_->getDB()->setValue(p, TRUE);
00784 else
00785 domain_->getDB()->setValue(p, FALSE);
00786
00787
00788 if (activate && lazy_ && !isActive(atomIdx))
00789 {
00790 bool ignoreActivePreds = false;
00791 bool groundOnly = false;
00792 activateAtom(atomIdx, ignoreActivePreds, groundOnly);
00793 }
00794 atom_[atomIdx] = value;
00795
00796
00797 if (blockIdx > -1 && value)
00798 {
00799 Predicate* pred = p->createEquivalentPredicate(domain_);
00800 domain_->setTruePredInBlock(blockIdx, pred);
00801 if (vsdebug)
00802 {
00803 cout << "Set true pred in block " << blockIdx << " to ";
00804 pred->printWithStrVar(cout, domain_);
00805 cout << endl;
00806 }
00807 }
00808 }
00809
00813 Array<int>& getNegOccurenceArray(const int& atomIdx)
00814 {
00815 int litIdx = 2*atomIdx;
00816 return getOccurenceArray(litIdx);
00817 }
00818
00822 Array<int>& getPosOccurenceArray(const int& atomIdx)
00823 {
00824 int litIdx = 2*atomIdx - 1;
00825 return getOccurenceArray(litIdx);
00826 }
00827
00834 void flipAtom(const int& toFlip, const int& blockIdx)
00835 {
00836 bool toFlipValue = getValueOfAtom(toFlip);
00837 register int clauseIdx;
00838 int sign;
00839 int oppSign;
00840 int litIdx;
00841 if (toFlipValue)
00842 sign = 1;
00843 else
00844 sign = 0;
00845 oppSign = sign ^ 1;
00846
00847 flipAtomValue(toFlip, blockIdx);
00848
00849 litIdx = 2*toFlip - sign;
00850 Array<int>& posOccArray = getOccurenceArray(litIdx);
00851 for (int i = 0; i < posOccArray.size(); i++)
00852 {
00853 clauseIdx = posOccArray[i];
00854
00855 if (deadClause_[clauseIdx] || isSatisfied_[clauseIdx]) continue;
00856
00857
00858 int numTrueLits = decrementNumTrueLits(clauseIdx);
00859 long double cost = getClauseCost(clauseIdx);
00860 int watch1 = getWatch1(clauseIdx);
00861 int watch2 = getWatch2(clauseIdx);
00862
00863
00864 if (numTrueLits == 0)
00865 {
00866
00867 if (cost >= 0)
00868 {
00869
00870 addFalseClause(clauseIdx);
00871
00872 addBreakCost(toFlip, -cost);
00873
00874 addMakeCostToAtomsInClause(clauseIdx, cost);
00875 }
00876
00877 else
00878 {
00879 assert(cost < 0);
00880
00881 removeFalseClause(clauseIdx);
00882
00883 addBreakCostToAtomsInClause(clauseIdx, -cost);
00884
00885 addMakeCost(toFlip, cost);
00886 }
00887 }
00888
00889
00890 else if (numTrueLits == 1)
00891 {
00892 if (watch1 == toFlip)
00893 {
00894 assert(watch1 != watch2);
00895 setWatch1(clauseIdx, watch2);
00896 watch1 = getWatch1(clauseIdx);
00897 }
00898
00899
00900 if (cost >= 0)
00901 {
00902 addBreakCost(watch1, cost);
00903 }
00904
00905 else
00906 {
00907 assert(cost < 0);
00908 addMakeCost(watch1, -cost);
00909 }
00910 }
00911
00912
00913 else
00914 {
00915
00916 if (watch1 == toFlip)
00917 {
00918
00919 int diffTrueLit = getTrueLiteralOtherThan(clauseIdx, watch1, watch2);
00920 setWatch1(clauseIdx, diffTrueLit);
00921 }
00922
00923 else if (watch2 == toFlip)
00924 {
00925
00926 int diffTrueLit = getTrueLiteralOtherThan(clauseIdx, watch1, watch2);
00927 setWatch2(clauseIdx, diffTrueLit);
00928 }
00929 }
00930 }
00931
00932
00933 litIdx = 2*toFlip - oppSign;
00934 Array<int>& negOccArray = getOccurenceArray(litIdx);
00935 for (int i = 0; i < negOccArray.size(); i++)
00936 {
00937 clauseIdx = negOccArray[i];
00938
00939 if (deadClause_[clauseIdx] || isSatisfied_[clauseIdx]) continue;
00940
00941
00942 int numTrueLits = incrementNumTrueLits(clauseIdx);
00943 long double cost = getClauseCost(clauseIdx);
00944 int watch1 = getWatch1(clauseIdx);
00945
00946
00947
00948 if (numTrueLits == 1)
00949 {
00950
00951 if (cost >= 0)
00952 {
00953
00954 removeFalseClause(clauseIdx);
00955
00956 addBreakCost(toFlip, cost);
00957
00958 addMakeCostToAtomsInClause(clauseIdx, -cost);
00959 }
00960
00961 else
00962 {
00963 assert(cost < 0);
00964
00965 addFalseClause(clauseIdx);
00966
00967 addBreakCostToAtomsInClause(clauseIdx, cost);
00968
00969 addMakeCost(toFlip, -cost);
00970 }
00971
00972 setWatch1(clauseIdx, toFlip);
00973 }
00974
00975
00976 else
00977 if (numTrueLits == 2)
00978 {
00979 if (cost >= 0)
00980 {
00981
00982
00983 addBreakCost(watch1, -cost);
00984 }
00985 else
00986 {
00987
00988 assert(cost < 0);
00989
00990 addMakeCost(watch1, cost);
00991 }
00992
00993
00994 setWatch2(clauseIdx, toFlip);
00995 }
00996 }
00997 }
00998
01006 void flipAtomValue(const int& atomIdx, const int& blockIdx)
01007 {
01008 bool opposite = !atom_[atomIdx];
01009 bool activate = true;
01010 setValueOfAtom(atomIdx, opposite, activate, blockIdx);
01011 }
01012
01024 long double getImprovementByFlipping(const int& atomIdx)
01025 {
01026 if (!breakHardClauses_ && breakCost_[atomIdx] >= hardWt_)
01027 return -LDBL_MAX;
01028 long double improvement = makeCost_[atomIdx] - breakCost_[atomIdx];
01029 return improvement;
01030 }
01031
01037 void setActive(const int& atomIdx)
01038 {
01039 if (lazy_)
01040 {
01041 Predicate* p =
01042 gndPredHashArray_[atomIdx - 1]->createEquivalentPredicate(domain_);
01043 domain_->getDB()->setActiveStatus(p, true);
01044 activeAtoms_++;
01045 delete p;
01046 }
01047 }
01048
01059 bool activateAtom(const int& atomIdx, const bool& ignoreActivePreds,
01060 const bool& groundOnly)
01061 {
01062
01063
01064
01065
01066
01067
01068
01069
01070
01071 if (lazy_ && !isActive(atomIdx))
01072 {
01073 if (vsdebug)
01074 {
01075 cout << "\tActivating ";
01076 gndPredHashArray_[atomIdx-1]->print(cout,domain_);
01077 cout << " " <<
01078 domain_->getDB()->getActiveStatus(gndPredHashArray_[atomIdx-1]) << endl;
01079 }
01080
01081
01082 bool needToFlipBack = false;
01083 if (!atom_[atomIdx])
01084 {
01085 bool activate = false;
01086 setValueOfAtom(atomIdx, true, activate, -1);
01087 updateMakeBreakCostAfterFlip(atomIdx);
01088 needToFlipBack = true;
01089 }
01090
01091
01092 Predicate* p =
01093 gndPredHashArray_[atomIdx - 1]->createEquivalentPredicate(domain_);
01094
01095 bool isFixed = false;
01096 Array<GroundClause*> unsatClauses;
01097 Array<GroundClause*> deadClauses;
01098 Array<GroundClause*> toAddClauses;
01099
01100
01101 getActiveClauses(p, unsatClauses, true, ignoreActivePreds);
01102
01103
01104 if (useThreshold_)
01105 {
01106
01107 for (int ni = 0; ni < unsatClauses.size(); ni++)
01108 {
01109 GroundClause* c = unsatClauses[ni];
01110 if (c->getWt() < 0)
01111 {
01112
01113
01114 double threshold = RAND_MAX*(1 - exp(c->getWt()));
01115 if (random() <= threshold)
01116 {
01117
01118 isFixed = true;
01119
01120
01121
01122 if (deadClauses.size() > 0)
01123 addNewClauses(ADD_CLAUSE_DEAD, deadClauses);
01124
01125 Array<GroundClause*> fixedClauses;
01126 fixedClauses.append(c);
01127 addNewClauses(ADD_CLAUSE_SAT, fixedClauses);
01128
01129
01130 for (int pi = 0; pi < c->getNumGroundPredicates(); pi++)
01131 {
01132 int lit = c->getGroundPredicateIndex(pi);
01133 fixAtom(abs(lit), (lit < 0));
01134 }
01135
01136
01137 break;
01138 }
01139 else
01140 {
01141
01142 deadClauses.append(c);
01143 }
01144 }
01145 else if (c->getNumGroundPredicates() == 1)
01146 {
01147 double threshold = RAND_MAX*(1 - exp(-c->getWt()));
01148 if (random() <= threshold)
01149 {
01150
01151 isFixed = true;
01152
01153
01154
01155 if (deadClauses.size() > 0)
01156 addNewClauses(ADD_CLAUSE_DEAD, deadClauses);
01157
01158 Array<GroundClause*> fixedClauses;
01159 fixedClauses.append(c);
01160 addNewClauses(ADD_CLAUSE_SAT, fixedClauses);
01161
01162
01163 int lit = c->getGroundPredicateIndex(0);
01164 fixAtom(abs(lit), (lit > 0));
01165
01166
01167 break;
01168 }
01169 else
01170 {
01171
01172 deadClauses.append(c);
01173 }
01174 }
01175 else toAddClauses.append(c);
01176 }
01177 }
01178
01179
01180 if (!isFixed)
01181 {
01182 if (deadClauses.size() > 0)
01183 addNewClauses(ADD_CLAUSE_DEAD, deadClauses);
01184
01185 if (useThreshold_)
01186 addNewClauses(ADD_CLAUSE_REGULAR, toAddClauses);
01187 else
01188 addNewClauses(ADD_CLAUSE_REGULAR, unsatClauses);
01189
01190
01191
01192 if (needToFlipBack)
01193 {
01194 bool activate = false;
01195 setValueOfAtom(atomIdx, false, activate, -1);
01196 updateMakeBreakCostAfterFlip(atomIdx);
01197 }
01198
01199 if (!groundOnly)
01200 {
01201
01202 domain_->getDB()->setActiveStatus(p, true);
01203 activeAtoms_++;
01204 }
01205 }
01206
01207 delete p;
01208 unsatClauses.clear();
01209 deadClauses.clear();
01210 toAddClauses.clear();
01211
01212 return !isFixed;
01213 }
01214 return true;
01215 }
01216
01220 void setInferenceMode(int mode)
01221 {
01222 if (mode == inferenceMode_) return;
01223 inferenceMode_ = mode;
01224 if (inferenceMode_ == MODE_MWS)
01225 {
01226 for (int i = 0; i < gndClauses_->size(); i++)
01227 {
01228 if ((*gndClauses_)[i]->isHardClause())
01229 clauseCost_[i] = hardWt_;
01230 else
01231 clauseCost_[i] = (*gndClauses_)[i]->getWt();
01232 }
01233 initMakeBreakCostWatch();
01234 }
01235 else
01236 {
01237 if (inferenceMode_ == MODE_HARD) eliminateSoftClauses();
01238 else if (inferenceMode_ == MODE_SAMPLESAT) makeUnitCosts();
01239 }
01240 }
01241
01245 bool isFixedAtom(int atomIdx)
01246 {
01247 return fixedAtom_[atomIdx];
01248 }
01249
01250 bool isSatisfiedClause(const int& clauseIdx)
01251 {
01252 return isSatisfied_[clauseIdx];
01253 }
01254
01255 void setSatisfiedClause(const int& clauseIdx)
01256 {
01257 isSatisfied_[clauseIdx] = true;
01258 }
01259
01260
01261 void updatePrevSatisfied()
01262 {
01263 prevSatisfiedClause_.clearAndCompress();
01264 prevSatisfiedClause_.growToSize(clause_.size(),false);
01265 for (int i=0; i<clause_.size(); i++)
01266 {
01267 long double cost = clauseCost_[i];
01268 bool isTrue = false;
01269 for (int j = 0; j < getClauseSize(i); j++)
01270 {
01271 if (isTrueLiteral(clause_[i][j]))
01272 {
01273 isTrue = true;
01274 break;
01275 }
01276 }
01277 prevSatisfiedClause_[i] = ((isTrue && cost>0) || (!isTrue && cost<0));
01278 }
01279 }
01280
01284 void unitPropagation()
01285 {
01286
01287
01288
01289 for (int i = 0; i < getNumClauses(); i++)
01290 {
01291
01292
01293 if (isDeadClause(i) || isSatisfiedClause(i) || getClauseCost(i) >= 0)
01294 continue;
01295
01296
01297 Array<int> atoms = getAtomsInClause(i);
01298 for (int j = 0; j < atoms.size(); j++)
01299 {
01300 int lit = atoms[j];
01301 bool value = (lit < 0)? true : false;
01302 fixAtom(abs(lit), value);
01303 }
01304 }
01305
01306
01307
01308
01309 bool done = false;
01310
01311 while (!done)
01312 {
01313 if (vsdebug) cout << endl << endl;
01314 done = true;
01315
01316
01317
01318 for (int ci = 0; ci < getNumClauses(); ci++)
01319 {
01320
01321 if (isDeadClause(ci) || isSatisfiedClause(ci) || getClauseCost(ci) <= 0)
01322 continue;
01323
01324
01325 int numNonfixedAtoms = 0;
01326 int nonfixedAtom = 0;
01327
01328 bool isSat = false;
01329 for (int li = 0; li < getClauseSize(ci); li++)
01330 {
01331 int lit = getAtomInClause(li,ci);
01332 int fixedValue = fixedAtom_[abs(lit)];
01333
01334 if (fixedValue==0)
01335 {
01336 numNonfixedAtoms++;
01337 nonfixedAtom = lit;
01338 continue;
01339 }
01340
01341 if ((fixedValue == 1 && lit > 0) || (fixedValue == -1 && lit < 0))
01342 {
01343 isSat = true;
01344 break;
01345 }
01346 }
01347
01348 if (isSat) setSatisfiedClause(ci);
01349 else if (numNonfixedAtoms == 1)
01350 {
01351 fixAtom(abs(nonfixedAtom), (nonfixedAtom > 0) ? true : false);
01352 done = false;
01353 }
01354 }
01355 }
01356
01357 saveLowState();
01358 if (vsdebug) cout << ">>> [vs.unitpropagation] DONE" << endl;
01359 }
01360
01367 void addNewClauses(int addType, Array<GroundClause*> & clauses)
01368 {
01369 if (vsdebug)
01370 cout << "Adding " << clauses.size() << " new clauses.." << endl;
01371
01372
01373 if (!useThreshold_ &&
01374 (addType == ADD_CLAUSE_DEAD || addType == ADD_CLAUSE_SAT))
01375 {
01376 cout << ">>> [ERR] add_dead/sat but useThreshold_ is false" << endl;
01377 exit(0);
01378 }
01379
01380
01381 int oldNumClauses = getNumClauses();
01382 int oldNumAtoms = getNumAtoms();
01383
01384
01385 for (int i = 0; i < clauses.size(); i++)
01386 {
01387 gndClauses_->append(clauses[i]);
01388 clauses[i]->appendToGndPreds(&gndPredHashArray_);
01389 }
01390
01391 gndPreds_->growToSize(gndPredHashArray_.size());
01392
01393 int numAtoms = getNumAtoms();
01394 int numClauses = getNumClauses();
01395
01396 if (numAtoms == oldNumAtoms && numClauses == oldNumClauses) return;
01397
01398 if (vsdebug) cout << "Clauses: " << numClauses << endl;
01399
01400 atom_.growToSize(numAtoms + 1, false);
01401
01402 makeCost_.growToSize(numAtoms + 1, 0.0);
01403 breakCost_.growToSize(numAtoms + 1, 0.0);
01404 lowAtom_.growToSize(numAtoms + 1, false);
01405 fixedAtom_.growToSize(numAtoms + 1, 0);
01406
01407
01408 for (int i = oldNumAtoms; i < gndPredHashArray_.size(); i++)
01409 {
01410 (*gndPreds_)[i] = gndPredHashArray_[i];
01411
01412 if (vsdebug)
01413 {
01414 cout << "New pred " << i + 1 << ": ";
01415 (*gndPreds_)[i]->print(cout, domain_);
01416 cout << endl;
01417 }
01418
01419 lowAtom_[i + 1] = atom_[i + 1] =
01420 (domain_->getDB()->getValue((*gndPreds_)[i]) == TRUE) ? true : false;
01421 }
01422 clauses.clear();
01423
01424 clause_.growToSize(numClauses);
01425 clauseCost_.growToSize(numClauses);
01426 falseClause_.growToSize(numClauses);
01427 whereFalse_.growToSize(numClauses);
01428 numTrueLits_.growToSize(numClauses);
01429 watch1_.growToSize(numClauses);
01430 watch2_.growToSize(numClauses);
01431 isSatisfied_.growToSize(numClauses, false);
01432 deadClause_.growToSize(numClauses, false);
01433 threshold_.growToSize(numClauses, false);
01434
01435 occurence_.growToSize(2*numAtoms + 1);
01436
01437 for (int i = oldNumClauses; i < numClauses; i++)
01438 {
01439 GroundClause* gndClause = (*gndClauses_)[i];
01440
01441 if (vsdebug)
01442 {
01443 cout << "New clause " << i << ": ";
01444 gndClause->print(cout, domain_, &gndPredHashArray_);
01445 cout << endl;
01446 }
01447
01448
01449 if (gndClause->isHardClause()) threshold_[i] = RAND_MAX;
01450 else
01451 {
01452 double w = gndClause->getWt();
01453 threshold_[i] = RAND_MAX*(1 - exp(-abs(w)));
01454 if (vsdebug)
01455 {
01456 cout << "Weight: " << w << endl;
01457 }
01458 }
01459 if (vsdebug)
01460 cout << "Threshold: " << threshold_[i] << endl;
01461
01462 int numGndPreds = gndClause->getNumGroundPredicates();
01463 clause_[i].growToSize(numGndPreds);
01464 for (int j = 0; j < numGndPreds; j++)
01465 {
01466 int lit = gndClause->getGroundPredicateIndex(j);
01467 clause_[i][j] = lit;
01468 int litIdx = 2*abs(lit) - (lit > 0);
01469 occurence_[litIdx].append(i);
01470 }
01471
01472
01473 if (inferenceMode_==MODE_SAMPLESAT)
01474 {
01475 if (gndClause->getWt()>0) clauseCost_[i] = 1;
01476 else clauseCost_[i] = -1;
01477 }
01478 else if (gndClause->isHardClause())
01479 clauseCost_[i] = hardWt_;
01480 else
01481 clauseCost_[i] = gndClause->getWt();
01482
01483
01484 if (inferenceMode_ == MODE_HARD && !gndClause->isHardClause())
01485 {
01486
01487 deadClause_[i] = true;
01488 }
01489 }
01490
01491 if (addType == ADD_CLAUSE_DEAD)
01492 {
01493
01494 for (int i = oldNumClauses; i < numClauses; i++)
01495 {
01496 deadClause_[i] = true;
01497 }
01498 }
01499 else if (addType == ADD_CLAUSE_SAT)
01500 {
01501
01502 for (int i = oldNumClauses; i < numClauses; i++)
01503 {
01504 isSatisfied_[i]=true;
01505 }
01506 }
01507 else if (addType == ADD_CLAUSE_REGULAR)
01508 {
01509 if (useThreshold_)
01510 killClauses(oldNumClauses);
01511 else
01512 initMakeBreakCostWatch(oldNumClauses);
01513 }
01514
01515 if (vsdebug)
01516 cout << "Done adding new clauses.." << endl;
01517 }
01518
01525 void updateSatisfiedClauses(const int& toFix)
01526 {
01527
01528 bool toFlipValue = getValueOfAtom(toFix);
01529
01530 register int clauseIdx;
01531 int sign;
01532 int litIdx;
01533 if (toFlipValue)
01534 sign = 1;
01535 else
01536 sign = 0;
01537
01538
01539 litIdx = 2*toFix - sign;
01540 Array<int>& posOccArray = getOccurenceArray(litIdx);
01541 for (int i = 0; i < posOccArray.size(); i++)
01542 {
01543 clauseIdx = posOccArray[i];
01544
01545
01546 if (deadClause_[clauseIdx] || isSatisfied_[clauseIdx]) continue;
01547
01548 if (getClauseCost(clauseIdx) < 0)
01549 {
01550 cout << "ERR: in MC-SAT, active neg-wt clause (" << clauseIdx
01551 << ") is sat by fixed "<<endl;
01552 exit(0);
01553 }
01554 isSatisfied_[clauseIdx] = true;
01555 }
01556 }
01557
01563 void updateMakeBreakCostAfterFlip(const int& toFlip)
01564 {
01565
01566 bool toFlipValue = !getValueOfAtom(toFlip);
01567
01568 register int clauseIdx;
01569 int sign;
01570 int oppSign;
01571 int litIdx;
01572 if (toFlipValue)
01573 sign = 1;
01574 else
01575 sign = 0;
01576 oppSign = sign ^ 1;
01577
01578
01579 litIdx = 2*toFlip - sign;
01580 Array<int>& posOccArray = getOccurenceArray(litIdx);
01581
01582 for (int i = 0; i < posOccArray.size(); i++)
01583 {
01584 clauseIdx = posOccArray[i];
01585
01586 if (deadClause_[clauseIdx] || isSatisfied_[clauseIdx]) continue;
01587
01588
01589 int numTrueLits = decrementNumTrueLits(clauseIdx);
01590 long double cost = getClauseCost(clauseIdx);
01591 int watch1 = getWatch1(clauseIdx);
01592 int watch2 = getWatch2(clauseIdx);
01593
01594
01595
01596 if (numTrueLits == 0)
01597 {
01598
01599 if (cost >= 0)
01600 {
01601
01602 addFalseClause(clauseIdx);
01603
01604 addBreakCost(toFlip, -cost);
01605
01606 addMakeCostToAtomsInClause(clauseIdx, cost);
01607 }
01608
01609 else
01610 {
01611 assert(cost < 0);
01612
01613 removeFalseClause(clauseIdx);
01614
01615 addBreakCostToAtomsInClause(clauseIdx, -cost);
01616
01617 addMakeCost(toFlip, cost);
01618 }
01619 }
01620
01621
01622 else if (numTrueLits == 1)
01623 {
01624 if (watch1 == toFlip)
01625 {
01626 assert(watch1 != watch2);
01627 setWatch1(clauseIdx, watch2);
01628 watch1 = getWatch1(clauseIdx);
01629 }
01630
01631 if (cost >= 0)
01632 {
01633 addBreakCost(watch1, cost);
01634 }
01635
01636 else
01637 {
01638 assert(cost < 0);
01639 addMakeCost(watch1, -cost);
01640 }
01641 }
01642
01643
01644 else
01645 {
01646
01647
01648 if (watch1 == toFlip)
01649 {
01650
01651 int diffTrueLit = getTrueLiteralOtherThan(clauseIdx, watch1, watch2);
01652 setWatch1(clauseIdx, diffTrueLit);
01653 }
01654
01655 else if (watch2 == toFlip)
01656 {
01657
01658 int diffTrueLit = getTrueLiteralOtherThan(clauseIdx, watch1, watch2);
01659 setWatch2(clauseIdx, diffTrueLit);
01660 }
01661 }
01662 }
01663
01664
01665 litIdx = 2*toFlip - oppSign;
01666 Array<int>& negOccArray = getOccurenceArray(litIdx);
01667 for (int i = 0; i < negOccArray.size(); i++)
01668 {
01669 clauseIdx = negOccArray[i];
01670
01671 if (deadClause_[clauseIdx] || isSatisfied_[clauseIdx]) continue;
01672
01673
01674 int numTrueLits = incrementNumTrueLits(clauseIdx);
01675 long double cost = getClauseCost(clauseIdx);
01676 int watch1 = getWatch1(clauseIdx);
01677
01678
01679
01680 if (numTrueLits == 1)
01681 {
01682
01683 if (cost >= 0)
01684 {
01685
01686 removeFalseClause(clauseIdx);
01687
01688 addBreakCost(toFlip, cost);
01689
01690 addMakeCostToAtomsInClause(clauseIdx, -cost);
01691 }
01692
01693 else
01694 {
01695 assert(cost < 0);
01696
01697 addFalseClause(clauseIdx);
01698
01699 addBreakCostToAtomsInClause(clauseIdx, cost);
01700
01701 addMakeCost(toFlip, -cost);
01702 }
01703
01704 setWatch1(clauseIdx, toFlip);
01705 }
01706
01707
01708 else if (numTrueLits == 2)
01709 {
01710 if (cost >= 0)
01711 {
01712
01713
01714 addBreakCost(watch1, -cost);
01715 }
01716 else
01717 {
01718
01719 assert(cost < 0);
01720
01721 addMakeCost(watch1, cost);
01722 }
01723
01724 setWatch2(clauseIdx, toFlip);
01725 }
01726 }
01727 }
01728
01729
01736 bool isActive(const int& atomIdx)
01737 {
01738 return domain_->getDB()->getActiveStatus(gndPredHashArray_[atomIdx-1]);
01739 }
01740
01747 bool isActive(const Predicate* pred)
01748 {
01749 return domain_->getDB()->getActiveStatus(pred);
01750 }
01751
01755 Array<int>& getOccurenceArray(const int& idx)
01756 {
01757 return occurence_[idx];
01758 }
01759
01763 int incrementNumTrueLits(const int& clauseIdx)
01764 {
01765 return ++numTrueLits_[clauseIdx];
01766 }
01767
01771 int decrementNumTrueLits(const int& clauseIdx)
01772 {
01773 return --numTrueLits_[clauseIdx];
01774 }
01775
01779 int getNumTrueLits(const int& clauseIdx)
01780 {
01781 return numTrueLits_[clauseIdx];
01782 }
01783
01787 long double getClauseCost(const int& clauseIdx)
01788 {
01789 return clauseCost_[clauseIdx];
01790 }
01791
01795 Array<int>& getAtomsInClause(const int& clauseIdx)
01796 {
01797 return clause_[clauseIdx];
01798 }
01799
01803 void addFalseClause(const int& clauseIdx)
01804 {
01805 falseClause_[numFalseClauses_] = clauseIdx;
01806 whereFalse_[clauseIdx] = numFalseClauses_;
01807 numFalseClauses_++;
01808 costOfFalseClauses_ += abs(clauseCost_[clauseIdx]);
01809 }
01810
01814 void removeFalseClause(const int& clauseIdx)
01815 {
01816 numFalseClauses_--;
01817 falseClause_[whereFalse_[clauseIdx]] = falseClause_[numFalseClauses_];
01818 whereFalse_[falseClause_[numFalseClauses_]] = whereFalse_[clauseIdx];
01819 costOfFalseClauses_ -= abs(clauseCost_[clauseIdx]);
01820 }
01821
01825 void addBreakCost(const int& atomIdx, const long double& cost)
01826 {
01827 breakCost_[atomIdx] += cost;
01828 }
01829
01833 void subtractBreakCost(const int& atomIdx, const long double& cost)
01834 {
01835 breakCost_[atomIdx] -= cost;
01836 }
01837
01844 void addBreakCostToAtomsInClause(const int& clauseIdx,
01845 const long double& cost)
01846 {
01847 register int size = getClauseSize(clauseIdx);
01848 for (int i = 0; i < size; i++)
01849 {
01850 register int lit = clause_[clauseIdx][i];
01851 breakCost_[abs(lit)] += cost;
01852 }
01853 }
01854
01861 void subtractBreakCostFromAtomsInClause(const int& clauseIdx,
01862 const long double& cost)
01863 {
01864 register int size = getClauseSize(clauseIdx);
01865 for (int i = 0; i < size; i++)
01866 {
01867 register int lit = clause_[clauseIdx][i];
01868 breakCost_[abs(lit)] -= cost;
01869 }
01870 }
01871
01878 void addMakeCost(const int& atomIdx, const long double& cost)
01879 {
01880 makeCost_[atomIdx] += cost;
01881 }
01882
01889 void subtractMakeCost(const int& atomIdx, const long double& cost)
01890 {
01891 makeCost_[atomIdx] -= cost;
01892 }
01893
01900 void addMakeCostToAtomsInClause(const int& clauseIdx,
01901 const long double& cost)
01902 {
01903 register int size = getClauseSize(clauseIdx);
01904 for (int i = 0; i < size; i++)
01905 {
01906 register int lit = clause_[clauseIdx][i];
01907 makeCost_[abs(lit)] += cost;
01908 }
01909 }
01910
01917 void subtractMakeCostFromAtomsInClause(const int& clauseIdx,
01918 const long double& cost)
01919 {
01920 register int size = getClauseSize(clauseIdx);
01921 for (int i = 0; i < size; i++)
01922 {
01923 register int lit = clause_[clauseIdx][i];
01924 makeCost_[abs(lit)] -= cost;
01925 }
01926 }
01927
01937 const int getTrueLiteralOtherThan(const int& clauseIdx,
01938 const int& atomIdx1,
01939 const int& atomIdx2)
01940 {
01941 register int size = getClauseSize(clauseIdx);
01942 for (int i = 0; i < size; i++)
01943 {
01944 register int lit = clause_[clauseIdx][i];
01945 register int v = abs(lit);
01946 if (isTrueLiteral(lit) && v != atomIdx1 && v != atomIdx2)
01947 return v;
01948 }
01949
01950 assert(false);
01951 return -1;
01952 }
01953
01957 const bool isTrueLiteral(const int& literal)
01958 {
01959 return ((literal > 0) == atom_[abs(literal)]);
01960 }
01961
01965 const int getAtomInClause(const int& atomIdxInClause, const int& clauseIdx)
01966 {
01967 return clause_[clauseIdx][atomIdxInClause];
01968 }
01969
01973 const int getRandomAtomInClause(const int& clauseIdx)
01974 {
01975 return clause_[clauseIdx][random()%getClauseSize(clauseIdx)];
01976 }
01977
01984 const int getRandomTrueLitInClause(const int& clauseIdx)
01985 {
01986 assert(numTrueLits_[clauseIdx] > 0);
01987 int trueLit = random()%numTrueLits_[clauseIdx];
01988 int whichTrueLit = 0;
01989 for (int i = 0; i < getClauseSize(clauseIdx); i++)
01990 {
01991 int lit = clause_[clauseIdx][i];
01992 int atm = abs(lit);
01993
01994 if (isTrueLiteral(lit))
01995 if (trueLit == whichTrueLit++)
01996 return atm;
01997 }
01998
01999 assert(false);
02000 return -1;
02001 }
02002
02003 const double getMaxClauseWeight()
02004 {
02005 double maxWeight = 0.0;
02006 for (int i = 0; i < getNumClauses(); i++)
02007 {
02008 double weight = abs(clauseCost_[i]);
02009 if (weight > maxWeight) maxWeight = weight;
02010 }
02011 return maxWeight;
02012 }
02013
02014 const long double getMakeCost(const int& atomIdx)
02015 {
02016 return makeCost_[atomIdx];
02017 }
02018
02019 const long double getBreakCost(const int& atomIdx)
02020 {
02021 return breakCost_[atomIdx];
02022 }
02023
02024 const int getClauseSize(const int& clauseIdx)
02025 {
02026 return clause_[clauseIdx].size();
02027 }
02028
02029 const int getWatch1(const int& clauseIdx)
02030 {
02031 return watch1_[clauseIdx];
02032 }
02033
02034 void setWatch1(const int& clauseIdx, const int& atomIdx)
02035 {
02036 watch1_[clauseIdx] = atomIdx;
02037 }
02038
02039 const int getWatch2(const int& clauseIdx)
02040 {
02041 return watch2_[clauseIdx];
02042 }
02043
02044 void setWatch2(const int& clauseIdx, const int& atomIdx)
02045 {
02046 watch2_[clauseIdx] = atomIdx;
02047 }
02048
02053 const int getBlockIndex(const int& atomIdx)
02054 {
02055 const GroundPredicate* gndPred = (*gndPreds_)[atomIdx];
02056 return domain_->getBlock(gndPred);
02057 }
02058
02062 const long double getLowCost()
02063 {
02064 return lowCost_;
02065 }
02066
02070 const int getLowBad()
02071 {
02072 return lowBad_;
02073 }
02074
02079 void makeUnitCosts()
02080 {
02081 for (int i = 0; i < clauseCost_.size(); i++)
02082 {
02083 if (clauseCost_[i] >= 0) clauseCost_[i] = 1.0;
02084 else
02085 {
02086 assert(clauseCost_[i] < 0);
02087 clauseCost_[i] = -1.0;
02088 }
02089 }
02090 if (vsdebug) cout << "Made unit costs" << endl;
02091 initMakeBreakCostWatch();
02092 }
02093
02097 void saveLowState()
02098 {
02099 if (vsdebug) cout << "Saving low state: " << endl;
02100 for (int i = 1; i <= getNumAtoms(); i++)
02101 {
02102 lowAtom_[i] = atom_[i];
02103 if (vsdebug) cout << lowAtom_[i] << endl;
02104 }
02105 lowCost_ = costOfFalseClauses_;
02106 lowBad_ = numFalseClauses_;
02107 }
02108
02112 int getTrueFixedAtomInBlock(const int& blockIdx)
02113 {
02114 const Predicate* truePred = domain_->getTruePredInBlock(blockIdx);
02115 if (truePred)
02116 {
02117 if (vsdebug)
02118 {
02119 cout << "True pred in block " << blockIdx << ": ";
02120 truePred->printWithStrVar(cout, domain_);
02121 cout << endl;
02122 }
02123
02124 GroundPredicate* trueGndPred = new GroundPredicate((Predicate*)truePred);
02125
02126 int atomIdx = gndPredHashArray_.find(trueGndPred);
02127 delete trueGndPred;
02128 if (atomIdx > -1 && fixedAtom_[atomIdx + 1] > 0)
02129 return atomIdx;
02130 }
02131 return -1;
02132 }
02133
02134 const GroundPredicateHashArray* getGndPredHashArrayPtr() const
02135 {
02136 return &gndPredHashArray_;
02137 }
02138
02139 const GroundPredicateHashArray* getUnePreds() const
02140 {
02141 return unePreds_;
02142 }
02143
02144 const GroundPredicateHashArray* getKnePreds() const
02145 {
02146 return knePreds_;
02147 }
02148
02149 const Array<TruthValue>* getKnePredValues() const
02150 {
02151 return knePredValues_;
02152 }
02153
02157 void setGndClausesWtsToSumOfParentWts()
02158 {
02159 for (int i = 0; i < gndClauses_->size(); i++)
02160 {
02161 GroundClause* gndClause = (*gndClauses_)[i];
02162 gndClause->setWtToSumOfParentWts(mln_);
02163 if (gndClause->isHardClause())
02164 clauseCost_[i] = hardWt_;
02165 else
02166 clauseCost_[i] = gndClause->getWt();
02167
02168 if (vsdebug) cout << "Setting cost of clause " << i << " to "
02169 << clauseCost_[i] << endl;
02170
02171
02172 if (gndClause->isHardClause()) threshold_[i] = RAND_MAX;
02173 else
02174 {
02175 double w = gndClause->getWt();
02176 threshold_[i] = RAND_MAX*(1 - exp(-abs(w)));
02177 if (vsdebug)
02178 {
02179 cout << "Weight: " << w << endl;
02180 }
02181 }
02182 if (vsdebug)
02183 cout << "Threshold: " << threshold_[i] << endl;
02184 }
02185 }
02186
02197 void getNumClauseGndings(Array<double>* const & numGndings, bool tv)
02198 {
02199
02200
02201
02202
02203
02204
02205
02206
02207
02208 Array<double> lazyFalseGndings(numGndings->size(), 0);
02209 Array<double> lazyTrueGndings(numGndings->size(), 0);
02210
02211 IntBoolPairItr itr;
02212 IntBoolPair *clauseFrequencies;
02213
02214
02215 int clauseCnt = numGndings->size();
02216 assert(clauseCnt == mln_->getNumClauses());
02217 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
02218 assert ((*numGndings)[clauseno] >= 0);
02219
02220 for (int i = 0; i < gndClauses_->size(); i++)
02221 {
02222 GroundClause *gndClause = (*gndClauses_)[i];
02223 int satLitcnt = 0;
02224 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
02225 {
02226 int lit = gndClause->getGroundPredicateIndex(j);
02227 if (isTrueLiteral(lit)) satLitcnt++;
02228 }
02229
02230 clauseFrequencies = gndClause->getClauseFrequencies();
02231 for (itr = clauseFrequencies->begin();
02232 itr != clauseFrequencies->end(); itr++)
02233 {
02234 int clauseno = itr->first;
02235 int frequency = itr->second.first;
02236 bool invertWt = itr->second.second;
02237
02238 if (invertWt)
02239 {
02240
02241 if (tv && satLitcnt > 0)
02242 {
02243
02244 if (lazy_) lazyFalseGndings[clauseno] += frequency;
02245 continue;
02246 }
02247
02248 if (!tv && satLitcnt == 0)
02249 {
02250
02251 if (lazy_) lazyTrueGndings[clauseno] += frequency;
02252 continue;
02253 }
02254 }
02255 else
02256 {
02257
02258 if (tv && satLitcnt == 0)
02259 {
02260
02261 if (lazy_) lazyFalseGndings[clauseno] += frequency;
02262 continue;
02263 }
02264
02265 if (!tv && satLitcnt > 0)
02266 {
02267
02268 if (lazy_) lazyTrueGndings[clauseno] += frequency;
02269 continue;
02270 }
02271 }
02272 (*numGndings)[clauseno] += frequency;
02273 }
02274 }
02275
02276
02277
02278
02279 if (lazy_)
02280 {
02281 for (int c = 0; c < mln_->getNumClauses(); c++)
02282 {
02283 const Clause* clause = mln_->getClause(c);
02284
02285 if (tv && clause->getWt() >= 0)
02286 {
02287 double totalGndings = domain_->getNumNonEvidGroundings(c);
02288 assert(totalGndings >= (*numGndings)[c] + lazyFalseGndings[c]);
02289 double remainingTrueGndings = totalGndings - lazyFalseGndings[c] -
02290 (*numGndings)[c];
02291 (*numGndings)[c] += remainingTrueGndings;
02292 }
02293
02294 else if (!tv && clause->getWt() < 0)
02295 {
02296 double totalGndings = domain_->getNumNonEvidGroundings(c);
02297 assert(totalGndings >= (*numGndings)[c] + lazyTrueGndings[c]);
02298 double remainingFalseGndings = totalGndings - lazyTrueGndings[c] -
02299 (*numGndings)[c];
02300 (*numGndings)[c] += remainingFalseGndings;
02301 }
02302 }
02303 }
02304
02305 }
02306
02318 void getNumClauseGndingsWithUnknown(double numGndings[], int clauseCnt,
02319 bool tv,
02320 const Array<bool>* const& unknownPred)
02321 {
02322
02323 assert(unknownPred->size() == getNumAtoms());
02324 IntBoolPairItr itr;
02325 IntBoolPair *clauseFrequencies;
02326
02327 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
02328 numGndings[clauseno] = 0;
02329
02330 for (int i = 0; i < gndClauses_->size(); i++)
02331 {
02332 GroundClause *gndClause = (*gndClauses_)[i];
02333 int satLitcnt = 0;
02334 bool unknown = false;
02335 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
02336 {
02337 int lit = gndClause->getGroundPredicateIndex(j);
02338 if ((*unknownPred)[abs(lit) - 1])
02339 {
02340 unknown = true;
02341 continue;
02342 }
02343 if (isTrueLiteral(lit)) satLitcnt++;
02344 }
02345
02346 clauseFrequencies = gndClause->getClauseFrequencies();
02347 for (itr = clauseFrequencies->begin();
02348 itr != clauseFrequencies->end(); itr++)
02349 {
02350 int clauseno = itr->first;
02351 int frequency = itr->second.first;
02352 bool invertWt = itr->second.second;
02353
02354 if (invertWt)
02355 {
02356
02357 if (tv && (satLitcnt > 0 || unknown))
02358 continue;
02359
02360 if (!tv && satLitcnt == 0)
02361 continue;
02362 }
02363 else
02364 {
02365
02366 if (tv && satLitcnt == 0)
02367 continue;
02368
02369 if (!tv && (satLitcnt > 0 || unknown))
02370 continue;
02371 }
02372 numGndings[clauseno] += frequency;
02373 }
02374 }
02375 }
02376
02384 void setOthersInBlockToFalse(const int& atomIdx, const int& blockIdx)
02385 {
02386 if (vsdebug)
02387 {
02388 cout << "Set all in block " << blockIdx << " to false except "
02389 << atomIdx << endl;
02390 }
02391 int blockSize = domain_->getBlockSize(blockIdx);
02392 for (int i = 0; i < blockSize; i++)
02393 {
02394 const Predicate* pred = domain_->getPredInBlock(i, blockIdx);
02395 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
02396 int idx = gndPredHashArray_.find(gndPred);
02397
02398 if (vsdebug)
02399 {
02400 cout << "Gnd pred in block ";
02401 gndPred->print(cout, domain_);
02402 cout << " (idx " << idx << ")" << endl;
02403 }
02404
02405 delete gndPred;
02406 delete pred;
02407
02408
02409 if (idx >= 0)
02410 {
02411
02412 if (idx != atomIdx && fixedAtom_[idx + 1] == 0)
02413 {
02414 if (vsdebug)
02415 cout << "Set " << idx + 1 << " to false" << endl;
02416
02417 bool activate = true;
02418 setValueOfAtom(idx + 1, false, activate, -1);
02419 }
02420 }
02421 }
02422 }
02423
02424
02433 void fixAtom(const int& atomIdx, const bool& value)
02434 {
02435 assert(atomIdx > 0);
02436 if (vsdebug)
02437 {
02438 cout << "Fixing ";
02439 (*gndPreds_)[atomIdx - 1]->print(cout, domain_);
02440 cout << " to " << value << endl;
02441 }
02442
02443
02444 if (!useThreshold_)
02445 {
02446 cout << ">>> [ERR] useThreshold_ is false" << endl;
02447 exit(0);
02448 }
02449
02450
02451 if ((fixedAtom_[atomIdx] == 1 && value == false) ||
02452 (fixedAtom_[atomIdx] == -1 && value == true))
02453 {
02454 cout << "Contradiction: Tried to fix atom " << atomIdx <<
02455 " to true and false ... exiting." << endl;
02456 exit(0);
02457 }
02458
02459
02460 if (fixedAtom_[atomIdx] != 0) return;
02461
02462 fixedAtom_[atomIdx] = (value) ? 1 : -1;
02463 if (atom_[atomIdx] != value)
02464 {
02465 bool activate = false;
02466 int blockIdx = getBlockIndex(atomIdx - 1);
02467 setValueOfAtom(atomIdx, value, activate, blockIdx);
02468 updateMakeBreakCostAfterFlip(atomIdx);
02469 }
02470
02471
02472 updateSatisfiedClauses(atomIdx);
02473
02474
02475 if (value)
02476 {
02477 int blockIdx = getBlockIndex(atomIdx - 1);
02478 if (blockIdx > -1)
02479 {
02480 int blockSize = domain_->getBlockSize(blockIdx);
02481 for (int i = 0; i < blockSize; i++)
02482 {
02483 const Predicate* pred = domain_->getPredInBlock(i, blockIdx);
02484 GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
02485 int idx = gndPredHashArray_.find(gndPred);
02486 delete gndPred;
02487 delete pred;
02488
02489
02490 if (idx >= 0)
02491 {
02492
02493 if (idx != (atomIdx - 1))
02494 {
02495
02496 if (fixedAtom_[idx + 1] == 1)
02497 {
02498 cout << "Contradiction: Tried to fix atom " << idx + 1 <<
02499 " to true and false ... exiting." << endl;
02500 exit(0);
02501 }
02502
02503 if (fixedAtom_[idx + 1] == -1) continue;
02504 if (vsdebug)
02505 {
02506 cout << "Fixing ";
02507 (*gndPreds_)[idx]->print(cout, domain_);
02508 cout << " to 0" << endl;
02509 }
02510 fixedAtom_[idx + 1] = -1;
02511 if (atom_[idx + 1] != false)
02512 {
02513 bool activate = false;
02514 setValueOfAtom(idx + 1, value, activate, blockIdx);
02515 updateMakeBreakCostAfterFlip(idx + 1);
02516 }
02517
02518 updateSatisfiedClauses(idx + 1);
02519 }
02520 }
02521 }
02522 }
02523 }
02524
02525
02526
02527 if (lazy_ && !isActive(atomIdx) && value)
02528 {
02529
02530
02531 Predicate* p =
02532 gndPredHashArray_[atomIdx - 1]->createEquivalentPredicate(domain_);
02533
02534 bool ignoreActivePreds = false;
02535 Array<GroundClause*> unsatClauses;
02536 getActiveClauses(p, unsatClauses, true, ignoreActivePreds);
02537
02538
02539
02540 addNewClauses(ADD_CLAUSE_REGULAR, unsatClauses);
02541
02542
02543 domain_->getDB()->setActiveStatus(p, true);
02544 activeAtoms_++;
02545
02546 delete p;
02547 }
02548 }
02549
02561 Array<int>* simplifyClauseFromFixedAtoms(const int& clauseIdx)
02562 {
02563 Array<int>* returnArray = new Array<int>;
02564
02565 if (isSatisfied_[clauseIdx]) return returnArray;
02566
02567
02568
02569 bool isGood = (clauseCost_[clauseIdx] >= 0) ? false : true;
02570
02571 bool allFalseAtoms = (clauseCost_[clauseIdx] >= 0) ? true : false;
02572
02573 for (int i = 0; i < getClauseSize(clauseIdx); i++)
02574 {
02575 int lit = clause_[clauseIdx][i];
02576 int fixedValue = fixedAtom_[abs(lit)];
02577
02578 if (clauseCost_[clauseIdx] >= 0)
02579 {
02580 if ((fixedValue == 1 && lit > 0) ||
02581 (fixedValue == -1 && lit < 0))
02582 {
02583 isGood = true;
02584 allFalseAtoms = false;
02585 returnArray->clear();
02586 break;
02587 }
02588 else if (fixedValue == 0)
02589 {
02590 allFalseAtoms = false;
02591 returnArray->append(lit);
02592 }
02593 }
02594 else
02595 {
02596 assert(clauseCost_[clauseIdx] < 0);
02597 if ((fixedValue == 1 && lit > 0) ||
02598 (fixedValue == -1 && lit < 0))
02599 {
02600 cout << "Contradiction: Tried to fix atom " << abs(lit) <<
02601 " to true in a negative clause ... exiting." << endl;
02602 exit(0);
02603 }
02604 else
02605 {
02606 returnArray->append(lit);
02607
02608 if (fixedValue == 0) isGood = false;
02609 }
02610 }
02611 }
02612 if (allFalseAtoms)
02613 {
02614 cout << "Contradiction: All atoms in clause " << clauseIdx <<
02615 " fixed to false ... exiting." << endl;
02616 exit(0);
02617 }
02618 if (isGood) isSatisfied_[clauseIdx] = true;
02619 return returnArray;
02620 }
02621
02628 const bool isDeadClause(const int& clauseIdx)
02629 {
02630 return deadClause_[clauseIdx];
02631 }
02632
02636 void eliminateSoftClauses()
02637 {
02638 bool atLeastOneDead = false;
02639 for (int i = 0; i < getNumClauses(); i++)
02640 {
02641 if (!(*gndClauses_)[i]->isHardClause())
02642 {
02643 atLeastOneDead = true;
02644 deadClause_[i] = true;
02645 }
02646 }
02647
02648 if (atLeastOneDead) initMakeBreakCostWatch();
02649 }
02650
02651 void LoadDisEviValuesFromRst(const char* szDisEvi)
02652 {
02653 for(int i = 1; i < atomEvi_.size(); i++)
02654 {
02655 atomEvi_[i] = false;
02656 }
02657 map<string, int> gndPredCont;
02658 for(int i = 0; i < gndPreds_->size(); i++)
02659 {
02660 string str = (*gndPreds_)[i]->getPredicateStr(domain_);
02661 gndPredCont.insert(map<string, int>::value_type(str, i+1));
02662 }
02663 ifstream is(szDisEvi);
02664 string strLine;
02665 while (getline(is, strLine))
02666 {
02667 stringstream ss(strLine);
02668 string strtmp;
02669 getline(ss,strtmp, ' ');
02670
02671
02672 map<string, int>::const_iterator citer;
02673 citer = gndPredCont.find(strtmp);
02674 if (citer == gndPredCont.end())
02675 {
02676 cout << "dis evi file error, non-existent query gndings" << endl;
02677 cout << strLine << endl;
02678 exit(1);
02679 }
02680 int atomIdx = citer->second;
02681
02682
02683 getline(ss, strtmp);
02684 int b = atoi(strtmp.c_str());
02685 if (b==1)
02686 {
02687 atomEvi_[atomIdx] = true;
02688 }
02689 else
02690 {
02691 atomEvi_[atomIdx] = false;
02692 }
02693 }
02694 }
02695
02703 void killClauses(const int& startClause)
02704 {
02705
02706 if (inferenceMode_ != MODE_HARD)
02707 {
02708 for (int i = startClause; i < getNumClauses(); i++)
02709 {
02710 GroundClause* clause = (*gndClauses_)[i];
02711 if ((clauseGoodInPrevious(i)) &&
02712 (clause->isHardClause() || random() <= threshold_[i]))
02713 {
02714 if (vsdebug)
02715 {
02716 cout << "Keeping clause "<< i << " ";
02717 clause->print(cout, domain_, &gndPredHashArray_);
02718 cout << endl;
02719 }
02720 deadClause_[i] = false;
02721 }
02722 else
02723 {
02724 deadClause_[i] = true;
02725 }
02726 }
02727 }
02728
02729 initMakeBreakCostWatch(startClause);
02730 }
02731
02732
02740 const bool clauseGoodInPrevious(const int& clauseIdx)
02741 {
02742
02743 return (clauseIdx >= prevSatisfiedClause_.size() ||
02744 prevSatisfiedClause_[clauseIdx]);
02745 }
02746
02750 void resetDeadClauses()
02751 {
02752 for (int i = 0; i < deadClause_.size(); i++)
02753 deadClause_[i] = false;
02754 initMakeBreakCostWatch();
02755 }
02756
02760 void resetFixedAtoms()
02761 {
02762 for (int i = 0; i < fixedAtom_.size(); i++)
02763 fixedAtom_[i] = 0;
02764 for (int i = 0; i < isSatisfied_.size(); i++)
02765 isSatisfied_[i] = false;
02766 }
02767
02768 void setLazy(const bool& l) { lazy_ = l; }
02769 const bool getLazy() { return lazy_; }
02770
02771 void setUseThreshold(const bool& t) { useThreshold_ = t;}
02772 const bool getUseThreshold() { return useThreshold_; }
02773
02774 long double getHardWt() { return hardWt_; }
02775
02776 const Domain* getDomain() { return domain_; }
02777
02778 const MLN* getMLN() { return mln_; }
02779
02785 void printLowState(ostream& out)
02786 {
02787 for (int i = 0; i < getNumAtoms(); i++)
02788 {
02789 (*gndPreds_)[i]->print(out, domain_);
02790 out << " " << lowAtom_[i + 1] << endl;
02791 }
02792 }
02793
02800 void printGndPred(const int& predIndex, ostream& out)
02801 {
02802 (*gndPreds_)[predIndex]->print(out, domain_);
02803 }
02804
02811 int getIndexOfGroundPredicate(GroundPredicate* const & gndPred)
02812 {
02813 return gndPredHashArray_.find(gndPred);
02814 }
02815
02826 void setAsEvidence(const GroundPredicate* const & predicate,
02827 const bool& trueEvidence)
02828 {
02829 if (vsdebug)
02830 {
02831 cout << "Setting to evidence " ;
02832 predicate->print(cout, domain_);
02833 cout << endl;
02834 }
02835 Database* db = domain_->getDB();
02836 int atomIdx = gndPredHashArray_.find((GroundPredicate*)predicate);
02837
02838 if (atomIdx <= 0)
02839 {
02840
02841 if (db->getValue(predicate) == trueEvidence)
02842 return;
02843
02844
02845 if (trueEvidence)
02846 db->setValue(predicate, TRUE);
02847 else
02848 db->setValue(predicate, FALSE);
02849 }
02850 else
02851 {
02852 Array<int> gndClauseIndexes;
02853 int deleted;
02854 gndClauseIndexes = getNegOccurenceArray(atomIdx + 1);
02855 gndClauseIndexes.bubbleSort();
02856
02857
02858 deleted = 0;
02859 for (int i = 0; i < gndClauseIndexes.size(); i++)
02860 {
02861
02862
02863
02864 if (!trueEvidence ||
02865 (*gndClauses_)[gndClauseIndexes[i]]->getNumGroundPredicates() == 1)
02866 {
02867 if (vsdebug)
02868 cout << "Deleting ground clause " << gndClauseIndexes[i] << endl;
02869
02870
02871 delete (*gndClauses_)[gndClauseIndexes[i] - deleted];
02872 gndClauses_->removeItem(gndClauseIndexes[i] - deleted);
02873 deleted++;
02874 }
02875 else
02876 {
02877 if (vsdebug)
02878 {
02879 cout << "Removing gnd pred " << -(atomIdx + 1)
02880 << " from ground clause " << gndClauseIndexes[i] << endl;
02881 }
02882 (*gndClauses_)[gndClauseIndexes[i]]->removeGndPred(-(atomIdx + 1));
02883 }
02884 }
02885
02886 gndClauseIndexes = getPosOccurenceArray(atomIdx + 1);
02887 gndClauseIndexes.bubbleSort();
02888
02889
02890 deleted = 0;
02891 for (int i = 0; i < gndClauseIndexes.size(); i++)
02892 {
02893
02894
02895
02896 if (trueEvidence ||
02897 (*gndClauses_)[gndClauseIndexes[i]]->getNumGroundPredicates() == 1)
02898 {
02899 if (vsdebug)
02900 cout << "Deleting ground clause " << gndClauseIndexes[i] << endl;
02901
02902
02903 delete (*gndClauses_)[gndClauseIndexes[i] - deleted];
02904 gndClauses_->removeItem(gndClauseIndexes[i] - deleted);
02905 deleted++;
02906 }
02907 else
02908 {
02909 if (vsdebug)
02910 {
02911 cout << "Removing gnd pred " << -(atomIdx + 1)
02912 << " from ground clause " << gndClauseIndexes[i] << endl;
02913 }
02914 (*gndClauses_)[gndClauseIndexes[i]]->removeGndPred(atomIdx + 1);
02915 }
02916 }
02917
02918 gndPredHashArray_.removeItemFastDisorder(atomIdx);
02919 gndPredHashArray_.compress();
02920 gndPreds_->removeItemFastDisorder(atomIdx);
02921 gndPreds_->compress();
02922
02923
02924
02925 int oldIdx = gndPredHashArray_.size();
02926 replaceAtomIndexInAllClauses(oldIdx, atomIdx);
02927 }
02928 }
02929
02938 void setAsQuery(const GroundPredicate* const & predicate)
02939 {
02940 if (vsdebug)
02941 {
02942 cout << "Setting to query " ;
02943 predicate->print(cout, domain_);
02944 cout << endl;
02945 }
02946 Database* db = domain_->getDB();
02947
02948 if (gndPredHashArray_.contains((GroundPredicate*)predicate))
02949 return;
02950 else
02951 {
02952
02953
02954 gndPredHashArray_.append((GroundPredicate*)predicate);
02955 Predicate* p = predicate->createEquivalentPredicate(domain_);
02956 db->setEvidenceStatus(p, false);
02957 bool ignoreActivePreds = true;
02958 getActiveClauses(p, newClauses_, true, ignoreActivePreds);
02959 }
02960 }
02961
02963
02970 GroundPredicate* getGndPred(const int& index)
02971 {
02972 return (*gndPreds_)[index];
02973 }
02974
02981 GroundClause* getGndClause(const int& index)
02982 {
02983 return (*gndClauses_)[index];
02984 }
02985
02989 void saveLowStateToGndPreds()
02990 {
02991 for (int i = 0; i < getNumAtoms(); i++)
02992 (*gndPreds_)[i]->setTruthValue(lowAtom_[i + 1]);
02993 }
02994
02998 void saveLowStateToDB()
02999 {
03000 for (int i = 0; i < getNumAtoms(); i++)
03001 {
03002 GroundPredicate* p = gndPredHashArray_[i];
03003 bool value = lowAtom_[i + 1];
03004 if (value)
03005 {
03006 domain_->getDB()->setValue(p, TRUE);
03007 }
03008 else
03009 {
03010 domain_->getDB()->setValue(p, FALSE);
03011 }
03012 }
03013 }
03014
03021 const int getGndPredIndex(GroundPredicate* const& gndPred)
03022 {
03023 return gndPreds_->find(gndPred);
03024 }
03025
03026
03028
03029
03031
03045 void getActiveClauses(Predicate *inputPred,
03046 Array<GroundClause*>& activeClauses,
03047 bool const & active,
03048 bool const & ignoreActivePreds)
03049 {
03050 Timer timer;
03051 double currTime;
03052
03053 Clause *fclause;
03054 GroundClause* newClause;
03055 int clauseCnt;
03056 GroundClauseHashArray clauseHashArray;
03057
03058 Array<GroundClause*>* newClauses = new Array<GroundClause*>;
03059
03060 const Array<IndexClause*>* indexClauses = NULL;
03061
03062
03063 if (inputPred == NULL)
03064 {
03065 clauseCnt = mln_->getNumClauses();
03066 }
03067
03068 else
03069 {
03070 if (domain_->getDB()->getDeactivatedStatus(inputPred)) return;
03071 int predId = inputPred->getId();
03072 indexClauses = mln_->getClausesContainingPred(predId);
03073 clauseCnt = indexClauses->size();
03074 }
03075
03076
03077 int clauseno = 0;
03078
03079 while (clauseno < clauseCnt)
03080 {
03081 if (inputPred)
03082 fclause = (Clause *) (*indexClauses)[clauseno]->clause;
03083 else
03084 fclause = (Clause *) mln_->getClause(clauseno);
03085
03086 if (vsdebug)
03087 {
03088 cout << "Getting active clauses for FO clause: ";
03089 fclause->print(cout, domain_);
03090 cout << endl;
03091 }
03092
03093 currTime = timer.time();
03094
03095 const double* parentWtPtr = NULL;
03096 if (!fclause->isHardClause()) parentWtPtr = fclause->getWtPtr();
03097 const int clauseId = mln_->findClauseIdx(fclause);
03098 newClauses->clear();
03099
03100 if (stillActivating_)
03101 stillActivating_ = fclause->getActiveClauses(inputPred, domain_,
03102 newClauses,
03103 &gndPredHashArray_,
03104 ignoreActivePreds);
03105
03106 for (int i = 0; i < newClauses->size(); i++)
03107 {
03108 long double wt = fclause->getWt();
03109 newClause = (*newClauses)[i];
03110
03111
03112 if (gndClauses_->find(newClause) >= 0)
03113 {
03114 delete newClause;
03115 continue;
03116 }
03117
03118 bool invertWt = false;
03119
03120 if (!fclause->isHardClause() &&
03121 newClause->getNumGroundPredicates() == 1 &&
03122 !newClause->getGroundPredicateSense(0))
03123 {
03124 newClause->setGroundPredicateSense(0, true);
03125 newClause->setWt(-newClause->getWt());
03126 wt = -wt;
03127 invertWt = true;
03128 int addToIndex = gndClauses_->find(newClause);
03129 if (addToIndex >= 0)
03130 {
03131 (*gndClauses_)[addToIndex]->addWt(wt);
03132 if (parentWtPtr)
03133 (*gndClauses_)[addToIndex]->incrementClauseFrequency(clauseId, 1,
03134 invertWt);
03135 delete newClause;
03136 continue;
03137 }
03138 }
03139
03140 int pos = clauseHashArray.find(newClause);
03141
03142 if (pos >= 0)
03143 {
03144
03145
03146 if (clauseHashArray[pos]->getClauseFrequency(clauseId) > 0)
03147 {
03148 delete newClause;
03149 continue;
03150 }
03151 if (vsdebug)
03152 {
03153 cout << "Adding weight " << wt << " to clause ";
03154 clauseHashArray[pos]->print(cout, domain_, &gndPredHashArray_);
03155 cout << endl;
03156 }
03157 clauseHashArray[pos]->addWt(wt);
03158 if (parentWtPtr)
03159 clauseHashArray[pos]->incrementClauseFrequency(clauseId, 1,
03160 invertWt);
03161 delete newClause;
03162 continue;
03163 }
03164
03165
03166 newClause->setWt(wt);
03167 if (parentWtPtr)
03168 newClause->incrementClauseFrequency(clauseId, 1, invertWt);
03169
03170 if (vsdebug)
03171 {
03172 cout << "Appending clause ";
03173 newClause->print(cout, domain_, &gndPredHashArray_);
03174 cout << endl;
03175 }
03176 clauseHashArray.append(newClause);
03177 }
03178 clauseno++;
03179
03180 }
03181
03182 for (int i = 0; i < clauseHashArray.size(); i++)
03183 {
03184 newClause = clauseHashArray[i];
03185 activeClauses.append(newClause);
03186 }
03187
03188 newClauses->clear();
03189 delete newClauses;
03190
03191 clauseHashArray.clear();
03192 }
03193
03201 void getActiveClauses(Array<GroundClause*> &allClauses,
03202 bool const & ignoreActivePreds)
03203 {
03204 getActiveClauses(NULL, allClauses, true, ignoreActivePreds);
03205 }
03206
03207 int getNumActiveAtoms()
03208 {
03209 return activeAtoms_;
03210 }
03211
03212 const void setBreakHardClauses(const bool& breakHardClauses)
03213 {
03214 breakHardClauses_ = breakHardClauses;
03215 }
03217
03218
03219 private:
03220
03226 void setHardClauseWeight()
03227 {
03228
03229 long double sumSoftWts = 0.0;
03230
03231 int clauseCnt = mln_->getNumClauses();
03232
03233 for (int i = 0; i < clauseCnt; i++)
03234 {
03235 Clause* fclause = (Clause *) mln_->getClause(i);
03236
03237 if (fclause->isHardClause()) continue;
03238
03239 long double wt = abs(fclause->getWt());
03240 long double numGndings = fclause->getNumGroundings(domain_);
03241 sumSoftWts += wt*numGndings;
03242 }
03243 assert(sumSoftWts >= 0);
03244
03245 hardWt_ = sumSoftWts + 10.0;
03246 cout << "Set hard weight to " << hardWt_ << endl;
03247 }
03248
03256 void replaceAtomIndexInAllClauses(const int& oldIdx, const int& newIdx)
03257 {
03258 Array<int> gndClauseIndexes;
03259
03260 gndClauseIndexes = getNegOccurenceArray(oldIdx + 1);
03261 for (int i = 0; i < gndClauseIndexes.size(); i++)
03262 {
03263
03264 if ((*gndClauses_)[gndClauseIndexes[i]])
03265 (*gndClauses_)[gndClauseIndexes[i]]->changeGndPredIndex(-(oldIdx + 1),
03266 -(newIdx + 1));
03267 }
03268
03269 gndClauseIndexes = getPosOccurenceArray(oldIdx + 1);
03270 for (int i = 0; i < gndClauseIndexes.size(); i++)
03271 {
03272
03273 if ((*gndClauses_)[gndClauseIndexes[i]])
03274 (*gndClauses_)[gndClauseIndexes[i]]->changeGndPredIndex(oldIdx + 1,
03275 newIdx + 1);
03276 }
03277 }
03278
03279 public:
03280
03281
03282 const static int MODE_MWS = 0;
03283 const static int MODE_HARD = 1;
03284 const static int MODE_SAMPLESAT = 2;
03285
03286 const static int ADD_CLAUSE_INITIAL = 0;
03287 const static int ADD_CLAUSE_REGULAR = 1;
03288 const static int ADD_CLAUSE_DEAD = 2;
03289 const static int ADD_CLAUSE_SAT = 3;
03290
03291 private:
03292
03293 bool lazy_;
03294
03295
03296 long double hardWt_;
03297
03298
03299
03300 MLN* mln_;
03301 Domain* domain_;
03302
03303
03304
03305 Array<GroundPredicate*>* gndPreds_;
03306
03307
03308 GroundClauseHashArray* gndClauses_;
03309
03310
03311
03312 GroundPredicateHashArray* unePreds_;
03313
03314
03315
03316 GroundPredicateHashArray* knePreds_;
03317
03318 Array<TruthValue>* knePredValues_;
03319
03320
03322
03323 int baseNumAtoms_;
03324
03325 bool noApprox_;
03326
03327 bool haveDeactivated_;
03328
03329 int memLimit_;
03330
03331 int clauseLimit_;
03333
03334
03336
03337 MRF* mrf_;
03339
03340
03341 Array<GroundClause*> newClauses_;
03342
03343 Array<GroundPredicate*> newPreds_;
03344
03345
03346 GroundPredicateHashArray gndPredHashArray_;
03347
03348
03349
03350 Array<Array<int> > clause_;
03351
03352 Array<long double> clauseCost_;
03353
03354 long double highestCost_;
03355
03356 bool eqHighest_;
03357
03358 int numHighest_;
03359
03360 Array<int> falseClause_;
03361
03362 Array<int> whereFalse_;
03363
03364 Array<int> numTrueLits_;
03365
03366 Array<int> watch1_;
03367
03368 Array<int> watch2_;
03369
03370
03371 Array<bool> isSatisfied_;
03372
03373 Array<bool> deadClause_;
03374
03375 bool useThreshold_;
03376
03377 Array<long double> threshold_;
03378
03379
03380
03381 Array<Array<int> > occurence_;
03382
03383
03384 Array<bool> atom_;
03385 Array<bool> atomEvi_;
03386
03387 Array<long double> makeCost_;
03388
03389 Array<long double> breakCost_;
03390
03391
03392 Array<int> fixedAtom_;
03393
03394
03395 Array<bool> lowAtom_;
03396
03397 long double lowCost_;
03398
03399 int lowBad_;
03400
03401
03402 int numFalseClauses_;
03403
03404 long double costOfFalseClauses_;
03405
03406
03407 int activeAtoms_;
03408
03409
03410 Array<bool> prevSatisfiedClause_;
03411
03412
03413 int numNonEvAtoms_;
03414
03415
03416
03417 int inferenceMode_;
03418
03419
03420
03421
03422 bool stillActivating_;
03423
03424
03425 bool breakHardClauses_;
03426
03427 };
03428
03429 #endif