00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef VARIABLESTATE_H_
00067 #define VARIABLESTATE_H_
00068
00069 #include "mrf.h"
00070
00071 const int NOVALUE = -1;
00072 const bool vsdebug = false;
00073
00093 class VariableState
00094 {
00095 public:
00096
00115 VariableState(GroundPredicateHashArray* const& unknownQueries,
00116 GroundPredicateHashArray* const& knownQueries,
00117 Array<TruthValue>* const & knownQueryValues,
00118 const Array<int>* const & allPredGndingsAreQueries,
00119 const bool& markHardGndClauses,
00120 const bool& trackParentClauseWts,
00121 const MLN* const & mln, const Domain* const & domain,
00122 const bool& lazy)
00123 {
00124 this->mln_ = (MLN*)mln;
00125 this->domain_ = (Domain*)domain;
00126 this->lazy_ = lazy;
00127
00128
00129 baseNumAtoms_ = 0;
00130 activeAtoms_ = 0;
00131 numFalseClauses_ = 0;
00132 costOfFalseClauses_ = 0.0;
00133 lowCost_ = LDBL_MAX;
00134 lowBad_ = INT_MAX;
00135
00136
00137 gndClauses_ = new Array<GroundClause*>;
00138 gndPreds_ = new Array<GroundPredicate*>;
00139
00140
00141 setHardClauseWeight();
00142
00143
00144 if (lazy_)
00145 {
00146
00147 domain_->getDB()->setPerformingInference(true);
00148
00149
00150 initLazyBlocks();
00151
00152 clauseLimit_ = INT_MAX;
00153 noApprox_ = false;
00154 haveDeactivated_ = false;
00155
00157
00158
00159
00160
00161 addOneAtomToEachBlock();
00162
00163
00164
00165
00166
00167 bool ignoreActivePreds = false;
00168 getActiveClauses(newClauses_, ignoreActivePreds);
00169 int defaultCnt = newClauses_.size();
00170 long double defaultCost = 0;
00171
00172 for (int i = 0; i < defaultCnt; i++)
00173 {
00174 if (newClauses_[i]->isHardClause())
00175 defaultCost += hardWt_;
00176 else
00177 defaultCost += abs(newClauses_[i]->getWt());
00178 }
00179
00180
00181 for (int i = 0; i < gndPredHashArray_.size(); i++)
00182 gndPredHashArray_[i]->removeGndClauses();
00183
00184
00185 for (int i = 0; i < newClauses_.size(); i++)
00186 delete newClauses_[i];
00187 newClauses_.clear();
00188
00189 baseNumAtoms_ = gndPredHashArray_.size();
00190 cout << "Number of Baseatoms = " << baseNumAtoms_ << endl;
00191 cout << "Default => Cost\t" << "******\t" << " Clause Cnt\t" << endl;
00192 cout << " " << defaultCost << "\t" << "******\t" << defaultCnt
00193 << "\t" << endl << endl;
00194
00195
00196 for (int i = 0; i < baseNumAtoms_; i++)
00197 {
00198 domain_->getDB()->setActiveStatus(gndPredHashArray_[i], true);
00199 activeAtoms_++;
00200 }
00201
00202
00203 fillLazyBlocks();
00204
00205
00206 ignoreActivePreds = false;
00207 getActiveClauses(newClauses_, ignoreActivePreds);
00208 }
00209
00210 else
00211 {
00212 unePreds_ = unknownQueries;
00213 knePreds_ = knownQueries;
00214 knePredValues_ = knownQueryValues;
00215
00216
00217 int size = 0;
00218 if (unknownQueries) size += unknownQueries->size();
00219 if (knownQueries) size += knownQueries->size();
00220 GroundPredicateHashArray* queries = new GroundPredicateHashArray(size);
00221 if (unknownQueries) queries->append(unknownQueries);
00222 if (knownQueries) queries->append(knownQueries);
00223 mrf_ = new MRF(queries, allPredGndingsAreQueries, domain_,
00224 domain_->getDB(), mln_, markHardGndClauses,
00225 trackParentClauseWts, -1);
00226
00227
00228 mrf_->deleteGndPredsGndClauseSets();
00229
00230 delete queries;
00231
00232
00233 blocks_ = mrf_->getBlocks();
00234 blockEvidence_ = mrf_->getBlockEvidence();
00235
00236
00237 newClauses_ = *(Array<GroundClause*>*)mrf_->getGndClauses();
00238
00239
00240 const GroundPredicateHashArray* gndPreds = mrf_->getGndPreds();
00241 for (int i = 0; i < gndPreds->size(); i++)
00242 gndPredHashArray_.append((*gndPreds)[i]);
00243
00244
00245 baseNumAtoms_ = gndPredHashArray_.size();
00246 }
00247
00248
00249
00250
00251
00252
00253 bool initial = true;
00254 addNewClauses(initial);
00255
00256 cout << "Initial num. of clauses: " << getNumClauses() << endl;
00257 }
00258
00263 ~VariableState()
00264 {
00265 if (lazy_)
00266 {
00267
00268 for (int i = 0; i < blocks_->size(); i++)
00269 (*blocks_)[i].clearAndCompress();
00270 delete blocks_;
00271
00272 delete blockEvidence_;
00273 }
00274 else
00275 {
00276
00277 if (mrf_) delete mrf_;
00278
00279
00280
00281 }
00282 }
00283
00284
00292 void addNewClauses(bool initial)
00293 {
00294 if (vsdebug)
00295 cout << "Adding " << newClauses_.size() << " new clauses.." << endl;
00296
00297
00298 int oldNumClauses = getNumClauses();
00299 int oldNumAtoms = getNumAtoms();
00300
00301 gndClauses_->append(newClauses_);
00302 gndPreds_->growToSize(gndPredHashArray_.size());
00303
00304 int numAtoms = getNumAtoms();
00305 int numClauses = getNumClauses();
00306
00307 if (numAtoms == oldNumAtoms && numClauses == oldNumClauses) return;
00308
00309 if (vsdebug) cout << "Clauses: " << numClauses << endl;
00310
00311 atom_.growToSize(numAtoms + 1, false);
00312
00313 makeCost_.growToSize(numAtoms + 1, 0.0);
00314 breakCost_.growToSize(numAtoms + 1, 0.0);
00315 lowAtom_.growToSize(numAtoms + 1, false);
00316 fixedAtom_.growToSize(numAtoms + 1, 0);
00317
00318
00319 for (int i = oldNumAtoms; i < gndPredHashArray_.size(); i++)
00320 {
00321 (*gndPreds_)[i] = gndPredHashArray_[i];
00322
00323 if (vsdebug)
00324 {
00325 cout << "New pred: ";
00326 (*gndPreds_)[i]->print(cout, domain_);
00327 cout << endl;
00328 }
00329
00330 lowAtom_[i + 1] = atom_[i + 1] =
00331 (domain_->getDB()->getValue((*gndPreds_)[i]) == TRUE) ? true : false;
00332 }
00333 newClauses_.clear();
00334
00335 clause_.growToSize(numClauses);
00336 clauseCost_.growToSize(numClauses);
00337 falseClause_.growToSize(numClauses);
00338 whereFalse_.growToSize(numClauses);
00339 numTrueLits_.growToSize(numClauses);
00340 watch1_.growToSize(numClauses);
00341 watch2_.growToSize(numClauses);
00342 isSatisfied_.growToSize(numClauses, false);
00343 deadClause_.growToSize(numClauses, false);
00344 threshold_.growToSize(numClauses, false);
00345
00346 occurence_.growToSize(2*numAtoms + 1);
00347
00348 for (int i = oldNumClauses; i < numClauses; i++)
00349 {
00350 GroundClause* gndClause = (*gndClauses_)[i];
00351
00352 if (vsdebug)
00353 {
00354 cout << "New clause: ";
00355 gndClause->print(cout, domain_, &gndPredHashArray_);
00356 cout << endl;
00357 }
00358
00359
00360 if (gndClause->isHardClause()) threshold_[i] = RAND_MAX;
00361 else
00362 {
00363 double w = gndClause->getWt();
00364 threshold_[i] = RAND_MAX*(1 - exp(-abs(w)));
00365 if (vsdebug)
00366 {
00367 cout << "Weight: " << w << endl;
00368 }
00369 }
00370 if (vsdebug)
00371 cout << "Threshold: " << threshold_[i] << endl;
00372
00373 int numGndPreds = gndClause->getNumGroundPredicates();
00374 clause_[i].growToSize(numGndPreds);
00375
00376 for (int j = 0; j < numGndPreds; j++)
00377 {
00378 int lit = gndClause->getGroundPredicateIndex(j);
00379 clause_[i][j] = lit;
00380 int litIdx = 2*abs(lit) - (lit > 0);
00381 occurence_[litIdx].append(i);
00382 }
00383
00384
00385 if (gndClause->isHardClause())
00386 clauseCost_[i] = hardWt_;
00387 else
00388 clauseCost_[i] = gndClause->getWt();
00389 }
00390
00391 if (!initial)
00392 {
00393
00394 if (useThreshold_)
00395 {
00396 killClauses(oldNumClauses);
00397 }
00398 else
00399 {
00400 initMakeBreakCostWatch(oldNumClauses);
00401 }
00402 }
00403 if (vsdebug) cout << "Done adding new clauses.." << endl;
00404 }
00405
00409 void init()
00410 {
00411
00412 initMakeBreakCostWatch();
00413 }
00414
00418 void reinit()
00419 {
00420 clause_.clearAndCompress();
00421 clauseCost_.clearAndCompress();
00422 falseClause_.clearAndCompress();
00423 whereFalse_.clearAndCompress();
00424 numTrueLits_.clearAndCompress();
00425 watch1_.clearAndCompress();
00426 watch2_.clearAndCompress();
00427 isSatisfied_.clearAndCompress();
00428 deadClause_.clearAndCompress();
00429 threshold_.clearAndCompress();
00430
00431 newClauses_.append(gndClauses_);
00432 gndClauses_->clearAndCompress();
00433 gndPreds_->clearAndCompress();
00434 for (int i = 0; i < occurence_.size(); i++)
00435 occurence_.clearAndCompress();
00436 occurence_.clearAndCompress();
00437
00438
00439 bool initial = true;
00440 addNewClauses(initial);
00441 baseNumAtoms_ = gndPredHashArray_.size();
00442 init();
00443 }
00444
00450 void initRandom()
00451 {
00452
00453 initBlocksRandom();
00454
00455
00456 for (int i = 1; i <= baseNumAtoms_; i++)
00457 {
00458
00459 if (fixedAtom_[i] != 0) setValueOfAtom(i, (fixedAtom_[i] == 1));
00460
00461 if (getBlockIndex(i - 1) >= 0)
00462 {
00463 if (vsdebug) cout << "Atom " << i << " in block" << endl;
00464 continue;
00465 }
00466
00467 else
00468 {
00469 if (vsdebug) cout << "Atom " << i << " not in block" << endl;
00470 setValueOfAtom(i, random() % 2);
00471 }
00472 }
00473 init();
00474 }
00475
00479 void initBlocksRandom()
00480 {
00481 if (vsdebug)
00482 {
00483 cout << "Initializing blocks randomly" << endl;
00484 cout << "Num. of blocks: " << blocks_->size() << endl;
00485 }
00486
00487
00488 for (int i = 0; i < blocks_->size(); i++)
00489 {
00490
00491 if (int trueFixedAtomInBlock = getTrueFixedAtomInBlock(i) >= 0)
00492 {
00493 if (vsdebug) cout << "True fixed atom in block " << i << endl;
00494 setOthersInBlockToFalse(trueFixedAtomInBlock, i);
00495 continue;
00496 }
00497
00498
00499 if ((*blockEvidence_)[i])
00500 {
00501
00502 setOthersInBlockToFalse(-1, i);
00503 continue;
00504 }
00505
00506
00507 Array<int>& block = (*blocks_)[i];
00508 bool ok = false;
00509 while (!ok)
00510 {
00511 int chosen = random() % block.size();
00512
00513 if (fixedAtom_[block[chosen] + 1] == 0)
00514 {
00515 if (vsdebug) cout << "Atom " << block[chosen] + 1
00516 << " chosen in block" << endl;
00517 setValueOfAtom(block[chosen] + 1, true);
00518 setOthersInBlockToFalse(chosen, i);
00519 ok = true;
00520 }
00521 }
00522 }
00523 if (vsdebug) cout << "Done initializing blocks randomly" << endl;
00524 }
00525
00530 void initMakeBreakCostWatch()
00531 {
00532
00533 for (int i = 0; i < getNumClauses(); i++) numTrueLits_[i] = 0;
00534 numFalseClauses_ = 0;
00535 costOfFalseClauses_ = 0.0;
00536 lowCost_ = LDBL_MAX;
00537 lowBad_ = INT_MAX;
00538
00539 assert(makeCost_.size() == breakCost_.size());
00540
00541 for (int i = 0; i < makeCost_.size(); i++)
00542 {
00543 makeCost_[i] = breakCost_[i] = 0.0;
00544 }
00545 initMakeBreakCostWatch(0);
00546 }
00547
00555 void initMakeBreakCostWatch(const int& startClause)
00556 {
00557 int theTrueLit = -1;
00558
00559 for (int i = startClause; i < getNumClauses(); i++)
00560 {
00561
00562 if (deadClause_[i]) continue;
00563 int trueLit1 = 0;
00564 int trueLit2 = 0;
00565 long double cost = clauseCost_[i];
00566 numTrueLits_[i] = 0;
00567 for (int j = 0; j < getClauseSize(i); j++)
00568 {
00569 if (isTrueLiteral(clause_[i][j]))
00570 {
00571 numTrueLits_[i]++;
00572 theTrueLit = abs(clause_[i][j]);
00573 if (!trueLit1) trueLit1 = theTrueLit;
00574 else if (trueLit1 && !trueLit2) trueLit2 = theTrueLit;
00575 }
00576 }
00577
00578
00579
00580 if ((numTrueLits_[i] == 0 && cost > 0) ||
00581 (numTrueLits_[i] > 0 && cost < 0))
00582 {
00583 whereFalse_[i] = numFalseClauses_;
00584 falseClause_[numFalseClauses_] = i;
00585 numFalseClauses_++;
00586 costOfFalseClauses_ += abs(cost);
00587 if (highestCost_ == abs(cost)) {eqHighest_ = true; numHighest_++;}
00588
00589
00590 if (numTrueLits_[i] == 0)
00591 for (int j = 0; j < getClauseSize(i); j++)
00592 {
00593 makeCost_[abs(clause_[i][j])] += cost;
00594 }
00595
00596
00597 if (numTrueLits_[i] == 1)
00598 {
00599
00600 makeCost_[theTrueLit] -= cost;
00601 watch1_[i] = theTrueLit;
00602 }
00603 else if (numTrueLits_[i] > 1)
00604 {
00605 watch1_[i] = trueLit1;
00606 watch2_[i] = trueLit2;
00607 }
00608 }
00609
00610 else if (numTrueLits_[i] == 1 && cost > 0)
00611 {
00612 breakCost_[theTrueLit] += cost;
00613 watch1_[i] = theTrueLit;
00614 }
00615
00616 else if (cost > 0)
00617 {
00618 watch1_[i] = trueLit1;
00619 watch2_[i] = trueLit2;
00620 }
00621
00622 else if (numTrueLits_[i] == 0 && cost < 0)
00623 {
00624 for (int j = 0; j < getClauseSize(i); j++)
00625 breakCost_[abs(clause_[i][j])] -= cost;
00626 }
00627 }
00628 }
00629
00630 int getNumAtoms() { return gndPreds_->size(); }
00631
00632 int getNumClauses() { return gndClauses_->size(); }
00633
00634 int getNumDeadClauses()
00635 {
00636 int count = 0;
00637 for (int i = 0; i < deadClause_.size(); i++)
00638 if (deadClause_[i]) count++;
00639 return count;
00640 }
00641
00645 int getIndexOfRandomAtom()
00646 {
00647 int numAtoms = getNumAtoms();
00648 if (numAtoms == 0) return NOVALUE;
00649 return random()%numAtoms + 1;
00650 }
00651
00657 int getIndexOfAtomInRandomFalseClause()
00658 {
00659 if (numFalseClauses_ == 0) return NOVALUE;
00660 int clauseIdx = falseClause_[random()%numFalseClauses_];
00661
00662 if (clauseCost_[clauseIdx] > 0)
00663 return abs(clause_[clauseIdx][random()%getClauseSize(clauseIdx)]);
00664
00665 else
00666 return getRandomTrueLitInClause(clauseIdx);
00667 }
00668
00673 int getRandomFalseClauseIndex()
00674 {
00675 if (numFalseClauses_ == 0) return NOVALUE;
00676 return falseClause_[random()%numFalseClauses_];
00677 }
00678
00683 long double getCostOfFalseClauses()
00684 {
00685 return costOfFalseClauses_;
00686 }
00687
00692 int getNumFalseClauses()
00693 {
00694 return numFalseClauses_;
00695 }
00696
00703 bool getValueOfAtom(const int& atomIdx)
00704 {
00705 return atom_[atomIdx];
00706 }
00707
00714 bool getValueOfLowAtom(const int& atomIdx)
00715 {
00716 return lowAtom_[atomIdx];
00717 }
00718
00726 void setValueOfAtom(const int& atomIdx, const bool& value)
00727 {
00728 if (vsdebug) cout << "Setting value of atom " << atomIdx
00729 << " to " << value << endl;
00730
00731 if (atom_[atomIdx] == value) return;
00732
00733 GroundPredicate* p = gndPredHashArray_[atomIdx - 1];
00734 if (value)
00735 {
00736 domain_->getDB()->setValue(p, TRUE);
00737 }
00738 else
00739 {
00740 domain_->getDB()->setValue(p, FALSE);
00741 }
00742
00743 if (lazy_ && !isActive(atomIdx))
00744 {
00745 bool ignoreActivePreds = false;
00746 activateAtom(atomIdx, ignoreActivePreds);
00747 }
00748 atom_[atomIdx] = value;
00749 }
00750
00754 Array<int>& getNegOccurenceArray(const int& atomIdx)
00755 {
00756 int litIdx = 2*atomIdx;
00757 return getOccurenceArray(litIdx);
00758 }
00759
00763 Array<int>& getPosOccurenceArray(const int& atomIdx)
00764 {
00765 int litIdx = 2*atomIdx - 1;
00766 return getOccurenceArray(litIdx);
00767 }
00768
00774 void flipAtom(const int& toFlip)
00775 {
00776 bool toFlipValue = getValueOfAtom(toFlip);
00777 register int clauseIdx;
00778 int sign;
00779 int oppSign;
00780 int litIdx;
00781 if (toFlipValue)
00782 sign = 1;
00783 else
00784 sign = 0;
00785 oppSign = sign ^ 1;
00786
00787 flipAtomValue(toFlip);
00788
00789
00790 litIdx = 2*toFlip - sign;
00791 Array<int>& posOccArray = getOccurenceArray(litIdx);
00792 for (int i = 0; i < posOccArray.size(); i++)
00793 {
00794 clauseIdx = posOccArray[i];
00795
00796 if (deadClause_[clauseIdx]) continue;
00797
00798 int numTrueLits = decrementNumTrueLits(clauseIdx);
00799 long double cost = getClauseCost(clauseIdx);
00800 int watch1 = getWatch1(clauseIdx);
00801 int watch2 = getWatch2(clauseIdx);
00802
00803
00804
00805 if (numTrueLits == 0)
00806 {
00807
00808 if (cost > 0)
00809 {
00810
00811 addFalseClause(clauseIdx);
00812
00813 addBreakCost(toFlip, -cost);
00814
00815 addMakeCostToAtomsInClause(clauseIdx, cost);
00816 }
00817
00818 else
00819 {
00820 assert(cost < 0);
00821
00822 removeFalseClause(clauseIdx);
00823
00824 addBreakCostToAtomsInClause(clauseIdx, -cost);
00825
00826 addMakeCost(toFlip, cost);
00827 }
00828 }
00829
00830
00831 else if (numTrueLits == 1)
00832 {
00833 if (watch1 == toFlip)
00834 {
00835 assert(watch1 != watch2);
00836 setWatch1(clauseIdx, watch2);
00837 watch1 = getWatch1(clauseIdx);
00838 }
00839
00840
00841 if (cost > 0)
00842 {
00843 addBreakCost(watch1, cost);
00844 }
00845
00846 else
00847 {
00848 assert(cost < 0);
00849 addMakeCost(watch1, -cost);
00850 }
00851 }
00852
00853
00854 else
00855 {
00856
00857 if (watch1 == toFlip)
00858 {
00859
00860 int diffTrueLit = getTrueLiteralOtherThan(clauseIdx, watch1, watch2);
00861 setWatch1(clauseIdx, diffTrueLit);
00862 }
00863
00864 else if (watch2 == toFlip)
00865 {
00866
00867 int diffTrueLit = getTrueLiteralOtherThan(clauseIdx, watch1, watch2);
00868 setWatch2(clauseIdx, diffTrueLit);
00869 }
00870 }
00871 }
00872
00873
00874 litIdx = 2*toFlip - oppSign;
00875 Array<int>& negOccArray = getOccurenceArray(litIdx);
00876 for (int i = 0; i < negOccArray.size(); i++)
00877 {
00878 clauseIdx = negOccArray[i];
00879
00880 if (deadClause_[clauseIdx]) continue;
00881
00882 int numTrueLits = incrementNumTrueLits(clauseIdx);
00883 long double cost = getClauseCost(clauseIdx);
00884 int watch1 = getWatch1(clauseIdx);
00885
00886
00887
00888 if (numTrueLits == 1)
00889 {
00890
00891 if (cost > 0)
00892 {
00893
00894 removeFalseClause(clauseIdx);
00895
00896 addBreakCost(toFlip, cost);
00897
00898 addMakeCostToAtomsInClause(clauseIdx, -cost);
00899 }
00900
00901 else
00902 {
00903 assert(cost < 0);
00904
00905 addFalseClause(clauseIdx);
00906
00907 addBreakCostToAtomsInClause(clauseIdx, cost);
00908
00909 addMakeCost(toFlip, -cost);
00910 }
00911
00912 setWatch1(clauseIdx, toFlip);
00913 }
00914
00915
00916 else
00917 if (numTrueLits == 2)
00918 {
00919 if (cost > 0)
00920 {
00921
00922
00923 addBreakCost(watch1, -cost);
00924 }
00925 else
00926 {
00927
00928 assert(cost < 0);
00929
00930 addMakeCost(watch1, cost);
00931 }
00932
00933
00934 setWatch2(clauseIdx, toFlip);
00935 }
00936 }
00937 }
00938
00943 void flipAtomValue(const int& atomIdx)
00944 {
00945 bool opposite = !atom_[atomIdx];
00946 setValueOfAtom(atomIdx, opposite);
00947 }
00948
00960 long double getImprovementByFlipping(const int& atomIdx)
00961 {
00962 if (lazy_ && !isActive(atomIdx))
00963 {
00964
00965 flipAtom(atomIdx);
00966 flipAtom(atomIdx);
00967 }
00968 long double improvement = makeCost_[atomIdx] - breakCost_[atomIdx];
00969 return improvement;
00970 }
00971
00978 void activateAtom(const int& atomIdx, const bool& ignoreActivePreds)
00979 {
00980
00981
00982 if (lazy_ && !isActive(atomIdx))
00983 {
00984 Predicate* p =
00985 gndPredHashArray_[atomIdx - 1]->createEquivalentPredicate(domain_);
00986 getActiveClauses(p, newClauses_, true, ignoreActivePreds);
00987
00988 bool initial = false;
00989 addNewClauses(initial);
00990
00991 domain_->getDB()->setActiveStatus(p, true);
00992 activeAtoms_++;
00993 delete p;
00994 }
00995 }
00996
01003 bool isActive(const int& atomIdx)
01004 {
01005 return domain_->getDB()->getActiveStatus(gndPredHashArray_[atomIdx-1]);
01006 }
01007
01014 bool isActive(const Predicate* pred)
01015 {
01016 return domain_->getDB()->getActiveStatus(pred);
01017 }
01018
01022 Array<int>& getOccurenceArray(const int& idx)
01023 {
01024 return occurence_[idx];
01025 }
01026
01030 int incrementNumTrueLits(const int& clauseIdx)
01031 {
01032 return ++numTrueLits_[clauseIdx];
01033 }
01034
01038 int decrementNumTrueLits(const int& clauseIdx)
01039 {
01040 return --numTrueLits_[clauseIdx];
01041 }
01042
01046 int getNumTrueLits(const int& clauseIdx)
01047 {
01048 return numTrueLits_[clauseIdx];
01049 }
01050
01054 long double getClauseCost(const int& clauseIdx)
01055 {
01056 return clauseCost_[clauseIdx];
01057 }
01058
01062 Array<int>& getAtomsInClause(const int& clauseIdx)
01063 {
01064 return clause_[clauseIdx];
01065 }
01066
01070 void addFalseClause(const int& clauseIdx)
01071 {
01072 falseClause_[numFalseClauses_] = clauseIdx;
01073 whereFalse_[clauseIdx] = numFalseClauses_;
01074 numFalseClauses_++;
01075 costOfFalseClauses_ += abs(clauseCost_[clauseIdx]);
01076 }
01077
01081 void removeFalseClause(const int& clauseIdx)
01082 {
01083 numFalseClauses_--;
01084 falseClause_[whereFalse_[clauseIdx]] = falseClause_[numFalseClauses_];
01085 whereFalse_[falseClause_[numFalseClauses_]] = whereFalse_[clauseIdx];
01086 costOfFalseClauses_ -= abs(clauseCost_[clauseIdx]);
01087 }
01088
01092 void addBreakCost(const int& atomIdx, const long double& cost)
01093 {
01094 breakCost_[atomIdx] += cost;
01095 }
01096
01100 void subtractBreakCost(const int& atomIdx, const long double& cost)
01101 {
01102 breakCost_[atomIdx] -= cost;
01103 }
01104
01111 void addBreakCostToAtomsInClause(const int& clauseIdx,
01112 const long double& cost)
01113 {
01114 register int size = getClauseSize(clauseIdx);
01115 for (int i = 0; i < size; i++)
01116 {
01117 register int lit = clause_[clauseIdx][i];
01118 breakCost_[abs(lit)] += cost;
01119 }
01120 }
01121
01128 void subtractBreakCostFromAtomsInClause(const int& clauseIdx,
01129 const long double& cost)
01130 {
01131 register int size = getClauseSize(clauseIdx);
01132 for (int i = 0; i < size; i++)
01133 {
01134 register int lit = clause_[clauseIdx][i];
01135 breakCost_[abs(lit)] -= cost;
01136 }
01137 }
01138
01145 void addMakeCost(const int& atomIdx, const long double& cost)
01146 {
01147 makeCost_[atomIdx] += cost;
01148 }
01149
01156 void subtractMakeCost(const int& atomIdx, const long double& cost)
01157 {
01158 makeCost_[atomIdx] -= cost;
01159 }
01160
01167 void addMakeCostToAtomsInClause(const int& clauseIdx,
01168 const long double& cost)
01169 {
01170 register int size = getClauseSize(clauseIdx);
01171 for (int i = 0; i < size; i++)
01172 {
01173 register int lit = clause_[clauseIdx][i];
01174 makeCost_[abs(lit)] += cost;
01175 }
01176 }
01177
01184 void subtractMakeCostFromAtomsInClause(const int& clauseIdx,
01185 const long double& cost)
01186 {
01187 register int size = getClauseSize(clauseIdx);
01188 for (int i = 0; i < size; i++)
01189 {
01190 register int lit = clause_[clauseIdx][i];
01191 makeCost_[abs(lit)] -= cost;
01192 }
01193 }
01194
01204 const int getTrueLiteralOtherThan(const int& clauseIdx,
01205 const int& atomIdx1,
01206 const int& atomIdx2)
01207 {
01208 register int size = getClauseSize(clauseIdx);
01209 for (int i = 0; i < size; i++)
01210 {
01211 register int lit = clause_[clauseIdx][i];
01212 register int v = abs(lit);
01213 if (isTrueLiteral(lit) && v != atomIdx1 && v != atomIdx2)
01214 return v;
01215 }
01216
01217 assert(false);
01218 return -1;
01219 }
01220
01224 const bool isTrueLiteral(const int& literal)
01225 {
01226 return ((literal > 0) == atom_[abs(literal)]);
01227 }
01228
01232 const int getAtomInClause(const int& atomIdxInClause, const int& clauseIdx)
01233 {
01234 return clause_[clauseIdx][atomIdxInClause];
01235 }
01236
01240 const int getRandomAtomInClause(const int& clauseIdx)
01241 {
01242 return clause_[clauseIdx][random()%getClauseSize(clauseIdx)];
01243 }
01244
01251 const int getRandomTrueLitInClause(const int& clauseIdx)
01252 {
01253 assert(numTrueLits_[clauseIdx] > 0);
01254 int trueLit = random()%numTrueLits_[clauseIdx];
01255 int whichTrueLit = 0;
01256 for (int i = 0; i < getClauseSize(clauseIdx); i++)
01257 {
01258 int lit = clause_[clauseIdx][i];
01259 int atm = abs(lit);
01260
01261 if (isTrueLiteral(lit))
01262 if (trueLit == whichTrueLit++)
01263 return atm;
01264 }
01265
01266 assert(false);
01267 return -1;
01268 }
01269
01270 const double getMaxClauseWeight()
01271 {
01272 double maxWeight = 0.0;
01273 for (int i = 0; i < getNumClauses(); i++)
01274 {
01275 double weight = abs(clauseCost_[i]);
01276 if (weight > maxWeight) maxWeight = weight;
01277 }
01278 return maxWeight;
01279 }
01280
01281 const long double getMakeCost(const int& atomIdx)
01282 {
01283 return makeCost_[atomIdx];
01284 }
01285
01286 const long double getBreakCost(const int& atomIdx)
01287 {
01288 return breakCost_[atomIdx];
01289 }
01290
01291 const int getClauseSize(const int& clauseIdx)
01292 {
01293 return clause_[clauseIdx].size();
01294 }
01295
01296 const int getWatch1(const int& clauseIdx)
01297 {
01298 return watch1_[clauseIdx];
01299 }
01300
01301 void setWatch1(const int& clauseIdx, const int& atomIdx)
01302 {
01303 watch1_[clauseIdx] = atomIdx;
01304 }
01305
01306 const int getWatch2(const int& clauseIdx)
01307 {
01308 return watch2_[clauseIdx];
01309 }
01310
01311 void setWatch2(const int& clauseIdx, const int& atomIdx)
01312 {
01313 watch2_[clauseIdx] = atomIdx;
01314 }
01315
01316 const bool isBlockEvidence(const int& blockIdx)
01317 {
01318 return (*blockEvidence_)[blockIdx];
01319 }
01320
01321 const int getBlockSize(const int& blockIdx)
01322 {
01323
01324 if (lazy_)
01325 return domain_->getPredBlock(blockIdx)->size();
01326 else
01327 return (*blocks_)[blockIdx].size();
01328 }
01329
01334 const int getBlockIndex(const int& atomIdx)
01335 {
01336 for (int i = 0; i < blocks_->size(); i++)
01337 {
01338 int blockIdx = (*blocks_)[i].find(atomIdx);
01339 if (blockIdx >= 0)
01340 return i;
01341 }
01342 return -1;
01343 }
01344
01345 Array<int>& getBlockArray(const int& blockIdx)
01346 {
01347 return (*blocks_)[blockIdx];
01348 }
01349
01350 bool getBlockEvidence(const int& blockIdx)
01351 {
01352 return (*blockEvidence_)[blockIdx];
01353 }
01354
01355 int getNumBlocks()
01356 {
01357 return blocks_->size();
01358 }
01359
01363 const long double getLowCost()
01364 {
01365 return lowCost_;
01366 }
01367
01371 const int getLowBad()
01372 {
01373 return lowBad_;
01374 }
01375
01380 void makeUnitCosts()
01381 {
01382 for (int i = 0; i < clauseCost_.size(); i++)
01383 {
01384 if (clauseCost_[i] > 0) clauseCost_[i] = 1.0;
01385 else
01386 {
01387 assert(clauseCost_[i] < 0);
01388 clauseCost_[i] = -1.0;
01389 }
01390 }
01391 if (vsdebug) cout << "Made unit costs" << endl;
01392 initMakeBreakCostWatch();
01393 }
01394
01398 void saveLowState()
01399 {
01400 if (vsdebug) cout << "Saving low state: " << endl;
01401 for (int i = 1; i <= getNumAtoms(); i++)
01402 {
01403 lowAtom_[i] = atom_[i];
01404 if (vsdebug) cout << lowAtom_[i] << endl;
01405 }
01406 lowCost_ = costOfFalseClauses_;
01407 lowBad_ = numFalseClauses_;
01408 }
01409
01413 int getTrueFixedAtomInBlock(const int blockIdx)
01414 {
01415 Array<int>& block = (*blocks_)[blockIdx];
01416 for (int i = 0; i < block.size(); i++)
01417 if (fixedAtom_[block[i] + 1] > 0) return i;
01418 return -1;
01419 }
01420
01421 const GroundPredicateHashArray* getGndPredHashArrayPtr() const
01422 {
01423 return &gndPredHashArray_;
01424 }
01425
01426 const GroundPredicateHashArray* getUnePreds() const
01427 {
01428 return unePreds_;
01429 }
01430
01431 const GroundPredicateHashArray* getKnePreds() const
01432 {
01433 return knePreds_;
01434 }
01435
01436 const Array<TruthValue>* getKnePredValues() const
01437 {
01438 return knePredValues_;
01439 }
01440
01444 void setGndClausesWtsToSumOfParentWts()
01445 {
01446 for (int i = 0; i < gndClauses_->size(); i++)
01447 {
01448 GroundClause* gndClause = (*gndClauses_)[i];
01449 gndClause->setWtToSumOfParentWts();
01450 if (vsdebug) cout << "Setting cost of clause " << i << " to "
01451 << gndClause->getWt() << endl;
01452 clauseCost_[i] = gndClause->getWt();
01453
01454
01455 if (gndClause->isHardClause()) threshold_[i] = RAND_MAX;
01456 else
01457 {
01458 double w = gndClause->getWt();
01459 threshold_[i] = RAND_MAX*(1 - exp(-abs(w)));
01460 if (vsdebug)
01461 {
01462 cout << "Weight: " << w << endl;
01463 }
01464 }
01465 if (vsdebug)
01466 cout << "Threshold: " << threshold_[i] << endl;
01467 }
01468 }
01469
01478 void getNumClauseGndings(Array<double>* const & numGndings, bool tv)
01479 {
01480
01481 IntPairItr itr;
01482 IntPair *clauseFrequencies;
01483
01484
01485 int clauseCnt = numGndings->size();
01486 assert(clauseCnt == mln_->getNumClauses());
01487 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
01488 assert ((*numGndings)[clauseno] >= 0);
01489
01490 for (int i = 0; i < gndClauses_->size(); i++)
01491 {
01492 GroundClause *gndClause = (*gndClauses_)[i];
01493 int satLitcnt = 0;
01494 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
01495 {
01496 int lit = gndClause->getGroundPredicateIndex(j);
01497 if (isTrueLiteral(lit)) satLitcnt++;
01498 }
01499
01500 if (tv && satLitcnt == 0)
01501 continue;
01502 if (!tv && satLitcnt > 0)
01503 continue;
01504
01505 clauseFrequencies = gndClause->getClauseFrequencies();
01506 for (itr = clauseFrequencies->begin();
01507 itr != clauseFrequencies->end(); itr++)
01508 {
01509 int clauseno = itr->first;
01510 int frequency = itr->second;
01511 (*numGndings)[clauseno] += frequency;
01512 }
01513 }
01514 }
01515
01524 void getNumClauseGndings(double numGndings[], int clauseCnt, bool tv)
01525 {
01526
01527 IntPairItr itr;
01528 IntPair *clauseFrequencies;
01529
01530 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
01531 numGndings[clauseno] = 0;
01532
01533 for (int i = 0; i < gndClauses_->size(); i++)
01534 {
01535 GroundClause *gndClause = (*gndClauses_)[i];
01536 int satLitcnt = getNumTrueLits(i);
01537 if (tv && satLitcnt == 0)
01538 continue;
01539 if (!tv && satLitcnt > 0)
01540 continue;
01541
01542 clauseFrequencies = gndClause->getClauseFrequencies();
01543 for (itr = clauseFrequencies->begin();
01544 itr != clauseFrequencies->end(); itr++)
01545 {
01546 int clauseno = itr->first;
01547 int frequency = itr->second;
01548 numGndings[clauseno] += frequency;
01549 }
01550 }
01551 }
01552
01564 void getNumClauseGndingsWithUnknown(double numGndings[], int clauseCnt,
01565 bool tv,
01566 const Array<bool>* const& unknownPred)
01567 {
01568 assert(unknownPred->size() == getNumAtoms());
01569 IntPairItr itr;
01570 IntPair *clauseFrequencies;
01571
01572 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
01573 numGndings[clauseno] = 0;
01574
01575 for (int i = 0; i < gndClauses_->size(); i++)
01576 {
01577 GroundClause *gndClause = (*gndClauses_)[i];
01578 int satLitcnt = 0;
01579 bool unknown = false;
01580 for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
01581 {
01582 int lit = gndClause->getGroundPredicateIndex(j);
01583 if ((*unknownPred)[abs(lit) - 1])
01584 {
01585 unknown = true;
01586 continue;
01587 }
01588 if (isTrueLiteral(lit)) satLitcnt++;
01589 }
01590
01591 if (tv && satLitcnt == 0)
01592 continue;
01593 if (!tv && (satLitcnt > 0 || unknown))
01594 continue;
01595
01596 clauseFrequencies = gndClause->getClauseFrequencies();
01597 for (itr = clauseFrequencies->begin();
01598 itr != clauseFrequencies->end(); itr++)
01599 {
01600 int clauseno = itr->first;
01601 int frequency = itr->second;
01602 numGndings[clauseno] += frequency;
01603 }
01604 }
01605 }
01606
01613 void setOthersInBlockToFalse(const int& atomIdx,
01614 const int& blockIdx)
01615 {
01616 Array<int>& block = (*blocks_)[blockIdx];
01617 for (int i = 0; i < block.size(); i++)
01618 {
01619
01620 if (i != atomIdx && fixedAtom_[block[i] + 1] == 0)
01621 setValueOfAtom(block[i] + 1, false);
01622 }
01623 }
01624
01625
01634 void fixAtom(const int& atomIdx, const bool& value)
01635 {
01636 assert(atomIdx > 0);
01637
01638 if ((fixedAtom_[atomIdx] == 1 && value == false) ||
01639 (fixedAtom_[atomIdx] == -1 && value == true))
01640 {
01641 cout << "Contradiction: Tried to fix atom " << atomIdx <<
01642 " to true and false ... exiting." << endl;
01643 exit(0);
01644 }
01645
01646 if (vsdebug)
01647 {
01648 cout << "Fixing ";
01649 (*gndPreds_)[atomIdx - 1]->print(cout, domain_);
01650 cout << " to " << value << endl;
01651 }
01652
01653 setValueOfAtom(atomIdx, value);
01654 fixedAtom_[atomIdx] = (value) ? 1 : -1;
01655 }
01656
01668 Array<int>* simplifyClauseFromFixedAtoms(const int& clauseIdx)
01669 {
01670 Array<int>* returnArray = new Array<int>;
01671
01672 if (isSatisfied_[clauseIdx]) return returnArray;
01673
01674
01675
01676 bool isGood = (clauseCost_[clauseIdx] > 0) ? false : true;
01677
01678 bool allFalseAtoms = (clauseCost_[clauseIdx] > 0) ? true : false;
01679
01680 for (int i = 0; i < getClauseSize(clauseIdx); i++)
01681 {
01682 int lit = clause_[clauseIdx][i];
01683 int fixedValue = fixedAtom_[abs(lit)];
01684
01685 if (clauseCost_[clauseIdx] > 0)
01686 {
01687 if ((fixedValue == 1 && lit > 0) ||
01688 (fixedValue == -1 && lit < 0))
01689 {
01690 isGood = true;
01691 allFalseAtoms = false;
01692 returnArray->clear();
01693 break;
01694 }
01695 else if (fixedValue == 0)
01696 {
01697 allFalseAtoms = false;
01698 returnArray->append(lit);
01699 }
01700 }
01701 else
01702 {
01703 assert(clauseCost_[clauseIdx] < 0);
01704 if ((fixedValue == 1 && lit > 0) ||
01705 (fixedValue == -1 && lit < 0))
01706 {
01707 cout << "Contradiction: Tried to fix atom " << abs(lit) <<
01708 " to true in a negative clause ... exiting." << endl;
01709 exit(0);
01710 }
01711 else
01712 {
01713 returnArray->append(lit);
01714
01715 if (fixedValue == 0) isGood = false;
01716 }
01717 }
01718 }
01719 if (allFalseAtoms)
01720 {
01721 cout << "Contradiction: All atoms in clause " << clauseIdx <<
01722 " fixed to false ... exiting." << endl;
01723 exit(0);
01724 }
01725 if (isGood) isSatisfied_[clauseIdx] = true;
01726 return returnArray;
01727 }
01728
01735 const bool isDeadClause(const int& clauseIdx)
01736 {
01737 return deadClause_[clauseIdx];
01738 }
01739
01743 void eliminateSoftClauses()
01744 {
01745 bool atLeastOneDead = false;
01746 for (int i = 0; i < getNumClauses(); i++)
01747 {
01748 if (!(*gndClauses_)[i]->isHardClause())
01749 {
01750 atLeastOneDead = true;
01751 deadClause_[i] = true;
01752 }
01753 }
01754 if (atLeastOneDead) initMakeBreakCostWatch();
01755 }
01756
01764 void killClauses(const int& startClause)
01765 {
01766 for (int i = startClause; i < getNumClauses(); i++)
01767 {
01768 GroundClause* clause = (*gndClauses_)[i];
01769 if ((clauseGoodInPrevious(i)) &&
01770 (clause->isHardClause() || random() <= threshold_[i]))
01771 {
01772 if (vsdebug)
01773 {
01774 cout << "Keeping clause "<< i << " ";
01775 clause->print(cout, domain_, &gndPredHashArray_);
01776 cout << endl;
01777 }
01778 deadClause_[i] = false;
01779 }
01780 else
01781 {
01782 deadClause_[i] = true;
01783 }
01784 }
01785 initMakeBreakCostWatch();
01786 }
01787
01788
01796 const bool clauseGoodInPrevious(const int& clauseIdx)
01797 {
01798
01799 int numSatLits = numTrueLits_[clauseIdx];
01800
01801 if ((numSatLits > 0 && clauseCost_[clauseIdx] > 0.0) ||
01802 (numSatLits == 0 && clauseCost_[clauseIdx] < 0.0))
01803 return true;
01804 else
01805 return false;
01806 }
01807
01811 void resetDeadClauses()
01812 {
01813 for (int i = 0; i < deadClause_.size(); i++)
01814 deadClause_[i] = false;
01815 initMakeBreakCostWatch();
01816 }
01817
01821 void resetFixedAtoms()
01822 {
01823 for (int i = 0; i < fixedAtom_.size(); i++)
01824 fixedAtom_[i] = 0;
01825 for (int i = 0; i < isSatisfied_.size(); i++)
01826 isSatisfied_[i] = false;
01827 }
01828
01829 void setLazy(const bool& l) { lazy_ = l; }
01830 const bool getLazy() { return lazy_; }
01831
01832 void setUseThreshold(const bool& t) { useThreshold_ = t;}
01833 const bool getUseThreshold() { return useThreshold_; }
01834
01835 long double getHardWt() { return hardWt_; }
01836
01837 const Domain* getDomain() { return domain_; }
01838
01839 const MLN* getMLN() { return mln_; }
01840
01846 void printLowState(ostream& out)
01847 {
01848 for (int i = 0; i < getNumAtoms(); i++)
01849 {
01850 (*gndPreds_)[i]->print(out, domain_);
01851 out << " " << lowAtom_[i + 1] << endl;
01852 }
01853 }
01854
01861 void printGndPred(const int& predIndex, ostream& out)
01862 {
01863 (*gndPreds_)[predIndex]->print(out, domain_);
01864 }
01865
01872 void getTruePreds(vector<string>& truePreds)
01873 {
01874 truePreds.clear();
01875 for (int i = 0; i < getNumAtoms(); i++)
01876 {
01877 if (getValueOfLowAtom(i + 1))
01878 {
01879 ostringstream oss(ostringstream::out);
01880 printGndPred(i, oss);
01881 truePreds.push_back(oss.str());
01882 }
01883 }
01884 }
01885
01896 void setAsEvidence(const GroundPredicate* const & predicate,
01897 const bool& trueEvidence)
01898 {
01899 if (vsdebug)
01900 {
01901 cout << "Setting to evidence " ;
01902 predicate->print(cout, domain_);
01903 cout << endl;
01904 }
01905 Database* db = domain_->getDB();
01906 int atomIdx = gndPredHashArray_.find((GroundPredicate*)predicate);
01907
01908 if (atomIdx <= 0)
01909 {
01910
01911 if (db->getValue(predicate) == trueEvidence)
01912 return;
01913
01914
01915 if (trueEvidence)
01916 db->setValue(predicate, TRUE);
01917 else
01918 db->setValue(predicate, FALSE);
01919 }
01920 else
01921 {
01922 Array<int> gndClauseIndexes;
01923 gndClauseIndexes = getNegOccurenceArray(atomIdx + 1);
01924 for (int i = 0; i < gndClauseIndexes.size(); i++)
01925 {
01926
01927
01928
01929 if (!trueEvidence ||
01930 (*gndClauses_)[gndClauseIndexes[i]]->getNumGroundPredicates() == 1)
01931 {
01932 if (vsdebug)
01933 cout << "Deleting ground clause " << gndClauseIndexes[i] << endl;
01934 delete (*gndClauses_)[gndClauseIndexes[i]];
01935 (*gndClauses_)[gndClauseIndexes[i]] = NULL;
01936 }
01937 else
01938 {
01939 if (vsdebug)
01940 {
01941 cout << "Removing gnd pred " << -(atomIdx + 1)
01942 << " from ground clause " << gndClauseIndexes[i] << endl;
01943 }
01944 (*gndClauses_)[gndClauseIndexes[i]]->removeGndPred(-(atomIdx + 1));
01945 }
01946 }
01947
01948 gndClauseIndexes = getPosOccurenceArray(atomIdx + 1);
01949 for (int i = 0; i < gndClauseIndexes.size(); i++)
01950 {
01951
01952
01953
01954 if (trueEvidence ||
01955 (*gndClauses_)[gndClauseIndexes[i]]->getNumGroundPredicates() == 1)
01956 {
01957 if (vsdebug)
01958 cout << "Deleting ground clause " << gndClauseIndexes[i] << endl;
01959 delete (*gndClauses_)[gndClauseIndexes[i]];
01960 (*gndClauses_)[gndClauseIndexes[i]] = NULL;
01961 }
01962 else
01963 {
01964 if (vsdebug)
01965 {
01966 cout << "Removing gnd pred " << -(atomIdx + 1)
01967 << " from ground clause " << gndClauseIndexes[i] << endl;
01968 }
01969 (*gndClauses_)[gndClauseIndexes[i]]->removeGndPred(atomIdx + 1);
01970 }
01971 }
01972
01973 gndPredHashArray_.removeItemFastDisorder(atomIdx);
01974 gndPredHashArray_.compress();
01975 gndPreds_->removeItemFastDisorder(atomIdx);
01976 gndPreds_->compress();
01977
01978
01979
01980 int oldIdx = gndPredHashArray_.size();
01981 replaceAtomIndexInAllClauses(oldIdx, atomIdx);
01982 gndClauses_->removeAllNull();
01983 }
01984 }
01985
01994 void setAsQuery(const GroundPredicate* const & predicate)
01995 {
01996 if (vsdebug)
01997 {
01998 cout << "Setting to query " ;
01999 predicate->print(cout, domain_);
02000 cout << endl;
02001 }
02002 Database* db = domain_->getDB();
02003
02004 if (gndPredHashArray_.contains((GroundPredicate*)predicate))
02005 return;
02006 else
02007 {
02008
02009
02010 gndPredHashArray_.append((GroundPredicate*)predicate);
02011 Predicate* p = predicate->createEquivalentPredicate(domain_);
02012 db->setEvidenceStatus(p, false);
02013 bool ignoreActivePreds = true;
02014 getActiveClauses(p, newClauses_, true, ignoreActivePreds);
02015 }
02016 }
02017
02019
02026 GroundPredicate* getGndPred(const int& index)
02027 {
02028 return (*gndPreds_)[index];
02029 }
02030
02037 GroundClause* getGndClause(const int& index)
02038 {
02039 return (*gndClauses_)[index];
02040 }
02041
02045 void saveLowStateToGndPreds()
02046 {
02047 for (int i = 0; i < getNumAtoms(); i++)
02048 (*gndPreds_)[i]->setTruthValue(lowAtom_[i + 1]);
02049 }
02050
02054 void saveLowStateToDB()
02055 {
02056 for (int i = 0; i < getNumAtoms(); i++)
02057 {
02058 GroundPredicate* p = gndPredHashArray_[i];
02059 bool value = lowAtom_[i + 1];
02060 if (value)
02061 {
02062 domain_->getDB()->setValue(p, TRUE);
02063 }
02064 else
02065 {
02066 domain_->getDB()->setValue(p, FALSE);
02067 }
02068 }
02069 }
02070
02077 const int getGndPredIndex(GroundPredicate* const& gndPred)
02078 {
02079 return gndPreds_->find(gndPred);
02080 }
02081
02082
02084
02085
02087
02101 void getActiveClauses(Predicate *inputPred,
02102 Array<GroundClause*>& activeClauses,
02103 bool const & active,
02104 bool const & ignoreActivePreds)
02105 {
02106 Clause *fclause;
02107 GroundClause* newClause;
02108 int clauseCnt;
02109 GroundClauseHashArray clauseHashArray;
02110
02111 Array<GroundClause*>* newClauses = new Array<GroundClause*>;
02112
02113 const Array<IndexClause*>* indexClauses = NULL;
02114
02115
02116 if (inputPred == NULL)
02117 {
02118 clauseCnt = mln_->getNumClauses();
02119 }
02120
02121 else
02122 {
02123 if (domain_->getDB()->getDeactivatedStatus(inputPred)) return;
02124 int predId = inputPred->getId();
02125 indexClauses = mln_->getClausesContainingPred(predId);
02126 clauseCnt = indexClauses->size();
02127 }
02128
02129
02130 int clauseno = 0;
02131 while (clauseno < clauseCnt)
02132 {
02133 if (inputPred)
02134 fclause = (Clause *) (*indexClauses)[clauseno]->clause;
02135 else
02136 fclause = (Clause *) mln_->getClause(clauseno);
02137
02138 if (vsdebug)
02139 {
02140 cout << "Getting active clauses for FO clause: ";
02141 fclause->print(cout, domain_);
02142 cout << endl;
02143 }
02144
02145 long double wt = fclause->getWt();
02146 const double* parentWtPtr = NULL;
02147 if (!fclause->isHardClause()) parentWtPtr = fclause->getWtPtr();
02148 const int clauseId = mln_->findClauseIdx(fclause);
02149 newClauses->clear();
02150
02151 fclause->getActiveClauses(inputPred, domain_, newClauses,
02152 &gndPredHashArray_, ignoreActivePreds);
02153
02154 for (int i = 0; i < newClauses->size(); i++)
02155 {
02156 newClause = (*newClauses)[i];
02157 int pos = clauseHashArray.find(newClause);
02158
02159 if (pos >= 0)
02160 {
02161 if (vsdebug)
02162 {
02163 cout << "Adding weight " << wt << " to clause ";
02164 clauseHashArray[pos]->print(cout, domain_, &gndPredHashArray_);
02165 cout << endl;
02166 }
02167 clauseHashArray[pos]->addWt(wt);
02168 if (parentWtPtr)
02169 {
02170 clauseHashArray[pos]->appendParentWtPtr(parentWtPtr);
02171 clauseHashArray[pos]->incrementClauseFrequency(clauseId, 1);
02172 }
02173 delete newClause;
02174 continue;
02175 }
02176
02177
02178 newClause->setWt(wt);
02179 newClause->appendToGndPreds(&gndPredHashArray_);
02180 if (parentWtPtr)
02181 {
02182 newClause->appendParentWtPtr(parentWtPtr);
02183 newClause->incrementClauseFrequency(clauseId, 1);
02184 assert(newClause->getWt() == *parentWtPtr);
02185 }
02186
02187 if (vsdebug)
02188 {
02189 cout << "Appending clause ";
02190 newClause->print(cout, domain_, &gndPredHashArray_);
02191 cout << endl;
02192 }
02193 clauseHashArray.append(newClause);
02194 }
02195 clauseno++;
02196 }
02197
02198 for (int i = 0; i < clauseHashArray.size(); i++)
02199 {
02200 newClause = clauseHashArray[i];
02201 activeClauses.append(newClause);
02202 }
02203 delete newClauses;
02204 }
02205
02213 void getActiveClauses(Array<GroundClause*> &allClauses,
02214 bool const & ignoreActivePreds)
02215 {
02216 getActiveClauses(NULL, allClauses, true, ignoreActivePreds);
02217 }
02218
02219 int getNumActiveAtoms()
02220 {
02221 return activeAtoms_;
02222 }
02223
02228 void addOneAtomToEachBlock()
02229 {
02230 assert(lazy_);
02231
02232 for (int i = 0; i < blocks_->size(); i++)
02233 {
02234
02235 if ((*blockEvidence_)[i])
02236 {
02237
02238 setOthersInBlockToFalse(-1, i);
02239 continue;
02240 }
02241
02242
02243
02244 const Array<Predicate*>* block = domain_->getPredBlock(i);
02245
02246 int chosen = random() % block->size();
02247 Predicate* pred = (*block)[chosen];
02248 GroundPredicate* groundPred = new GroundPredicate(pred);
02249
02250
02251 int index = gndPredHashArray_.find(groundPred);
02252 if (index < 0)
02253 {
02254
02255 index = gndPredHashArray_.append(groundPred);
02256 (*blocks_)[i].append(index);
02257 chosen = (*blocks_)[i].size() - 1;
02258
02259
02260 bool initial = false;
02261 addNewClauses(initial);
02262 }
02263 else
02264 {
02265 delete groundPred;
02266 chosen = (*blocks_)[i].find(index);
02267 }
02268 setValueOfAtom(index + 1, true);
02269 setOthersInBlockToFalse(chosen, i);
02270 }
02271 }
02272
02276 void initLazyBlocks()
02277 {
02278 assert(lazy_);
02279 blocks_ = new Array<Array<int> >;
02280 blocks_->growToSize(domain_->getNumPredBlocks());
02281 blockEvidence_ = new Array<bool>(*(domain_->getBlockEvidenceArray()));
02282 }
02283
02287 void fillLazyBlocks()
02288 {
02289 assert(lazy_);
02290 const Array<Array<Predicate*>*>* blocks = domain_->getPredBlocks();
02291 for (int i = 0; i < blocks->size(); i++)
02292 {
02293 if (vsdebug) cout << "Block " << i << endl;
02294 Array<Predicate*>* block = (*blocks)[i];
02295 for (int j = 0; j < block->size(); j++)
02296 {
02297 Predicate* pred = (*block)[j];
02298 if (vsdebug)
02299 {
02300 cout << "\tPred: ";
02301 pred->printWithStrVar(cout, domain_);
02302 cout << endl;
02303 }
02304
02305 if (domain_->getDB()->getEvidenceStatus(pred))
02306 continue;
02307 GroundPredicate* groundPred = new GroundPredicate(pred);
02308
02309
02310 int index = gndPredHashArray_.find(groundPred);
02311 if (index < 0)
02312 index = gndPredHashArray_.append(groundPred);
02313 else
02314 delete groundPred;
02315
02316
02317 if (!(*blocks_)[i].contains(index))
02318 (*blocks_)[i].append(index);
02319 }
02320 }
02321
02322 bool initial = true;
02323 addNewClauses(initial);
02324 }
02325
02327
02328
02329 private:
02330
02336 void setHardClauseWeight()
02337 {
02338
02339 long double sumSoftWts = 0.0;
02340
02341 int clauseCnt = mln_->getNumClauses();
02342
02343 for (int i = 0; i < clauseCnt; i++)
02344 {
02345 Clause* fclause = (Clause *) mln_->getClause(i);
02346
02347 if (fclause->isHardClause()) continue;
02348
02349 long double wt = abs(fclause->getWt());
02350 long double numGndings = fclause->getNumGroundings(domain_);
02351 sumSoftWts += wt*numGndings;
02352 }
02353 assert(sumSoftWts >= 0);
02354
02355 hardWt_ = sumSoftWts + 10.0;
02356 cout << "Set hard weight to " << hardWt_ << endl;
02357 }
02358
02366 void replaceAtomIndexInAllClauses(const int& oldIdx, const int& newIdx)
02367 {
02368 Array<int> gndClauseIndexes;
02369
02370 gndClauseIndexes = getNegOccurenceArray(oldIdx + 1);
02371 for (int i = 0; i < gndClauseIndexes.size(); i++)
02372 {
02373
02374 if ((*gndClauses_)[gndClauseIndexes[i]])
02375 (*gndClauses_)[gndClauseIndexes[i]]->changeGndPredIndex(-(oldIdx + 1),
02376 -(newIdx + 1));
02377 }
02378
02379 gndClauseIndexes = getPosOccurenceArray(oldIdx + 1);
02380 for (int i = 0; i < gndClauseIndexes.size(); i++)
02381 {
02382
02383 if ((*gndClauses_)[gndClauseIndexes[i]])
02384 (*gndClauses_)[gndClauseIndexes[i]]->changeGndPredIndex(oldIdx + 1,
02385 newIdx + 1);
02386 }
02387 }
02388
02389 private:
02390
02391
02392 bool lazy_;
02393
02394
02395 long double hardWt_;
02396
02397
02398
02399 MLN* mln_;
02400 Domain* domain_;
02401
02402
02403
02404 Array<GroundPredicate*>* gndPreds_;
02405 Array<GroundClause*>* gndClauses_;
02406
02407
02408
02409 GroundPredicateHashArray* unePreds_;
02410
02411
02412
02413 GroundPredicateHashArray* knePreds_;
02414
02415 Array<TruthValue>* knePredValues_;
02416
02417
02419
02420 int baseNumAtoms_;
02421
02422 bool noApprox_;
02423
02424 bool haveDeactivated_;
02425
02426 int memLimit_;
02427
02428 int clauseLimit_;
02430
02431
02433
02434 MRF* mrf_;
02436
02437
02438 Array<GroundClause*> newClauses_;
02439
02440 Array<GroundPredicate*> newPreds_;
02441
02442
02443 GroundPredicateHashArray gndPredHashArray_;
02444
02445
02446
02447 Array<Array<int> > clause_;
02448
02449 Array<long double> clauseCost_;
02450
02451 long double highestCost_;
02452
02453 bool eqHighest_;
02454
02455 int numHighest_;
02456
02457 Array<int> falseClause_;
02458
02459 Array<int> whereFalse_;
02460
02461 Array<int> numTrueLits_;
02462
02463 Array<int> watch1_;
02464
02465 Array<int> watch2_;
02466
02467 Array<bool> isSatisfied_;
02468
02469 Array<bool> deadClause_;
02470
02471 bool useThreshold_;
02472
02473 Array<long double> threshold_;
02474
02475
02476
02477 Array<Array<int> > occurence_;
02478
02479
02480 Array<bool> atom_;
02481
02482 Array<long double> makeCost_;
02483
02484 Array<long double> breakCost_;
02485
02486
02487 Array<int> fixedAtom_;
02488
02489
02490 Array<bool> lowAtom_;
02491
02492 long double lowCost_;
02493
02494 int lowBad_;
02495
02496
02497 int numFalseClauses_;
02498
02499 long double costOfFalseClauses_;
02500
02501
02502
02503 Array<Array<int> >* blocks_;
02504
02505
02506 Array<bool >* blockEvidence_;
02507
02508
02509 int activeAtoms_;
02510
02511 };
02512
02513 #endif