00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include "clause.h"
00068 #include "mrf.h"
00069 #include "variable.h"
00070 #include "superclause.h"
00071
00072 ClauseSampler* Clause::clauseSampler_ = NULL;
00073 double Clause::fixedSizeB_ = -1;
00074 double AuxClauseData::fixedSizeB_ = -1;
00075
00076
00077
00078
00079
00080
00081 inline bool isRepresentativePartialTuple(Array<int>* const & constants,
00082 int & implicitIndex,
00083 Array<Variable*>* const & eqVars,
00084 int varId)
00085 {
00086 IntHashArray * seenConstants = new IntHashArray();
00087 Variable *var = (*eqVars)[-varId];
00088 for (int i = 0; i < constants->size(); i++)
00089 {
00090 int constantId = (*constants)[i];
00091 if(((*eqVars)[i] == var) && (constantId >= 0) &&
00092 var->isImplicit(constantId))
00093 seenConstants->append(constantId);
00094 }
00095
00096
00097
00098 bool representative = (implicitIndex <= seenConstants->size());
00099 delete seenConstants;
00100 return representative;
00101 }
00102
00103 template <typename Type>
00104 inline Array<Type>* getCanonicalArray(Array<Type>* const & arr,
00105 Array<int> * const & varIdToCanonicalVarId)
00106 {
00107 Array<Type> * canonicalArr = new Array<Type>();
00108 int canonicalVarId;
00109 Type val;
00110 for (int varId = 0; varId < arr->size(); varId++)
00111 {
00112 val = (*arr)[varId];
00113 canonicalVarId = (*varIdToCanonicalVarId)[varId];
00114 if (canonicalVarId < 0)
00115 continue;
00116 if (canonicalArr->size() < canonicalVarId + 1)
00117 canonicalArr->growToSize(canonicalVarId + 1);
00118 (*canonicalArr)[canonicalVarId] = val;
00119 }
00120 return canonicalArr;
00121 }
00122
00123
00124
00125 bool Clause::createAndAddUnknownClause(
00126 Array<GroundClause*>* const& unknownGndClauses,
00127 Array<Clause*>* const& unknownClauses,
00128 double* const & numUnknownClauses,
00129 const AddGroundClauseStruct* const & agcs,
00130 const Database* const & db)
00131 {
00132 PredicateSet predSet;
00133 PredicateSet::iterator iter;
00134
00135 Clause* clause = NULL;
00136 for (int i = 0; i < predicates_->size(); i++)
00137 {
00138 Predicate* predicate = (*predicates_)[i];
00139 assert(predicate->isGrounded());
00140 if (db->getValue(predicate) == UNKNOWN)
00141 {
00142 if ( (iter=predSet.find(predicate)) != predSet.end() )
00143 {
00144
00145 if ((*iter)->getSense() != predicate->getSense())
00146 {
00147 if (clause) delete clause;
00148 return true;
00149 }
00150
00151 continue;
00152 }
00153 else
00154 predSet.insert(predicate);
00155
00156 if (clause == NULL) clause = new Clause();
00157 Predicate* pred = new Predicate(*predicate, clause);
00158 clause->appendPredicate(pred);
00159 }
00160 }
00161
00162 if (clause)
00163 {
00164 if (numUnknownClauses) (*numUnknownClauses)++;
00165
00166 clause->setWt(wt_);
00167 clause->canonicalizeWithoutVariables();
00168
00169 if (agcs)
00170 {
00171 if (clausedebug >= 2)
00172 {
00173 cout << "Appending unknown clause to MRF ";
00174 clause->print(cout, db->getDomain());
00175 cout << endl;
00176 }
00177 MRF::addUnknownGndClause(agcs, this, clause, isHardClause_);
00178 }
00179
00180
00181 if (unknownGndClauses)
00182 {
00183 if (clausedebug >= 2)
00184 {
00185 cout << "Appending unknown ground clause ";
00186 clause->print(cout, db->getDomain());
00187 cout << endl;
00188 }
00189 unknownGndClauses->append(new GroundClause(clause, agcs->gndPreds));
00190 if (isHardClause_) unknownGndClauses->lastItem()->setWtToHardWt();
00191 }
00192 else if (unknownClauses)
00193 {
00194 if (clausedebug >= 2)
00195 {
00196 cout << "Appending unknown clause ";
00197 clause->print(cout, db->getDomain());
00198 cout << endl;
00199 }
00200 unknownClauses->append(clause);
00201 if (isHardClause_) clause->setIsHardClause(true);
00202 }
00203 if (unknownClauses == NULL) delete clause;
00204 }
00205 return false;
00206 }
00207
00208 void addPredicateToHash(const Clause* const & c,
00209 PredicateHashArray* const & predHashArray)
00210 {
00211 int numPreds = c->getNumPredicates();
00212
00213 for (int i = 0; i < numPreds; i++)
00214 {
00215 Predicate* pred = new Predicate(*(c->getPredicate(i)));
00216 int index = predHashArray->find(pred);
00217 if(index < 0 )
00218 {
00219 index = predHashArray->append(pred) + 1;
00220 }
00221 else
00222 {
00223 delete pred;
00224 index++;
00225 }
00226 }
00227 }
00228
00229
00241 bool Clause::createAndAddActiveClause(
00242 Array<GroundClause *> * const & activeGroundClauses,
00243 GroundPredicateHashArray* const& seenGndPreds,
00244 const Database* const & db,
00245 bool const & getSatisfied)
00246 {
00247 bool accumulateClauses = activeGroundClauses;
00248 Predicate *cpred;
00249 PredicateSet predSet;
00250 PredicateSet::iterator iter;
00251
00252 GroundClause *groundClause;
00253
00254 Clause* clause = NULL;
00255 bool isEmpty = true;
00256 for (int i = 0; i < predicates_->size(); i++)
00257 {
00258 Predicate* predicate = (*predicates_)[i];
00259 assert(predicate);
00260 assert(predicate->isGrounded());
00261 if ( (iter = predSet.find(predicate)) != predSet.end() )
00262 {
00263
00264
00265 if (wt_ >= 0 && !getSatisfied &&
00266 (*iter)->getSense() != predicate->getSense())
00267 {
00268 if (clause) delete clause;
00269 return false;
00270 }
00271
00272
00273 continue;
00274 }
00275 else
00276 predSet.insert(predicate);
00277
00278 bool isEvidence = db->getEvidenceStatus(predicate);
00279
00280 if (clausedebug >= 2)
00281 {
00282 cout << "isEvidence " << isEvidence << endl;
00283 predicate->printWithStrVar(cout, db->getDomain());
00284 cout << endl;
00285 }
00286 if (!isEvidence)
00287 isEmpty = false;
00288
00289
00290 if (wt_ < 0 && isEvidence && !getSatisfied &&
00291 db->sameTruthValueAndSense(db->getValue(predicate),
00292 predicate->getSense()))
00293 {
00294 if (clause) delete clause;
00295 return false;
00296 }
00297
00298
00299 if (accumulateClauses && !isEvidence)
00300 {
00301 if (!clause) clause = new Clause();
00302
00303 cpred = new Predicate(*predicate, clause);
00304 assert(cpred);
00305 if (clausedebug >= 2)
00306 {
00307 cout << "Appending pred ";
00308 predicate->printWithStrVar(cout, db->getDomain());
00309 cout << " to clause ";
00310 clause->print(cout, db->getDomain());
00311 cout << endl;
00312 }
00313 clause->appendPredicate(cpred);
00314 if (clausedebug >= 2) cout << "Appended pred to clause" << endl;
00315 }
00316 }
00317
00318
00319
00320 if (isEmpty)
00321 {
00322 if (clausedebug >= 2) cout << "Clause is empty" << endl;
00323 assert(!clause);
00324 return false;
00325 }
00326
00327 else
00328 {
00329
00330 if (accumulateClauses)
00331 {
00332 assert(clause);
00333 if (clausedebug >= 2) cout << "Canonicalizing clause" << endl;
00334 clause->canonicalizeWithoutVariables();
00335
00336 groundClause = new GroundClause(clause, seenGndPreds);
00337 if (isHardClause_)
00338 groundClause->setWtToHardWt();
00339 if (clausedebug >= 2) cout << "Appending ground clause to active set" << endl;
00340 activeGroundClauses->appendUnique(groundClause);
00341 delete clause;
00342 }
00343 return true;
00344 }
00345 }
00346
00347
00348 double Clause::getConstantTuples(const Domain* const & domain,
00349 const Database* const & db,
00350 Array<int>* const & mlnClauseTermIds,
00351 const Clause* const & varClause,
00352 PredicateTermToVariable * const & ptermToVar,
00353 ClauseToSuperClauseMap* const & clauseToSuperClause,
00354 bool useImplicit)
00355 {
00356 bool ignoreActivePreds = true;
00357 double numTrueGndings = 0;
00358
00359
00360
00361 Array<int> *constants = new Array<int>(*mlnClauseTermIds);
00362
00363
00364
00365
00366 cout<<"***************************************************************"<<endl;
00367 cout<<"Came to process the clause : "<<endl;
00368 print(cout, domain);
00369 cout<<endl;
00370
00371 createVarIdToVarsGroundedType(domain);
00372
00373 Array<Predicate*>* origClauseLits = new Array<Predicate*>(*predicates_);
00374
00375
00376
00377 Array<Array<Predicate*>* > partGroundedClauses;
00378
00379
00380
00381
00382 PredicateTermToVariable::iterator itr;
00383 PredicateTerm *pterm;
00384 Predicate *pred;
00385 const Term *term;
00386 Array<Variable *> *eqVars = new Array<Variable *>();
00387 eqVars->growToSize(mlnClauseTermIds->size());
00388
00389
00390 for (int i = 0; i < eqVars->size(); i++)
00391 (*eqVars)[i] = NULL;
00392
00393 if (useImplicit)
00394 {
00395 for (int predno = 0; predno < varClause->getNumPredicates(); predno++)
00396 {
00397 pred = varClause->getPredicate(predno);
00398 int predId = pred->getId();
00399 for (int termno = 0; termno < pred->getNumTerms(); termno++)
00400 {
00401 term = pred->getTerm(termno);
00402 assert(!term->isConstant());
00403 int termId = term->getId();
00404 pterm = new PredicateTerm(predId, termno);
00405 itr = ptermToVar->find(pterm);
00406 assert(itr != ptermToVar->end());
00407 assert(-termId < eqVars->size());
00408 (*eqVars)[-termId] = itr->second;
00409 delete pterm;
00410 }
00411 }
00412 }
00413
00414 if (useInverseIndex)
00415 {
00416
00417 sortLiteralsByNegationAndArity(*origClauseLits, ignoreActivePreds, db);
00418 groundIndexableLiterals(domain, db, *origClauseLits, partGroundedClauses,
00419 ignoreActivePreds);
00420 }
00421 else
00422 {
00423
00424
00425
00426
00427
00428
00429 sortLiteralsByTrueDivTotalGroundings(*origClauseLits, domain, db);
00430
00431 Array<Predicate*>* clauseLitsCopy = new Array<Predicate*>;
00432 clauseLitsCopy->growToSize(origClauseLits->size());
00433 for (int i = 0; i < origClauseLits->size(); i++)
00434 (*clauseLitsCopy)[i] = new Predicate(*(*origClauseLits)[i]);
00435 partGroundedClauses.append(clauseLitsCopy);
00436 }
00437
00438
00439
00440
00441 if (clausedebug)
00442 {
00443 cout << "Partially grounded clauses to be completed: " << endl;
00444 for (int pgcIdx = 0; pgcIdx < partGroundedClauses.size(); pgcIdx++)
00445 {
00446 cout << "\t";
00447 for (int i = 0; i < partGroundedClauses[pgcIdx]->size(); i++)
00448 {
00449 (*partGroundedClauses[pgcIdx])[i]->printWithStrVar(cout, domain);
00450 cout << " ";
00451 }
00452 cout << endl;
00453 }
00454 }
00455
00456 bool skip;
00457
00458
00459
00460 for (int pgcIdx = 0; pgcIdx < partGroundedClauses.size(); pgcIdx++)
00461 {
00462
00463 constants->copyFrom(*mlnClauseTermIds);
00464
00465 skip = false;
00466
00467 Array<Predicate*> clauseLits = *(partGroundedClauses[pgcIdx]);
00468 assert(clauseLits.size() == origClauseLits->size());
00469
00470 Array<int>* origVarIds = new Array<int>;
00471
00472 for (int i = 0; i < clauseLits.size(); i++)
00473 {
00474 assert(clauseLits[i]->getNumTerms() ==
00475 (*origClauseLits)[i]->getNumTerms());
00476
00477 for (int j = 0; j < (*origClauseLits)[i]->getNumTerms(); j++)
00478 {
00479 const Term* oldTerm = (*origClauseLits)[i]->getTerm(j);
00480 const Term* newTerm = clauseLits[i]->getTerm(j);
00481
00482 if (oldTerm->getType() == Term::VARIABLE)
00483 {
00484 int varId = oldTerm->getId();
00485 origVarIds->append(varId);
00486 if (newTerm->getType() == Term::CONSTANT)
00487 {
00488 int constId = newTerm->getId();
00489 assert(constId >= 0);
00490 Array<Term*>& vars = (*varIdToVarsGroundedType_)[-varId]->vars;
00491 assert(constants->size() >= (-varId+1));
00492
00493 if (useImplicit)
00494 {
00495 int implicitIndex =
00496 (*eqVars)[-varId]->getImplicitIndex(constId);
00497 if (implicitIndex < 0)
00498 {
00499 (*constants)[-varId] = constId;
00500 }
00501 else
00502 {
00503 if (isRepresentativePartialTuple(constants, implicitIndex,
00504 eqVars, varId))
00505 {
00506 (*constants)[-varId] = constId;
00507 }
00508 else
00509 {
00510 skip = true;
00511 }
00512 }
00513 }
00514 else
00515 {
00516 (*constants)[-varId] = constId;
00517 }
00518
00519 for (int k = 0; k < vars.size(); k++) vars[k]->setId(constId);
00520 }
00521 }
00522 }
00523
00524 delete clauseLits[i];
00525 clauseLits[i] = (*origClauseLits)[i];
00526 }
00527
00528 if (!skip)
00529 {
00530
00531
00532 Array<LitIdxVarIdsGndings*> ivgArr;
00533 createAllLitIdxVarsGndings(clauseLits, ivgArr, domain, true);
00534 int ivgArrIdx = 0;
00535 bool lookAtNextLit = false;
00536
00537
00538 while (ivgArrIdx >= 0)
00539 {
00540
00541 LitIdxVarIdsGndings* ivg = ivgArr[ivgArrIdx];
00542 Predicate* lit = (*origClauseLits)[ivg->litIdx];
00543
00544 Array<int>& varIds = ivg->varIds;
00545 ArraysAccessor<int>& varGndings = ivg->varGndings;
00546 bool& litUnseen = ivg->litUnseen;
00547 bool hasComb;
00548
00549 if (clausedebug)
00550 {
00551 cout << "Looking at lit: ";
00552 lit->printWithStrVar(cout, domain);
00553 cout << endl;
00554 }
00555
00556 bool gotoNextComb = false;
00557
00558 while ((hasComb = varGndings.hasNextCombination()) || litUnseen)
00559 {
00560
00561 if (litUnseen) litUnseen = false;
00562
00563 if (hasComb)
00564 {
00565
00566 for (int v = 0; v < varIds.size(); v++)
00567 {
00568 (*constants)[-varIds[v]] = varIds[v];
00569 }
00570
00571 int constId;
00572 int v = 0;
00573
00574 gotoNextComb = false;
00575 while (varGndings.nextItemInCombination(constId))
00576 {
00577 int varId = varIds[v];
00578 Array<Term*>& vars = (*varIdToVarsGroundedType_)[-varId]->vars;
00579
00580
00581 assert(constants->size() >= (-varId+1));
00582
00583 if (useImplicit)
00584 {
00585 int implicitIndex =
00586 (*eqVars)[-varId]->getImplicitIndex(constId);
00587 if (implicitIndex < 0)
00588 {
00589 (*constants)[-varId] = constId;
00590 }
00591 else
00592 {
00593
00594
00595
00596
00597
00598 if (isRepresentativePartialTuple(constants, implicitIndex,
00599 eqVars, varId))
00600 {
00601 (*constants)[-varId] = constId;
00602 }
00603 else
00604 {
00605 gotoNextComb = true;
00606 }
00607 }
00608 }
00609 else
00610 {
00611 (*constants)[-varId] = constId;
00612 }
00613 v++;
00614 for (int i = 0; i < vars.size(); i++) vars[i]->setId(constId);
00615 }
00616
00617
00618 assert(varIds.size() == v);
00619 }
00620
00621 removeRedundantPredicates();
00622
00623 if (clausedebug)
00624 {
00625 cout << "Clause is now: ";
00626 printWithWtAndStrVar(cout, domain);
00627 cout << endl;
00628 }
00629
00630 if (gotoNextComb)
00631 continue;
00632
00633 if (literalOrSubsequentLiteralsAreTrue(lit, ivg->subseqGndLits, db))
00634 {
00635 if (clausedebug)
00636 cout << "Clause satisfied" << endl;
00637
00638
00639 double numComb = 1;
00640 for (int i = ivgArrIdx + 1; i < ivgArr.size(); i++)
00641 {
00642 int numVar = ivgArr[i]->varGndings.getNumArrays();
00643 for (int j = 0; j < numVar; j++)
00644 numComb *= ivgArr[i]->varGndings.getArray(j)->size();
00645 }
00646 numTrueGndings += numComb;
00647 }
00648 else
00649 {
00650
00651 if (ivgArrIdx + 1 < ivgArr.size())
00652 {
00653 if (clausedebug) cout << "Moving to next literal" << endl;
00654 lookAtNextLit = true;
00655 ivgArrIdx++;
00656 break;
00657 }
00658
00659
00660
00661 if (hasTwoLiteralsWithOppSense(db))
00662 {
00663 ++numTrueGndings;
00664 }
00665 else
00666 {
00667
00668 addConstantTuple(domain, db, varClause, constants, eqVars,
00669 clauseToSuperClause, useImplicit);
00670 }
00671 }
00672 }
00673
00674
00675
00676 if (lookAtNextLit) { lookAtNextLit = false; }
00677
00678 else
00679 {
00680 varGndings.reset();
00681 litUnseen = true;
00682 ivgArrIdx--;
00683
00684 for (int v = 0; v < varIds.size(); v++)
00685 {
00686 (*constants)[-varIds[v]] = varIds[v];
00687 }
00688 }
00689 }
00690 deleteAllLitIdxVarsGndings(ivgArr);
00691 }
00692
00693
00694 for (int i = 0; i < origVarIds->size(); i++)
00695 {
00696 int varId = (*origVarIds)[i];
00697 assert(varId < 0);
00698 Array<Term*>& vars = (*varIdToVarsGroundedType_)[-varId]->vars;
00699 for (int j = 0; j < vars.size(); j++) vars[j]->setId(varId);
00700 (*varIdToVarsGroundedType_)[-varId]->isGrounded = false;
00701 }
00702
00703 delete origVarIds;
00704 delete partGroundedClauses[pgcIdx];
00705 }
00706 delete origClauseLits;
00707 delete constants;
00708 return numTrueGndings;
00709 }
00710
00711
00712
00713 inline void Clause::addConstantTuple(const Domain* const & domain,
00714 const Database* const & db,
00715 const Clause * const & varClause,
00716 Array<int> * const & constants,
00717 Array<Variable *> * const & eqVars,
00718 ClauseToSuperClauseMap* const & clauseToSuperClause,
00719 bool useImplicit)
00720 {
00721
00722 PredicateSet predSet;
00723
00724 PredicateSet::iterator iter;
00725 Predicate *predicate;
00726
00727 Clause *clause = new Clause();
00728
00729
00730
00731 for (int i = 0; i < predicates_->size(); i++)
00732 {
00733 predicate = (*predicates_)[i];
00734 assert(predicate->isGrounded());
00735 if (db->getValue(predicate) == UNKNOWN)
00736 {
00737 clause->appendPredicate(varClause->getPredicate(i));
00738 if ( (iter=predSet.find(predicate)) != predSet.end() )
00739 {
00740
00741 if ((*iter)->getSense() != predicate->getSense())
00742 {
00743 clause->removeAllPredicates();
00744 delete clause;
00745 return;
00746 }
00747
00748 continue;
00749 }
00750 else
00751 predSet.insert(predicate);
00752 }
00753 }
00754
00755 SuperClause *superClause;
00756 ClauseToSuperClauseMap::iterator itr;
00757 Clause *keyClause;
00758 Array<int> * varIdToCanonicalVarId;
00759 Array<int> * canonicalConstants;
00760 Array<Variable *> * canonicalEqVars;
00761
00762
00763 if (clause->getNumPredicates() == 0)
00764 {
00765 delete clause;
00766 return;
00767 }
00768
00769
00770
00771 itr = clauseToSuperClause->find(clause);
00772 if (itr == clauseToSuperClause->end())
00773 {
00774
00775
00776 keyClause = new Clause(*clause);
00777 keyClause->setWt(1);
00778
00779
00780 int varCnt = constants->size();
00781 varIdToCanonicalVarId = new Array<int>(varCnt, -1);
00782 keyClause->canonicalize(varIdToCanonicalVarId);
00783 canonicalEqVars = getCanonicalArray(eqVars, varIdToCanonicalVarId);
00784 superClause = new SuperClause(keyClause, canonicalEqVars,
00785 varIdToCanonicalVarId, useImplicit, wt_);
00786 (*clauseToSuperClause)[clause] = superClause;
00787 delete canonicalEqVars;
00788 delete varIdToCanonicalVarId;
00789 }
00790 else
00791 {
00792 superClause = itr->second;
00793
00794
00795
00796
00797
00798 int varCnt = constants->size();
00799 varIdToCanonicalVarId = superClause->getVarIdToCanonicalVarId();
00800 varIdToCanonicalVarId->growToSize(varCnt,-1);
00801
00802
00803 clause->removeAllPredicates();
00804 delete clause;
00805 }
00806
00807 varIdToCanonicalVarId = superClause->getVarIdToCanonicalVarId();
00808 canonicalConstants = getCanonicalArray(constants, varIdToCanonicalVarId);
00809 superClause->addNewConstantsAndIncrementCount(canonicalConstants,
00810
00811 this->getWt());
00812
00813
00814 delete canonicalConstants;
00815 }
00816
00817
00818
00819
00820
00821 Array<int> * Clause::updateToVarClause()
00822 {
00823 Array<int> * termIds = new Array<int>();
00824
00825 termIds->append(0);
00826
00827 const Predicate *pred;
00828
00829
00830 for (int i = 0; i < predicates_->size(); i++)
00831 {
00832 pred = (*predicates_)[i];
00833 for (int j = 0; j < pred->getNumTerms(); j++)
00834 {
00835 const Term* t = pred->getTerm(j);
00836 if (t->getType() == Term::VARIABLE)
00837 {
00838 int id = t->getId();
00839 assert(id < 0);
00840 termIds->growToSize(-id+1);
00841 (*termIds)[-id] = id;
00842 }
00843 }
00844 }
00845
00846
00847
00848 for (int i = 0; i < predicates_->size(); i++)
00849 {
00850 pred = (*predicates_)[i];
00851 for (int j = 0; j < pred->getNumTerms(); j++)
00852 {
00853 Term* t = (Term *) pred->getTerm(j);
00854
00855 if (t->getType() == Term::VARIABLE)
00856 continue;
00857 int constantId = t->getId();
00858 assert(constantId >= 0);
00859 termIds->append(constantId);
00860
00861
00862 int varId = termIds->size()-1;
00863 t->setId(-varId);
00864 }
00865 }
00866 return termIds;
00867 }
00868