00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #ifndef STRUCTLEARN_H_NOV_2_2005
00068 #define STRUCTLEARN_H_NOV_2_2005
00069
00070 #include "clausefactory.h"
00071 #include "mln.h"
00072
00073 #include "lbfgsb.h"
00074 #include "discriminativelearner.h"
00075
00076
00078 struct ExistFormula
00079 {
00080 ExistFormula(const string& fformula)
00081 : formula(fformula), gain(0), wt(0), newScore(0), numPreds(0) {}
00082 ~ExistFormula()
00083 {
00084 for (int i = 0; i < cnfClausesForDomains.size(); i++)
00085 cnfClausesForDomains[i].deleteItemsAndClear();
00086 }
00087
00088 string formula;
00089
00090
00091 Array<Array<Clause*> > cnfClausesForDomains;
00092 double gain;
00093 double wt;
00094 Array<double> wts;
00095 double newScore;
00096 int numPreds;
00097 };
00098
00099 struct IndexCountDomainIdx { IndexAndCount* iac; int domainIdx; };
00100 struct ClauseAndICDArray {Clause* clause;Array<IndexCountDomainIdx> icdArray;};
00101
00102
00103 class StructLearn
00104 {
00105 public:
00106
00107
00108 StructLearn(Array<MLN*>* const & mlns, const bool& startFromEmptyMLN,
00109 const string& outMLNFileName, Array<Domain*>* const & domains,
00110 const Array<string>* const & nonEvidPredNames,
00111 const int& maxVars, const int& maxNumPredicates,
00112 const bool& cacheClauses, const double& maxCacheSizeMB,
00113 const bool& tryAllFlips, const bool& sampleClauses,
00114 const double& delta, const double& epsilon,
00115 const int& minClauseSamples, const int& maxClauseSamples,
00116 const bool& hasPrior, const double& priorMean,
00117 const double& priorStdDev, const bool& wtPredsEqually,
00118 const int& lbMaxIter, const double& lbConvThresh,
00119 const int& looseMaxIter, const double& looseConvThresh,
00120 const int& beamSize, const int& bestGainUnchangedLimit,
00121 const int& numEstBestClauses,
00122 const double& minWt, const double& penalty,
00123 const bool& sampleGndPreds, const double& fraction,
00124 const int& minGndPredSamples, const int& maxGndPredSamples,
00125 const bool& reEvaluateBestCandidatesWithTightParams,
00126 const bool& structGradDescent, const bool& withEM)
00127 : mln0_((*mlns)[0]), mlns_(mlns), startFromEmptyMLN_(startFromEmptyMLN),
00128 outMLNFileName_(outMLNFileName), domains_(domains),
00129 preds_(new Array<Predicate*>), areNonEvidPreds_(new Array<bool>),
00130 clauseFactory_(new ClauseFactory(maxVars, maxNumPredicates,
00131 (*domains_)[0])),
00132 cacheClauses_(cacheClauses), origCacheClauses_(cacheClauses),
00133 cachedClauses_((cacheClauses) ? (new ClauseHashArray) : NULL),
00134 cacheSizeMB_(0), maxCacheSizeMB_(maxCacheSizeMB),
00135 tryAllFlips_(tryAllFlips),
00136 sampleClauses_(sampleClauses), origSampleClauses_(sampleClauses),
00137 pll_(NULL), hasPrior_(hasPrior), priorMean_(priorMean),
00138 priorStdDev_(priorStdDev), wtPredsEqually_(wtPredsEqually),
00139 lbfgsb_(NULL), lbMaxIter_(lbMaxIter), lbConvThresh_(lbConvThresh),
00140 looseMaxIter_(looseMaxIter), looseConvThresh_(looseConvThresh),
00141 beamSize_(beamSize), bestGainUnchangedLimit_(bestGainUnchangedLimit),
00142 numEstBestClauses_(numEstBestClauses),
00143 minGain_(0), minWt_(minWt), penalty_(penalty),
00144 sampleGndPreds_(sampleGndPreds), fraction_(fraction),
00145 minGndPredSamples_(minGndPredSamples),
00146 maxGndPredSamples_(maxGndPredSamples),
00147 reEvalBestCandsWithTightParams_(reEvaluateBestCandidatesWithTightParams),
00148 candCnt_(0), iter_(-1), bsiter_(-1), startSec_(-1), indexTrans_(NULL),
00149 structGradDescent_(structGradDescent), withEM_(withEM)
00150 {
00151 assert(minWt_ >= 0);
00152 assert(domains_->size() == mlns_->size());
00153
00154 areNonEvidPreds_->growToSize((*domains_)[0]->getNumPredicates(), false);
00155 for (int i = 0; i < nonEvidPredNames->size(); i++)
00156 {
00157 int predId=(*domains_)[0]->getPredicateId((*nonEvidPredNames)[i].c_str());
00158 if (predId < 0)
00159 {
00160 cout << "ERROR: in StructLearn::StructLearn(). Predicate "
00161 << (*nonEvidPredNames)[i] << " undefined." << endl;
00162 exit(-1);
00163 }
00164 (*areNonEvidPreds_)[predId] = true;
00165 }
00166
00167 (*domains_)[0]->createPredicates(preds_, true);
00168
00169 if (origSampleClauses_)
00170 {
00171 ClauseSampler* cs = new ClauseSampler(delta, epsilon, minClauseSamples,
00172 maxClauseSamples);
00173 Clause::setClauseSampler(cs);
00174 for (int i = 0; i < domains_->size(); i++)
00175 (*domains_)[i]->newTrueFalseGroundingsStore();
00176 }
00177 }
00178
00179
00180 ~StructLearn()
00181 {
00182 if (pll_) delete pll_;
00183 if (lbfgsb_) delete lbfgsb_;
00184 preds_->deleteItemsAndClear();
00185 delete preds_;
00186 delete areNonEvidPreds_;
00187 delete clauseFactory_;
00188 if (cachedClauses_)
00189 {
00190 cachedClauses_->deleteItemsAndClear();
00191 delete cachedClauses_;
00192 }
00193 if (origSampleClauses_) delete Clause::getClauseSampler();
00194 if (indexTrans_) delete indexTrans_;
00195 }
00196
00197 void run()
00198 {
00199 startSec_ = timer_.time();
00200
00201 bool needIndexTrans = IndexTranslator::needIndexTranslator(*mlns_,*domains_);
00202
00203
00204
00205 Array<Clause*> initialMLNClauses;
00206 Array<ExistFormula*> existFormulas;
00207 if (startFromEmptyMLN_)
00208 {
00209 getMLNClauses(initialMLNClauses, existFormulas);
00210 removeClausesFromMLNs();
00211 for (int i = 0; i < initialMLNClauses.size(); i++)
00212 {
00213 Clause* c = initialMLNClauses[i];
00214 c->newAuxClauseData();
00215 c->trackConstants();
00216 c->getAuxClauseData()->gain = 0;
00217 c->getAuxClauseData()->op = OP_ADD;
00218 c->getAuxClauseData()->removedClauseIdx = -1;
00219 c->getAuxClauseData()->hasBeenExpanded = false;
00220 c->getAuxClauseData()->lastStepExpanded = -1;
00221 c->getAuxClauseData()->lastStepOverMinWeight = -1;
00222 }
00223 }
00224
00225
00226 cout << "adding unit clauses to MLN..." << endl << endl;
00227 addUnitClausesToMLNs();
00228
00229
00230 for (int i = 0; i < mln0_->getNumClauses(); i++)
00231 {
00232 Clause* c = (Clause*) mln0_->getClause(i);
00233 c->newAuxClauseData();
00234
00235
00236 if (isModifiableClause(i)) c->trackConstants();
00237 }
00238
00239 indexTrans_ = (needIndexTrans)? new IndexTranslator(mlns_, domains_) : NULL;
00240 if (indexTrans_)
00241 cout << "The weights of clauses in the CNFs of existential formulas wiil "
00242 << "be tied" << endl;
00243
00244
00245 runStructLearning(initialMLNClauses, existFormulas);
00246 }
00247
00248
00249 void runStructLearning(Array<Clause*> initialMLNClauses,
00250 Array<ExistFormula*> existFormulas)
00251 {
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296 pll_ = new PseudoLogLikelihood(areNonEvidPreds_, domains_, wtPredsEqually_,
00297 sampleGndPreds_, fraction_,
00298 minGndPredSamples_, maxGndPredSamples_);
00299 pll_->setIndexTranslator(indexTrans_);
00300
00301 int numClausesFormulas = getNumClausesFormulas();
00302 lbfgsb_ = new LBFGSB(-1, -1, pll_, numClausesFormulas);
00303
00304 useTightParams();
00305
00306
00307 cout << "computing counts for initial MLN clauses..." << endl;
00308 double begSec = timer_.time();
00309 pllComputeCountsForInitialMLN();
00310 cout << "computing counts for initial MLN clauses took ";
00311 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00312
00313
00314
00315 cout << "learning the initial weights and score of MLN..." << endl << endl;
00316 begSec = timer_.time();
00317 double score;
00318 Array<double> wts;
00319 if (!learnAndSetMLNWeights(score)) return;
00320 printMLNClausesWithWeightsAndScore(score, -1);
00321 cout << "learning the initial weights and score of MLN took ";
00322 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00323
00324
00325
00326 cout <<"trying to add unit clause with diff variable combinations to MLN..."
00327 << endl << endl;
00328 begSec = timer_.time();
00329 appendUnitClausesWithDiffCombOfVar(score);
00330 printMLNClausesWithWeightsAndScore(score, -1);
00331 cout <<"adding unit clause with diff variable combinations to MLN took ";
00332 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00333
00334
00335
00336 if (tryAllFlips_)
00337 {
00338 cout << "trying to flip the senses of MLN clauses..." << endl << endl;
00339 begSec = timer_.time();
00340 flipMLNClauses(score);
00341 printMLNClausesWithWeightsAndScore(score, -1);
00342 cout << "trying to flip the senses of MLN clauses took ";;
00343 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00344 }
00345
00346 iter_ = -1;
00347
00348
00349
00350 double bsec;
00351 Array<Clause*> initialClauses;
00352 Array<Clause*> bestCandidates;
00353 while (true)
00354 {
00355 begSec = timer_.time();
00356 iter_++;
00357
00358
00359 cout << "Iteration " << iter_ << endl << endl;
00360
00361 minGain_ = 0;
00362
00363 Array<ExistFormula*> highGainWtExistFormulas;
00364 if (startFromEmptyMLN_ && !existFormulas.empty())
00365 {
00366 useTightParams();
00367 cout << "evaluating the gains of existential formulas..." << endl<<endl;
00368 bsec = timer_.time();
00369
00370 minGain_ = evaluateExistFormulas(existFormulas, highGainWtExistFormulas,
00371 score);
00372 cout << "evaluating the gains of existential formulas took ";
00373 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00374 cout << "setting minGain to min gain among existential formulas. "
00375 << "minGain = " << minGain_ << endl;
00376 }
00377
00378 useLooseParams();
00379
00380
00381 initialClauses.clear();
00382 for (int i = 0; i < mln0_->getNumClauses(); i++)
00383 {
00384 Clause* c = (Clause*) mln0_->getClause(i);
00385 c->getAuxClauseData()->reset();
00386 if (isModifiableClause(i))
00387 {
00388 c->trackConstants();
00389 initialClauses.append(new Clause(*c));
00390 }
00391 }
00392
00393 bestCandidates.clear();
00394 beamSearch(initialClauses, initialMLNClauses, score, bestCandidates);
00395 bool noCand = (startFromEmptyMLN_ && !existFormulas.empty()) ?
00396 (bestCandidates.empty() && highGainWtExistFormulas.empty())
00397 : bestCandidates.empty();
00398 if (noCand)
00399 {
00400 cout << "Beam is empty. Ending search for MLN clauses." << endl;
00401 printIterAndTimeElapsed(begSec);
00402 break;
00403 }
00404
00405 useTightParams();
00406
00407 if (reEvalBestCandsWithTightParams_)
00408 {
00409 cout << "reevaluating top candidates... " << endl << endl;
00410 bsec = timer_.time();
00411 reEvaluateCandidates(bestCandidates, score);
00412 cout << "reevaluating top candidates took ";
00413 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00414 }
00415
00416
00417 bsec = timer_.time();
00418 bool ok;
00419 if (startFromEmptyMLN_ && !existFormulas.empty())
00420 ok = effectBestCandidateOnMLNs(bestCandidates, existFormulas,
00421 highGainWtExistFormulas, score);
00422 else
00423 ok = effectBestCandidateOnMLNs(bestCandidates, score);
00424
00425 cout << "effecting best candidates took ";
00426 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00427
00428 if (!ok)
00429 {
00430 cout << "failed to effect any of the best candidates on MLN" << endl
00431 << "stopping search for MLN clauses..." << endl;
00432 break;
00433 }
00434
00435 printIterAndTimeElapsed(begSec);
00436 }
00437
00438 cout << "done searching for MLN clauses" << endl << endl;
00439 int numIterTaken = iter_+1;
00440 iter_= -1;
00441
00442 useTightParams();
00443
00444
00445 cout << "pruning clauses from MLN..." << endl << endl;
00446 begSec = timer_.time();
00447 pruneMLN(score);
00448 cout << "pruning clauses from MLN took ";
00449 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00450
00451 printMLNClausesWithWeightsAndScore(score, -1);
00452 printMLNToFile(NULL, -2);
00453 cout << "num of iterations taken = " << numIterTaken << endl;
00454
00455 cout << "time taken for structure learning = ";
00456 timer_.printTime(cout, timer_.time()-startSec_); cout << endl << endl;
00457
00458 initialMLNClauses.deleteItemsAndClear();
00459 deleteExistFormulas(existFormulas);
00460 }
00461
00462
00463 bool learnAndSetMLNWeights(double& score)
00464 {
00465 Array<double> priorMeans, priorStdDevs;
00466 double tmpScore = score;
00467 pllSetPriorMeansStdDevs(priorMeans, priorStdDevs, 0, NULL);
00468 int numClausesFormulas = getNumClausesFormulas();
00469 Array<double> wts;
00470 wts.growToSize(numClausesFormulas + 1);
00471
00472 int iter; bool error; double elapsedSec;
00473 tmpScore = maximizeScore(numClausesFormulas, 0, &wts, NULL, NULL,
00474 iter, error, elapsedSec);
00475 if (error)
00476 {
00477 cout << "LBFGSB failed to find wts" << endl;
00478 return false;
00479 }
00480 else
00481 {
00482 score = tmpScore;
00483 printNewScore((Clause*)NULL, NULL, iter, elapsedSec, score, 0, 0);
00484 }
00485
00486 updateWts(wts, NULL, NULL);
00487
00488 return true;
00489 }
00490
00492 private:
00493
00494
00495
00496 void beamSearch(const Array<Clause*>& initClauses,
00497 const Array<Clause*>& initMLNClauses, const double& prevScore,
00498 Array<Clause*>& bestClauses)
00499 {
00500 int iterBestGainUnchanged = 0; bsiter_ = -1;
00501 Array<double> priorMeans, priorStdDevs;
00502 ClauseOpHashArray* beam = new ClauseOpHashArray;
00503 for (int i = 0; i < initClauses.size(); i++) beam->append(initClauses[i]);
00504
00505 int numClausesFormulas = getNumClausesFormulas();
00506 Array<double> wts;
00507 wts.growToSize(numClausesFormulas + 2);
00508 Array<double> origWts(numClausesFormulas);
00509 if (indexTrans_) indexTrans_->getClauseFormulaWts(origWts);
00510 else mln0_->getClauseWts(origWts);
00511
00512
00513 setPriorMeansStdDevs(priorMeans, priorStdDevs, 1, NULL);
00514 if (indexTrans_) indexTrans_->appendClauseIdxToClauseFormulaIdxs(1, 1);
00515
00516 bool error;
00517 double begIterSec, bsec;
00518 Array<Clause*> candidates;
00519 while (!beam->empty() && iterBestGainUnchanged < bestGainUnchangedLimit_)
00520 {
00521 begIterSec = timer_.time();
00522 bsiter_++;
00523
00524
00525 cout << endl << "BEAM SEARCH ITERATION " << bsiter_ << endl << endl;
00526
00527 cout << "creating candidate clauses..." << endl;
00528 candidates.clear();
00529 bsec = timer_.time();
00530
00531 if (bsiter_==0) createCandidateClauses(beam, candidates, &initMLNClauses);
00532 else createCandidateClauses(beam, candidates, NULL);
00533
00534 cout << "num of candidates created = " << candidates.size() << "; took ";
00535 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00536
00537 cout << "evaluating gain of candidates..." << endl << endl;
00538 bsec = timer_.time();
00539 for (int i = 0; i < candidates.size(); i++)
00540 countAndMaxScoreEffectCandidate(candidates[i], &wts, &origWts,prevScore,
00541 true, priorMeans, priorStdDevs, error,
00542 NULL, NULL);
00543 cout << "evaluating gain of candidates took ";
00544 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00545
00546
00547 cout << "finding best candidates..." << endl;
00548 bsec = timer_.time();
00549 ClauseOpHashArray* newBeam = new ClauseOpHashArray(beamSize_);
00550 bool newBestClause = findBestClausesAmongCandidates(candidates, newBeam,
00551 bestClauses);
00552 cout << "finding best candidates took ";
00553 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00554
00555 beam->deleteItemsAndClear();
00556 delete beam;
00557 beam = newBeam;
00558
00559 if (newBestClause)
00560 {
00561 iterBestGainUnchanged = 0;
00562 cout << "found new best clause in beam search iter " << bsiter_ << endl;
00563 }
00564 else
00565 iterBestGainUnchanged++;
00566
00567
00568
00569 cout << "best clauses found in beam search iter " << bsiter_ << endl;
00570 cout << "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" << endl;
00571 for (int i = 0; i < bestClauses.size(); i++)
00572 {
00573 cout << i << "\t";
00574 bestClauses[i]->printWithoutWtWithStrVar(cout,(*domains_)[0]);
00575 cout << endl
00576 << "\tgain = " << bestClauses[i]->getAuxClauseData()->gain
00577 << ", op = "
00578 << Clause::getOpAsString(bestClauses[i]->getAuxClauseData()->op);
00579 if (bestClauses[i]->getAuxClauseData()->op != OP_REMOVE)
00580 cout << ", wt = " << bestClauses[i]->getWt();
00581 cout << endl;
00582 }
00583 cout << "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" << endl << endl;
00584
00585
00586 cout << "BEAM SEARCH ITERATION " << bsiter_ << " took ";
00587 timer_.printTime(cout, timer_.time()-begIterSec); cout << endl << endl;
00588 cout << "Time elapsed = ";
00589 timer_.printTime(cout, timer_.time()-startSec_); cout << endl << endl;
00590 }
00591
00592 if (beam->empty()) cout << "Beam search ended because beam is empty"
00593 << endl << endl;
00594 else cout << "Beam search ended because best clause did not change in "
00595 << iterBestGainUnchanged << " iterations" << endl << endl;
00596
00597 beam->deleteItemsAndClear();
00598 delete beam;
00599 bsiter_ = -1;
00600
00601 if (indexTrans_) indexTrans_->removeClauseIdxToClauseFormulaIdxs(1, 1);
00602 }
00603
00604
00606 private:
00607
00608
00609
00610 double maximizeScore(const int& numClausesFormulas, const int& numExtraWts,
00611 Array<double>* const & wts,
00612 const Array<double>* const & origWts,
00613 const Array<int>* const & removedClauseFormulaIdxs,
00614 int& iter, bool& error, double& elapsedSec)
00615 {
00616 if (origWts)
00617 { for (int i=1 ; i<=numClausesFormulas; i++) (*wts)[i] = (*origWts)[i-1];}
00618 else
00619 { for (int i=1; i<=numClausesFormulas; i++) (*wts)[i] = 0; }
00620
00621 for (int i = 1; i <= numExtraWts; i++) (*wts)[numClausesFormulas+i] = 0;
00622
00623 if (removedClauseFormulaIdxs)
00624 for (int i = 0; i < removedClauseFormulaIdxs->size(); i++)
00625 (*wts)[ (*removedClauseFormulaIdxs)[i]+1 ] = 0;
00626
00627
00628
00629
00630
00631 double* wwts = (double*) wts->getItems();
00632 double begSec = timer_.time();
00633 double newScore
00634 = lbfgsb_->minimize(numClausesFormulas + numExtraWts, wwts, iter, error);
00635 newScore = -newScore;
00636 elapsedSec = timer_.time() - begSec;
00637
00638
00639
00640
00641
00642 return newScore;
00643 }
00644
00645
00646
00647
00648
00649
00650
00651
00652
00653 double countAndMaxScoreEffectAllCandidates(const Array<Clause*>& candidates,
00654 Array<double>* const & wts,
00655 const Array<double>*const& origWts,
00656 const double& prevScore,
00657 const bool& resetPriors,
00658 Array<double>& priorMeans,
00659 Array<double>& priorStdDevs,
00660 bool& error,
00661 Array<UndoInfo*>* const& uundoInfos,
00662 Array<ClauseAndICDArray*>* const & appendedClauseInfos)
00663 {
00664 Array<UndoInfo*>* undoInfos = (uundoInfos)? uundoInfos:new Array<UndoInfo*>;
00665 Array<int> remClauseFormulaIdxs;
00666 Array<int*> idxPtrs;
00667 Array<Clause*> toBeRemovedClauses;
00668 Array<Clause*> toBeAppendedClauses;
00669 int numClausesFormulas = getNumClausesFormulas();
00670
00671 for (int i = 0; i < candidates.size(); i++)
00672 {
00673 Clause* cand = candidates[i];
00674 AuxClauseData* acd = cand->getAuxClauseData();
00675 int op = acd->op;
00676
00677
00678 if (op == OP_REMOVE || op == OP_REPLACE || op == OP_REPLACE_ADDPRED ||
00679 op == OP_REPLACE_REMPRED)
00680 {
00681 Clause* remClause = (mlns_->size() > 1) ?
00682 (Clause*) mln0_->getClause(acd->removedClauseIdx) : NULL;
00683 Array<int*> idxs; idxs.growToSize(mlns_->size());
00684 Array<Clause*> remClauses; remClauses.growToSize(mlns_->size());
00685 for (int d = 0; d < mlns_->size(); d++)
00686 {
00687 int remIdx;
00688 if (d == 0)
00689 {
00690 remIdx = acd->removedClauseIdx;
00691 remClauseFormulaIdxs.append(remIdx);
00692 }
00693 else
00694 {
00695 if (remClause->containsConstants())
00696 remClause->translateConstants((*domains_)[d-1], (*domains_)[d]);
00697 remIdx = (*mlns_)[d]->findClauseIdx(remClause);
00698 }
00699 idxs[d] = (*mlns_)[d]->getMLNClauseInfoIndexPtr(remIdx);
00700 remClauses[d] = (Clause*) (*mlns_)[d]->getClause(remIdx);
00701 }
00702 if (mlns_->size() > 1 && remClause->containsConstants())
00703 remClause->translateConstants((*domains_)[mlns_->size()-1],
00704 (*domains_)[0]);
00705 pllRemoveCountsForClause(remClauses, idxs, undoInfos);
00706
00707 if (op == OP_REMOVE) toBeRemovedClauses.append(cand);
00708 }
00709
00710
00711 if (op == OP_ADD || op == OP_REPLACE || op == OP_REPLACE_ADDPRED ||
00712 op == OP_REPLACE_REMPRED)
00713 {
00714 Array<int*> tmpIdxs; tmpIdxs.growToSize(mlns_->size());
00715 for (int d = 0; d < mlns_->size(); d++)
00716 {
00717 int* tmpClauseIdxInMLN = new int( (*mlns_)[d]->getNumClauses() +
00718 toBeAppendedClauses.size() );
00719 tmpIdxs[d] = tmpClauseIdxInMLN;
00720 idxPtrs.append(tmpClauseIdxInMLN);
00721 }
00722
00723 Array<UndoInfo*>* tmpInfos = (appendedClauseInfos)?
00724 new Array<UndoInfo*> : undoInfos;
00725 pllComputeCountsForClause(cand, tmpIdxs, tmpInfos);
00726 toBeAppendedClauses.append(cand);
00727
00728 if (appendedClauseInfos)
00729 {
00730 undoInfos->append(*tmpInfos);
00731 ClauseAndICDArray* ca = new ClauseAndICDArray;
00732 ca->clause = cand;
00733 appendedClauseInfos->append(ca);
00734 for (int j = 0; j < tmpInfos->size(); j++)
00735 {
00736 ca->icdArray.append(IndexCountDomainIdx());
00737 ca->icdArray.lastItem().iac=(*tmpInfos)[j]->affectedArr->lastItem();
00738 ca->icdArray.lastItem().domainIdx = (*tmpInfos)[j]->domainIdx;
00739 }
00740 delete tmpInfos;
00741 }
00742 }
00743 }
00744
00745
00746
00747
00748 if (indexTrans_)
00749 indexTrans_->getClauseFormulaIndexes(remClauseFormulaIdxs, 0);
00750
00751 Array<double> removedValues;
00752 if (resetPriors)
00753 {
00754 setPriorMeansStdDevs(priorMeans, priorStdDevs, toBeAppendedClauses.size(),
00755 &remClauseFormulaIdxs);
00756 if (indexTrans_) indexTrans_->appendClauseIdxToClauseFormulaIdxs(
00757 toBeAppendedClauses.size(), 1);
00758 }
00759 else
00760 {
00761 assert(priorMeans.size() == numClausesFormulas + candidates.size());
00762 assert(priorStdDevs.size() == numClausesFormulas + candidates.size());
00763 if (indexTrans_)
00764 assert(indexTrans_->checkCIdxWtsGradsSize(candidates.size()));
00765
00766 setRemoveAppendedPriorMeans(numClausesFormulas, priorMeans,
00767 remClauseFormulaIdxs,
00768 toBeAppendedClauses.size(), removedValues);
00769 }
00770
00771 if (hasPrior_)
00772 pll_->setMeansStdDevs(priorMeans.size(), priorMeans.getItems(),
00773 priorStdDevs.getItems());
00774 else
00775 pll_->setMeansStdDevs(-1, NULL, NULL);
00776
00777 wts->growToSize(numClausesFormulas + toBeAppendedClauses.size() + 1);
00778
00779 int iter; double elapsedSec;
00780 double newScore = maximizeScore(numClausesFormulas,
00781 toBeAppendedClauses.size(),
00782 wts, origWts, &remClauseFormulaIdxs,
00783 iter, error, elapsedSec);
00784
00785
00786 for (int i = 0; i < toBeAppendedClauses.size(); i++)
00787 toBeAppendedClauses[i]->setWt( (*wts)[numClausesFormulas+i+1] );
00788
00789 for (int i = 0; i < toBeRemovedClauses.size(); i++)
00790 toBeRemovedClauses[i]->setWt(111111111);
00791
00792 Array<double> penalties;
00793 for (int i = 0; i < candidates.size(); i++)
00794 penalties.append(getPenalty(candidates[i]));
00795
00796 if (error) { newScore=prevScore; cout<<"LBFGSB failed to find wts"<<endl; }
00797 printNewScore(candidates, (*domains_)[0], iter, elapsedSec,
00798 newScore, newScore-prevScore, penalties);
00799
00800 for (int i = 0; i < candidates.size(); i++)
00801 candidates[i]->getAuxClauseData()->gain = newScore-prevScore-penalties[i];
00802
00803 if (uundoInfos == NULL)
00804 {
00805 pll_->undoAppendRemoveCounts(undoInfos);
00806 delete undoInfos;
00807 }
00808
00809 if (resetPriors)
00810 {
00811 if (indexTrans_) indexTrans_->removeClauseIdxToClauseFormulaIdxs(
00812 toBeAppendedClauses.size(), 1);
00813 }
00814 else
00815 {
00816 for (int i = 0; i < remClauseFormulaIdxs.size(); i++)
00817 priorMeans[remClauseFormulaIdxs[i]] = removedValues[i];
00818 }
00819
00820 idxPtrs.deleteItemsAndClear();
00821
00822 return newScore;
00823 }
00824
00825
00826 double countAndMaxScoreEffectCandidate(Clause* const & candidate,
00827 Array<double>* const & wts,
00828 const Array<double>* const & origWts,
00829 const double& prevScore,
00830 const bool& resetPriors,
00831 Array<double>& priorMeans,
00832 Array<double>& priorStdDevs,
00833 bool& error,
00834 Array<UndoInfo*>* const& undoInfos,
00835 Array<ClauseAndICDArray*>* const & appendedClauseInfos)
00836 {
00837 Array<Clause*> candidates; candidates.append(candidate);
00838 return countAndMaxScoreEffectAllCandidates(candidates, wts, origWts,
00839 prevScore, resetPriors,
00840 priorMeans, priorStdDevs, error,
00841 undoInfos, appendedClauseInfos);
00842 }
00843
00844
00845
00846
00847 bool appendToAndRemoveFromMLNs(const Array<Clause*>& candidates,double& score,
00848 const bool& makeChangeForEqualScore=false)
00849 {
00850 Array<double> wts;
00851 Array<UndoInfo*> undoInfos;
00852 Array<ClauseAndICDArray*> appendedClauseInfos;
00853 Array<double> priorMeans, priorStdDevs;
00854 bool error;
00855
00856 double newScore
00857 = countAndMaxScoreEffectAllCandidates(candidates, &wts, NULL, score,
00858 true, priorMeans,priorStdDevs,error,
00859 &undoInfos, &appendedClauseInfos);
00860
00861 bool improve = (makeChangeForEqualScore) ? (newScore >= score)
00862 : (newScore > score);
00863
00864 if (!error && improve)
00865 {
00866 score = newScore;
00867
00868
00869
00870 updateWts(wts, NULL, NULL);
00871
00872 int numClausesFormulas = getNumClausesFormulas();
00873
00874 for (int i = 0; i < appendedClauseInfos.size(); i++)
00875 appendedClauseInfos[i]->clause->setWt(wts[numClausesFormulas+i+1]);
00876
00877
00878 int appClauseIdx = 0;
00879 for (int i = 0; i < candidates.size(); i++)
00880 {
00881 Clause* cand = candidates[i];
00882 AuxClauseData* acd = cand->getAuxClauseData();
00883 int op = acd->op;
00884
00885
00886
00887
00888 if (op == OP_REPLACE) cand->canonicalize();
00889
00890 if (op == OP_REMOVE || op == OP_REPLACE || op == OP_REPLACE_ADDPRED ||
00891 op == OP_REPLACE_REMPRED)
00892 {
00893 Clause* remClause = (Clause*) mln0_->getClause(acd->removedClauseIdx);
00894
00895 for (int d = 0; d < mlns_->size(); d++)
00896 {
00897 if (d == 0)
00898 {
00899 Clause* r = removeClauseFromMLN(acd->removedClauseIdx, d);
00900 cout << "Modified MLN: Removed clause from MLN: ";
00901 r->printWithoutWtWithStrVar(cout,(*domains_)[0]); cout << endl;
00902 if (op == OP_REMOVE && cand != r) delete cand;
00903
00904 }
00905 else
00906 {
00907 if (remClause->containsConstants())
00908 remClause->translateConstants((*domains_)[d-1], (*domains_)[d]);
00909 int remIdx = (*mlns_)[d]->findClauseIdx(remClause);
00910 delete removeClauseFromMLN(remIdx, d);
00911 }
00912 }
00913 delete remClause;
00914 }
00915
00916 if (op == OP_ADD || op == OP_REPLACE || op == OP_REPLACE_ADDPRED ||
00917 op == OP_REPLACE_REMPRED)
00918 {
00919 Array<int*> idxPtrs; idxPtrs.growToSize(mlns_->size());
00920 for (int d = 0; d < mlns_->size(); d++)
00921 {
00922 Clause* c = cand;
00923 if (d > 0)
00924 {
00925 if (cand->containsConstants())
00926 cand->translateConstants((*domains_)[d-1], (*domains_)[d]);
00927 c = new Clause(*cand);
00928 }
00929 int idx = appendClauseToMLN(c, d);
00930 idxPtrs[d] = (*mlns_)[d]->getMLNClauseInfoIndexPtr(idx);
00931 }
00932
00933 Array<IndexCountDomainIdx>& icds
00934 = appendedClauseInfos[appClauseIdx++]->icdArray;
00935 for (int j = 0; j < icds.size(); j++)
00936 icds[j].iac->index = idxPtrs[ icds[j].domainIdx ];
00937
00938 cout << "Modified MLN: Appended clause to MLN: ";
00939 cand->printWithoutWtWithStrVar(cout,(*domains_)[0]);
00940 cout << endl;
00941 }
00942 }
00943
00944 assert(pll_->checkNoRepeatedIndex());
00945 assert(appClauseIdx == appendedClauseInfos.size());
00946
00947 undoInfos.deleteItemsAndClear();
00948
00949
00950 if (indexTrans_) indexTrans_->createClauseIdxToClauseFormulaIdxsMap();
00951 }
00952 else
00953 {
00954 cout << "undoing candidates because score did not improve..."<<endl<<endl;
00955 pll_->undoAppendRemoveCounts(&undoInfos);
00956 }
00957
00958 appendedClauseInfos.deleteItemsAndClear();
00959 return improve;
00960 }
00961
00962
00963
00964 bool appendToAndRemoveFromMLNs(Clause* const & candidate, double& score,
00965 const bool& makeChangeForEqualScore=false)
00966 {
00967 Array<Clause*> candidates; candidates.append(candidate);
00968 return appendToAndRemoveFromMLNs(candidates, score,makeChangeForEqualScore);
00969 }
00970
00971
00972
00973 void pllComputeCountsForClause(Clause* const & c,
00974 const Array<int*>& clauseIdxInMLNs,
00975 Array<UndoInfo*>* const & undoInfos)
00976 {
00977 assert(c->getAuxClauseData()->cache == NULL);
00978 assert(clauseIdxInMLNs.size() == domains_->size());
00979 double begSec = timer_.time();
00980
00981 if (cacheClauses_)
00982 {
00983 int i;
00984 if ((i = cachedClauses_->find(c)) >= 0)
00985 {
00986
00987
00988
00989 pll_->insertCounts(clauseIdxInMLNs, undoInfos,
00990 (*cachedClauses_)[i]->getAuxClauseData()->cache);
00991 cout << "using cached counts took ";
00992 timer_.printTime(cout, timer_.time()-begSec); cout << endl;
00993 return;
00994 }
00995 else
00996 {
00997 assert(c->getAuxClauseData()->cache == NULL);
00998 if (cacheSizeMB_ < maxCacheSizeMB_)
00999 c->newCache(domains_->size(), (*domains_)[0]->getNumPredicates());
01000 else
01001 {
01002 static bool printCacheFull = true;
01003 if (printCacheFull)
01004 {
01005 cout << "Cache is full, approximate size = " << cacheSizeMB_
01006 << " MB" << endl;
01007 printCacheFull = false;
01008 }
01009 }
01010
01011 }
01012 }
01013
01014
01015
01016
01017 if (!c->containsConstants())
01018 {
01019 for (int i = 0; i < domains_->size(); i++)
01020 {
01021 int* clauseIdxInMLN = clauseIdxInMLNs[i];
01022 pll_->computeCountsForNewAppendedClause(c, clauseIdxInMLN, i,
01023 undoInfos, sampleClauses_,
01024 c->getAuxClauseData()->cache);
01025 }
01026 }
01027 else
01028 {
01029 int i;
01030 for (i = 0; i < domains_->size(); i++)
01031 {
01032 if (i > 0) c->translateConstants((*domains_)[i-1], (*domains_)[i]);
01033 int* clauseIdxInMLN = clauseIdxInMLNs[i];
01034 pll_->computeCountsForNewAppendedClause(c, clauseIdxInMLN, i,
01035 undoInfos, sampleClauses_,
01036 c->getAuxClauseData()->cache);
01037 }
01038 if (i > 1) c->translateConstants((*domains_)[i-1], (*domains_)[0]);
01039 }
01040
01041
01042 if (c->getAuxClauseData()->cache)
01043 {
01044 if (cacheSizeMB_ < maxCacheSizeMB_)
01045 {
01046 cacheSizeMB_ += c->sizeMB();
01047 Array<Array<Array<CacheCount*>*>*>* cache =c->getAuxClauseData()->cache;
01048 c->getAuxClauseData()->cache = NULL;
01049 Clause* copyClause = new Clause(*c);
01050 copyClause->getAuxClauseData()->cache = cache;
01051 copyClause->compress();
01052 cachedClauses_->append(copyClause);
01053 }
01054 else
01055 {
01056 c->getAuxClauseData()->deleteCache();
01057 c->getAuxClauseData()->cache = NULL;
01058 }
01059 }
01060
01061 cout << "Computing counts took ";
01062 timer_.printTime(cout, timer_.time()-begSec);
01063
01064
01065 cout << endl;
01066 }
01067
01068
01069 void pllRemoveCountsForClause(const Array<Clause*>& remClauses,
01070 const Array<int*>& clauseIdxInMLNs,
01071 Array<UndoInfo*>* const & undoInfos)
01072 {
01073 assert(clauseIdxInMLNs.size() == domains_->size());
01074 double begSec = timer_.time();
01075 for (int i = 0; i < domains_->size(); i++)
01076 pll_->removeCountsForClause(remClauses[i],clauseIdxInMLNs[i],i,undoInfos);
01077 cout << "Removing counts took ";
01078 timer_.printTime(cout, timer_.time()-begSec);
01079 cout << endl;
01080 }
01081
01082
01083
01084 void pllComputeCountsForInitialMLN()
01085 {
01086 for (int i = 0; i < mlns_->size(); i++)
01087 {
01088 cout << "computing counts for clauses in domain " << i << "..." << endl;
01089 MLN* mln = (*mlns_)[i];
01090 for (int j = 0; j < mln->getNumClauses(); j++)
01091 {
01092 Clause* c = (Clause*)mln->getClause(j);
01093 cout << "Clause " << j << ": ";
01094 c->printWithoutWtWithStrVar(cout, (*domains_)[i]); cout << endl;
01095 int* clauseIdxInMLN = mln->getMLNClauseInfoIndexPtr(j);
01096 pll_->computeCountsForNewAppendedClause(c, clauseIdxInMLN, i,
01097 NULL, sampleClauses_, NULL);
01098 }
01099 }
01100 }
01101
01102
01104 private:
01105 void addUnitClausesToMLNs()
01106 {
01107 Array<Predicate*> nonEvidPreds;
01108 for (int i = 0; i < preds_->size(); i++)
01109 if ((*areNonEvidPreds_)[(*preds_)[i]->getId()])
01110 nonEvidPreds.append((*preds_)[i]);
01111
01112 Array<Clause*> unitClauses;
01113 bool allowEqualPreds = true;
01114 ClauseFactory::createUnitClauses(unitClauses,nonEvidPreds,allowEqualPreds);
01115
01116 for (int i = 0; i < unitClauses.size(); i++)
01117 {
01118 if (mln0_->containsClause(unitClauses[i]))
01119 { delete unitClauses[i]; continue; }
01120 ostringstream oss; int idx;
01121 unitClauses[i]->printWithoutWtWithStrVar(oss, (*domains_)[0]);
01122
01123 for (int j = 0; j < mlns_->size(); j++)
01124 {
01125 Clause* c = (j == 0) ? unitClauses[i] : new Clause(*unitClauses[i]);
01126 (*mlns_)[j]->appendClause(oss.str(), false, c, priorMean_, false, idx,
01127 false);
01128 ((MLNClauseInfo*)(*mlns_)[j]->getMLNClauseInfo(idx))->priorMean
01129 = priorMean_;
01130 }
01131 }
01132 }
01133
01134
01135 void appendUnitClausesWithDiffCombOfVar(double& score)
01136 {
01137 bool allowEqualPreds = false;
01138 for (int i = 0; i < preds_->size(); i++)
01139 {
01140 if (!(*areNonEvidPreds_)[(*preds_)[i]->getId()]) continue;
01141
01142 Clause* origClause = ClauseFactory::createUnitClause((*preds_)[i],
01143 allowEqualPreds);
01144 if (origClause == NULL) continue;
01145 assert(origClause->getAuxClauseData() == NULL);
01146 origClause->setAuxClauseData(new AuxClauseData);
01147 origClause->trackConstants();
01148
01149 ClauseOpHashArray newUnitClauses;
01150 clauseFactory_->createUnitClausesWithDiffCombOfVar((*preds_)[i], OP_ADD,
01151 -1, newUnitClauses);
01152 for (int j = 0; j < newUnitClauses.size(); j++)
01153 {
01154 Clause* newClause = newUnitClauses[j];
01155
01156 if (origClause->same(newClause) ||
01157 !clauseFactory_->validClause(newClause) ||
01158 mln0_->containsClause(newClause))
01159 {
01160 newUnitClauses.removeItemFastDisorder(j);
01161 delete newClause;
01162 j--;
01163 continue;
01164 }
01165
01166 if (!appendToAndRemoveFromMLNs(newClause, score)) delete newClause;
01167 }
01168 delete origClause;
01169 }
01170 }
01171
01172
01173 void flipMLNClauses(double& score)
01174 {
01175 Array<Clause*> mlnClauses;
01176 for (int i = 0; i < mln0_->getNumClauses(); i++)
01177 {
01178
01179 if (mln0_->getClause(i)->getNumPredicates()==1 || !isModifiableClause(i))
01180 continue;
01181 mlnClauses.append((Clause*)mln0_->getClause(i));
01182 }
01183
01184 for (int i = 0; i < mlnClauses.size(); i++)
01185 {
01186 Clause* origClause = mlnClauses[i];
01187 int origIdx = mln0_->findClauseIdx(origClause);
01188
01189
01190
01191 bool canonicalizeNewClauses = false;
01192 ClauseOpHashArray newClauses;
01193 clauseFactory_->flipSensesInClause(origClause, OP_REPLACE, origIdx,
01194 newClauses, canonicalizeNewClauses);
01195
01196 Array<double> priorMeans, priorStdDevs;
01197 setPriorMeansStdDevs(priorMeans, priorStdDevs, 1, NULL);
01198 Array<double> wts;
01199 wts.growToSize(getNumClausesFormulas()+2);
01200 Clause* bestClause = NULL;
01201 double bestScore = score;
01202 bool error;
01203
01204 for (int j = 0; j < newClauses.size(); j++)
01205 {
01206 Clause* newClause = newClauses[j];
01207 if (origClause->same(newClause) || newClause->hasRedundantPredicates()
01208 || mln0_->containsClause(newClause)) { delete newClause; continue; }
01209
01210 double newScore
01211 = countAndMaxScoreEffectCandidate(newClause, &wts, NULL, score, true,
01212 priorMeans, priorStdDevs, error,
01213 NULL, NULL);
01214 if (newScore > bestScore)
01215 {
01216 bestScore = newScore;
01217 if (bestClause) delete bestClause;
01218 bestClause = newClause;
01219 }
01220 else
01221 delete newClause;
01222 }
01223
01224 if (bestClause && !appendToAndRemoveFromMLNs(bestClause, score))
01225 delete bestClause;
01226 }
01227 }
01228
01229
01230
01231
01232 void pruneMLN(double& score)
01233 {
01234 Array<Clause*> mlnClauses;
01235 for (int i = 0; i < mln0_->getNumClauses(); i++)
01236 {
01237
01238 if (mln0_->getClause(i)->getNumPredicates() == 1 ||
01239 !isModifiableClause(i)) continue;
01240 mlnClauses.append((Clause*)mln0_->getClause(i));
01241 }
01242
01243 for (int i = 0; i < mlnClauses.size(); i++)
01244 {
01245 Clause* origClause = mlnClauses[i];
01246 int origIdx = mln0_->findClauseIdx(origClause);
01247 Clause* copy = new Clause(*origClause);
01248 copy->setAuxClauseData(new AuxClauseData(0, OP_REMOVE, origIdx));
01249 copy->trackConstants();
01250 if (!appendToAndRemoveFromMLNs(copy, score, true)) delete copy;
01251 }
01252 }
01253
01254
01255 int appendClauseToMLN(Clause* const c, const int& domainIdx)
01256 {
01257 ostringstream oss; int idx;
01258 c->printWithoutWtWithStrVar(oss, (*domains_)[domainIdx]);
01259 MLN* mln = (*mlns_)[domainIdx];
01260 bool ok = mln->appendClause(oss.str(), false, c, c->getWt(), false, idx,
01261 false);
01262 if (!ok) { cout << "ERROR: failed to insert " << oss.str() <<" into MLN"
01263 << endl; exit(-1); }
01264 ((MLNClauseInfo*)mln->getMLNClauseInfo(idx))->priorMean = priorMean_;
01265 return idx;
01266 }
01267
01268
01269 Clause* removeClauseFromMLN(const int& remClauseIdx, const int& domainIdx)
01270 {
01271 Clause* remClause = (*mlns_)[domainIdx]->removeClause(remClauseIdx);
01272 if (remClause == NULL)
01273 {
01274 cout << "ERROR: failed to remove ";
01275 remClause->printWithoutWtWithStrVar(cout, (*domains_)[0]);
01276 cout << " from MLN" << endl;
01277 exit(-1);
01278 }
01279 return remClause;
01280 }
01281
01282
01284 private:
01285 void addPredicateToClause(Clause* const & beamClause, const int& op,
01286 const int& removeClauseIdx,
01287 ClauseOpHashArray& newClauses)
01288 {
01289
01290 clauseFactory_->addPredicateToClause(*preds_,beamClause,op,removeClauseIdx,
01291 true,newClauses,false);
01292
01293
01294 if (beamClause->getNumPredicates() == 1)
01295 {
01296 beamClause->getPredicate(0)->invertSense();
01297 clauseFactory_->addPredicateToClause(*preds_, beamClause, op,
01298 removeClauseIdx, true, newClauses,
01299 false);
01300 beamClause->getPredicate(0)->invertSense();
01301 }
01302 }
01303
01304
01305
01306 void removePredicateFromClause(Clause* const & beamClause, const int& op,
01307 const int& removeClauseIdx,
01308 ClauseOpHashArray& newClauses)
01309 {
01310 if (beamClause->getNumPredicates() > 2)
01311 clauseFactory_->removePredicateFromClause(beamClause, op, removeClauseIdx,
01312 newClauses);
01313 }
01314
01315
01316 void addNewClauseToCandidates(Clause* const & newClause,
01317 Array<Clause*>& candidates,
01318 Array<Clause*>* const & thrownOut)
01319 {
01320
01321 if (thrownOut && mln0_->containsClause(newClause))
01322 { thrownOut->append(newClause); return;}
01323 newClause->setWt(0);
01324 candidates.append(newClause);
01325 }
01326
01327
01328 void createCandidateClauses(const ClauseOpHashArray* const & beam,
01329 Array<Clause*>& candidates,
01330 const Array<Clause*>* const & initMLNClauses)
01331 {
01332
01333 Array<Clause*> thrownOutClauses;
01334 ClauseOpHashArray newClauses;
01335 for (int i = 0; i < beam->size(); i++)
01336 {
01337 Clause* beamClause = (*beam)[i];
01338 AuxClauseData* beamacd = beamClause->getAuxClauseData();
01339 int op = beamacd->op;
01340 int newClausesBegIdx = newClauses.size();
01341
01342 if (op == OP_ADD)
01343 {
01344 int remIdx = beamacd->removedClauseIdx;
01345 addPredicateToClause(beamClause, op, remIdx, newClauses);
01346 removePredicateFromClause(beamClause, op, remIdx, newClauses);
01347 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01348 addNewClauseToCandidates(newClauses[j], candidates,&thrownOutClauses);
01349 }
01350 else
01351 if (op == OP_REPLACE_ADDPRED)
01352 {
01353 addPredicateToClause(beamClause, op, beamacd->removedClauseIdx,
01354 newClauses);
01355 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01356 addNewClauseToCandidates(newClauses[j], candidates,&thrownOutClauses);
01357 }
01358 else
01359 if (op == OP_REPLACE_REMPRED)
01360 {
01361 removePredicateFromClause(beamClause, op, beamacd->removedClauseIdx,
01362 newClauses);
01363 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01364 addNewClauseToCandidates(newClauses[j], candidates,&thrownOutClauses);
01365 }
01366 else
01367 if (op == OP_NONE)
01368 {
01369 int idx = mln0_->findClauseIdx(beamClause);
01370 bool beamClauseInMLN = (idx >= 0);
01371 bool removeBeamClause = (beamClauseInMLN &&
01372 beamClause->getNumPredicates() > 1);
01373 bool isModClause = (!beamClauseInMLN || isModifiableClause(idx));
01374
01375 if (isModClause)
01376 {
01377 addPredicateToClause(beamClause, OP_ADD, -1, newClauses);
01378 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01379 addNewClauseToCandidates(newClauses[j], candidates,
01380 &thrownOutClauses);
01381
01382 if (removeBeamClause)
01383 {
01384 newClausesBegIdx = newClauses.size();
01385 addPredicateToClause(beamClause, OP_REPLACE_ADDPRED, idx,
01386 newClauses);
01387 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01388 addNewClauseToCandidates(newClauses[j], candidates,
01389 &thrownOutClauses);
01390 }
01391
01392 newClausesBegIdx = newClauses.size();
01393 removePredicateFromClause(beamClause, OP_ADD, -1, newClauses);
01394 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01395 addNewClauseToCandidates(newClauses[j], candidates,
01396 &thrownOutClauses);
01397
01398 if (removeBeamClause)
01399 {
01400 newClausesBegIdx = newClauses.size();
01401 removePredicateFromClause(beamClause, OP_REPLACE_REMPRED, idx,
01402 newClauses);
01403 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01404 addNewClauseToCandidates(newClauses[j], candidates,
01405 &thrownOutClauses);
01406 }
01407
01408 if (removeBeamClause)
01409 {
01410 Clause* c = new Clause(*beamClause);
01411 c->getAuxClauseData()->op = OP_REMOVE;
01412 c->getAuxClauseData()->removedClauseIdx = idx;
01413 addNewClauseToCandidates(c, candidates, NULL);
01414 }
01415 }
01416 }
01417 else
01418 assert(op == OP_REMOVE || op == OP_REPLACE);
01419 }
01420
01421
01422 if (initMLNClauses)
01423 {
01424
01425
01426 int newClausesBegIdx = newClauses.size();
01427 for (int i = 0; i < initMLNClauses->size(); i++)
01428 {
01429 Clause* c = new Clause(*((*initMLNClauses)[i]));
01430 if (newClauses.append(c) < 0) delete c;
01431 }
01432 for (int i = newClausesBegIdx; i < newClauses.size(); i++)
01433 addNewClauseToCandidates(newClauses[i], candidates, &thrownOutClauses);
01434 }
01435
01436 for (int i = 0; i < thrownOutClauses.size(); i++)
01437 delete thrownOutClauses[i];
01438 }
01439
01440
01442
01443 private:
01444 void useTightParams()
01445 {
01446 sampleClauses_ = false;
01447 pll_->setSampleGndPreds(false);
01448 lbfgsb_->setMaxIter(lbMaxIter_);
01449 lbfgsb_->setFtol(lbConvThresh_);
01450 cacheClauses_ = false;
01451 }
01452
01453
01454 void useLooseParams()
01455 {
01456 sampleClauses_ = origSampleClauses_;
01457 pll_->setSampleGndPreds(sampleGndPreds_);
01458 if (looseMaxIter_ >= 0) lbfgsb_->setMaxIter(looseMaxIter_);
01459 if (looseConvThresh_ >= 0) lbfgsb_->setFtol(looseConvThresh_);
01460 cacheClauses_ = origCacheClauses_;
01461 }
01462
01463
01464
01465 bool isModifiableClause(const int& i) const
01466 { return (!mln0_->isExistClause(i) && !mln0_->isExistUniqueClause(i)); }
01467
01468
01469 bool isNonModifiableFormula(const FormulaAndClauses* const & fnc) const
01470 { return (fnc->hasExist || fnc->isExistUnique); }
01471
01472
01473 void printIterAndTimeElapsed(const double& begSec)
01474 {
01475 cout << "Iteration " << iter_ << " took ";
01476 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
01477 cout << "Time elapsed = ";
01478 timer_.printTime(cout, timer_.time()-startSec_); cout << endl << endl;
01479 }
01480
01481
01482 void removeClausesFromMLNs()
01483 {
01484 for (int i = 0; i < mlns_->size(); i++)
01485 (*mlns_)[i]->removeAllClauses(NULL);
01486 }
01487
01488
01489 void reEvaluateCandidates(Array<Clause*>& candidates,
01490 const double& prevScore)
01491 {
01492 int numClausesFormulas = getNumClausesFormulas();
01493 Array<double> wts;
01494 wts.growToSize(numClausesFormulas + 2);
01495
01496 Array<double> priorMeans, priorStdDevs;
01497 setPriorMeansStdDevs(priorMeans, priorStdDevs, 1, NULL);
01498 if (indexTrans_) indexTrans_->appendClauseIdxToClauseFormulaIdxs(1, 1);
01499
01500 bool error;
01501 for (int i = 0; i < candidates.size(); i++)
01502 {
01503 candidates[i]->getAuxClauseData()->gain = 0;
01504 countAndMaxScoreEffectCandidate(candidates[i], &wts, NULL, prevScore,
01505 true, priorMeans, priorStdDevs, error,
01506 NULL, NULL);
01507 }
01508
01509 if (indexTrans_) indexTrans_->removeClauseIdxToClauseFormulaIdxs(1, 1);
01510
01511 Array<Clause*> tmpCand(candidates);
01512 rquicksort(tmpCand);
01513 candidates.clear();
01514 for (int i = 0; i < tmpCand.size(); i++)
01515 {
01516 if (tmpCand[i]->getAuxClauseData()->gain > minGain_ &&
01517 fabs(tmpCand[i]->getWt()) >= minWt_)
01518 candidates.append(tmpCand[i]);
01519 else
01520 delete tmpCand[i];
01521 }
01522
01523 cout << "reevaluated top " << candidates.size()
01524 << " candidates with tight params:" << endl;
01525 cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << endl;
01526 for (int i = 0; i < candidates.size(); i++)
01527 {
01528 cout << i << "\t";
01529 candidates[i]->printWithoutWtWithStrVar(cout,(*domains_)[0]);
01530 cout << endl
01531 << "\tgain = " << candidates[i]->getAuxClauseData()->gain
01532 << ", wt = " << candidates[i]->getWt()
01533 << ", op = "
01534 << Clause::getOpAsString(candidates[i]->getAuxClauseData()->op)
01535 << endl;
01536 }
01537 cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << endl << endl;;
01538 }
01539
01540
01541
01542
01543
01544 bool findBestClausesAmongCandidates(Array<Clause*>& candidates,
01545 ClauseOpHashArray* const & beam,
01546 Array<Clause*>& bestClauses)
01547 {
01548 assert(beam->size() == 0);
01549
01550
01551
01552
01553 rquicksort(candidates);
01554
01555 for (int i = 0; i < candidates.size(); i++)
01556 {
01557 if (beam->size() >= beamSize_) break;
01558 double candGain = candidates[i]->getAuxClauseData()->gain;
01559 double candAbsWt = fabs(candidates[i]->getWt());
01560 if (candGain > minGain_ && candAbsWt >= minWt_)
01561 {
01562 int a = beam->append(new Clause(*(candidates[i])));
01563 assert(a >= 0); a = 0;
01564 }
01565 }
01566
01567 if (beam->size() == 0)
01568 {
01569 for (int i = 0; i < candidates.size(); i++) delete candidates[i];
01570 return false;
01571 }
01572
01573
01574 Clause* prevBestClause = (bestClauses.size() > 0) ?
01575 new Clause(*(bestClauses[0])) : NULL;
01576
01577
01578
01579 ClauseOpSet cset;
01580 ClauseOpSet::iterator it;
01581 for (int i = 0; i < candidates.size(); i++) cset.insert(candidates[i]);
01582
01583 for (int i = 0; i < bestClauses.size(); i++)
01584 {
01585 if ((it=cset.find(bestClauses[i])) == cset.end())
01586 {
01587 candidates.append(bestClauses[i]);
01588 cset.insert(bestClauses[i]);
01589 }
01590 else
01591 {
01592 assert((*it)->getAuxClauseData()->gain ==
01593 bestClauses[i]->getAuxClauseData()->gain);
01594 delete bestClauses[i];
01595 }
01596 }
01597
01598 rquicksort(candidates);
01599
01600 bestClauses.clear();
01601 for (int i = 0; i < candidates.size(); i++)
01602 {
01603 if (bestClauses.size() < numEstBestClauses_)
01604 {
01605 double candGain = candidates[i]->getAuxClauseData()->gain;
01606 double candAbsWt = fabs(candidates[i]->getWt());
01607 if (candGain > minGain_ && candAbsWt >= minWt_)
01608 bestClauses.append(candidates[i]);
01609 else
01610 {
01611
01612
01613
01614
01615
01616
01617 delete candidates[i];
01618 }
01619 }
01620 else
01621 delete candidates[i];
01622 }
01623
01624
01625 bool bestClauseChanged;
01626 if (bestClauses.size() > 0)
01627 {
01628 if (prevBestClause == NULL) bestClauseChanged = true;
01629 else
01630 {
01631
01632 if (bestClauses[0]->same(prevBestClause) &&
01633 bestClauses[0]->getAuxClauseData()->op ==
01634 prevBestClause->getAuxClauseData()->op)
01635 bestClauseChanged = false;
01636 else
01637 {
01638
01639 if (bestClauses[0]->getAuxClauseData()->gain >
01640 prevBestClause->getAuxClauseData()->gain)
01641 bestClauseChanged = true;
01642 else
01643 bestClauseChanged = false;
01644 }
01645 }
01646 }
01647 else
01648 bestClauseChanged = false;
01649
01650 if (prevBestClause) delete prevBestClause;
01651 return bestClauseChanged;
01652 }
01653
01654
01655 bool effectBestCandidateOnMLNs(Clause* const & cand, double& score)
01656 {
01657 if (appendToAndRemoveFromMLNs(cand, score))
01658 {
01659 printMLNClausesWithWeightsAndScore(score, iter_);
01660
01661 return true;
01662 }
01663 return false;
01664 }
01665
01666
01667 bool effectBestCandidateOnMLNs(Array<Clause*>& bestCandidates, double& score)
01668 {
01669 cout << "effecting best candidate on MLN..." << endl << endl;
01670 bool ok = false;
01671 int i;
01672 for (i = 0; i < bestCandidates.size(); i++)
01673 {
01674 cout << "effecting best candidate " << i << " on MLN..." << endl << endl;
01675 if (ok=effectBestCandidateOnMLNs(bestCandidates[i], score)) break;
01676 cout << "failed to effect candidate on MLN." << endl;
01677 delete bestCandidates[i];
01678 }
01679 for (int j = i+1; j < bestCandidates.size();j++) delete bestCandidates[j];
01680 return ok;
01681 }
01682
01683
01684 double getPenalty(const Clause* const & cand)
01685 {
01686 int op = cand->getAuxClauseData()->op;
01687
01688 if (op == OP_ADD) return cand->getNumPredicates() * penalty_;
01689
01690 if (op == OP_REPLACE_ADDPRED)
01691 {
01692 int remIdx = cand->getAuxClauseData()->removedClauseIdx;
01693 int origlen = mln0_->getClause(remIdx)->getNumPredicates();
01694 int diff = cand->getNumPredicates() - origlen;
01695 assert(diff > 0);
01696 return diff * penalty_;
01697 }
01698
01699 if (op == OP_REPLACE_REMPRED)
01700 {
01701 int remIdx = cand->getAuxClauseData()->removedClauseIdx;
01702 int origlen = mln0_->getClause(remIdx)->getNumPredicates();
01703 int diff = origlen - cand->getNumPredicates();
01704 assert(diff > 0);
01705 return diff * penalty_;
01706 }
01707
01708 if (op == OP_REPLACE)
01709 {
01710
01711
01712 int remIdx = cand->getAuxClauseData()->removedClauseIdx;
01713 const Clause* mlnClause = mln0_->getClause(remIdx);
01714 assert(cand->getNumPredicates() == mlnClause->getNumPredicates());
01715 int diff = 0;
01716 for (int i = 0; i < cand->getNumPredicates(); i++)
01717 {
01718 Predicate* cpred = (Predicate*) cand->getPredicate(i);
01719 Predicate* mpred = (Predicate*) mlnClause->getPredicate(i);
01720 assert(cpred->same(mpred));
01721 if (cpred->getSense() != mpred->getSense()) diff++;
01722 }
01723 return diff * penalty_;
01724 }
01725
01726 if (op == OP_REMOVE) return cand->getNumPredicates() * penalty_;
01727
01728 assert(false);
01729 return 88888;
01730 }
01731
01732
01733 void printNewScore(const Array<Clause*>& carray, const Domain* const & domain,
01734 const int& lbfgsbIter, const double& lbfgsbSec,
01735 const double& newScore, const double& gain,
01736 const Array<double>& penalties)
01737 {
01738 cout << "*************************** " << candCnt_++ << ", iter " << iter_
01739 << ", beam search iter " << bsiter_ << endl;
01740 for (int i = 0; i < carray.size(); i++)
01741 {
01742 if (carray[i])
01743 {
01744 cout << "candidate : ";
01745 carray[i]->printWithoutWtWithStrVar(cout,domain); cout << endl;
01746 cout << "op : ";
01747 cout << Clause::getOpAsString(carray[i]->getAuxClauseData()->op) <<endl;
01748 cout << "removed clause: ";
01749 int remIdx = carray[i]->getAuxClauseData()->removedClauseIdx;
01750 if (remIdx < 0) cout << "NULL";
01751 else { mln0_->getClause(remIdx)->printWithoutWtWithStrVar(cout,domain);}
01752 cout << endl;
01753 if (carray[i]->getAuxClauseData()->prevClauseStr.length() > 0)
01754 {
01755 cout << "prevClause : ";
01756 cout << carray[i]->getAuxClauseData()->prevClauseStr << endl;
01757 }
01758 if (carray[i]->getAuxClauseData()->addedPredStr.length() > 0)
01759 {
01760 cout << "addedPred : ";
01761 cout << carray[i]->getAuxClauseData()->addedPredStr << endl;
01762 }
01763 if (carray[i]->getAuxClauseData()->removedPredIdx >= 0)
01764 {
01765 cout << "removedPredIdx: ";
01766 cout << carray[i]->getAuxClauseData()->removedPredIdx << endl;
01767 }
01768
01769 cout << "score : " << newScore << endl;
01770 cout << "gain : " << gain << endl;
01771 cout << "penalty : " << penalties[i] << endl;
01772 cout << "net gain : " << gain - penalties[i] << endl;
01773 if (carray[i]->getAuxClauseData()->op != OP_REMOVE)
01774 cout << "wt : " << carray[i]->getWt() << endl;
01775 }
01776 }
01777 cout << "num LBFGSB iter = " << lbfgsbIter << endl;
01778 cout << "time taken by LBFGSB = "; Timer::printTime(cout, lbfgsbSec);
01779 cout << endl;
01780 cout << "*************************** " << endl << endl;;
01781 }
01782
01783
01784 void printNewScore(const Clause* const & c, const Domain* const & domain,
01785 const int& lbfgsbIter, const double& lbfgsbSec,
01786 const double& newScore, const double& gain,
01787 const double& penalty)
01788 {
01789 Array<Clause*> carray;
01790 Array<double> parray;
01791 if (c) { carray.append((Clause*)c); parray.append(penalty); }
01792 printNewScore(carray, domain, lbfgsbIter, lbfgsbSec, newScore, gain,parray);
01793 }
01794
01795
01796 void printMLNClausesWithWeightsAndScore(const double& score, const int& iter)
01797 {
01798 if (iter >= 0) cout << "MLN in iteration " << iter << ":" << endl;
01799 else cout << "MLN:" << endl;
01800 cout << "------------------------------------" << endl;
01801 if (indexTrans_) indexTrans_->printClauseFormulaWts(cout, true);
01802 else mln0_->printMLNClausesFormulas(cout, (*domains_)[0], true);
01803 cout << "------------------------------------" << endl;
01804 cout << "score = "<< score << endl << endl;
01805 }
01806
01807
01808 void printMLNToFile(const char* const & appendStr, const int& iter)
01809 {
01810 string fname = outMLNFileName_;
01811
01812 if (appendStr) fname.append(".").append(appendStr);
01813
01814 if (iter >= -1)
01815 {
01816 char buf[100];
01817 sprintf(buf, "%d", iter);
01818 fname.append(".iter").append(buf);
01819 }
01820
01821 if (appendStr || iter >= -1) fname.append(".mln");
01822
01823 ofstream out(fname.c_str());
01824 if (!out.good()) { cout << "ERROR: failed to open " <<fname<<endl;exit(-1);}
01825
01826
01827 out << "//predicate declarations" << endl;
01828 (*domains_)[0]->printPredicateTemplates(out);
01829 out << endl;
01830
01831
01832 if ((*domains_)[0]->getNumFunctions() > 0)
01833 {
01834 out << "//function declarations" << endl;
01835 (*domains_)[0]->printFunctionTemplates(out);
01836 out << endl;
01837 }
01838
01839 if (indexTrans_) indexTrans_->printClauseFormulaWts(out, false);
01840 else mln0_->printMLNClausesFormulas(out, (*domains_)[0], false);
01841
01842 out << endl;
01843 out.close();
01844 }
01845
01846
01847 void printClausesInBeam(const ClauseOpHashArray* const & beam)
01848 {
01849 cout.setf(ios_base::left, ios_base::adjustfield);
01850 cout << "^^^^^^^^^^^^^^^^^^^ beam ^^^^^^^^^^^^^^^^^^^" << endl;
01851 for (int i = 0; i < beam->size(); i++)
01852 {
01853 cout << i << ": ";
01854 cout.width(10); cout << (*beam)[i]->getWt(); cout.width(0); cout << " ";
01855 (*beam)[i]->printWithoutWtWithStrVar(cout, (*domains_)[0]); cout << endl;
01856 }
01857 cout.width(0);
01858 cout << "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << endl;
01859
01860 }
01861
01862 void setPriorMeansStdDevs(Array<double>& priorMeans,
01863 Array<double>& priorStdDevs, const int& addSlots,
01864 const Array<int>* const& removedSlotIndexes)
01865 {
01866 priorMeans.clear(); priorStdDevs.clear();
01867
01868 if (indexTrans_)
01869 {
01870 indexTrans_->setPriorMeans(priorMeans);
01871 priorStdDevs.growToSize(priorMeans.size());
01872 for (int i = 0; i < priorMeans.size(); i++)
01873 priorStdDevs[i] = priorStdDev_;
01874 }
01875 else
01876 {
01877 for (int i = 0; i < mln0_->getNumClauses(); i++)
01878 {
01879 priorMeans.append(mln0_->getMLNClauseInfo(i)->priorMean);
01880 priorStdDevs.append(priorStdDev_);
01881 }
01882 }
01883
01884 if (removedSlotIndexes)
01885 {
01886 for (int i = 0; i < removedSlotIndexes->size(); i++)
01887 priorMeans[ (*removedSlotIndexes)[i] ] = 0;
01888 }
01889
01890 for (int i = 0; i < addSlots; i++)
01891 {
01892 priorMeans.append(priorMean_);
01893 priorStdDevs.append(priorStdDev_);
01894 }
01895 }
01896
01897
01898 void pllSetPriorMeansStdDevs(Array<double>& priorMeans,
01899 Array<double>& priorStdDevs, const int& addSlots,
01900 const Array<int>* const & removedSlotIndexes)
01901 {
01902 if (hasPrior_)
01903 {
01904 setPriorMeansStdDevs(priorMeans, priorStdDevs,
01905 addSlots, removedSlotIndexes);
01906 pll_->setMeansStdDevs(priorMeans.size(), priorMeans.getItems(),
01907 priorStdDevs.getItems());
01908 }
01909 else
01910 pll_->setMeansStdDevs(-1, NULL, NULL);
01911 }
01912
01913
01914 void setRemoveAppendedPriorMeans(const int& numClausesFormulas,
01915 Array<double>& priorMeans,
01916 const Array<int>& removedSlotIndexes,
01917 const int& addSlots,
01918 Array<double>& removedValues)
01919 {
01920 for (int i = 0; i < removedSlotIndexes.size(); i++)
01921 {
01922 removedValues.append(priorMeans[ removedSlotIndexes[i] ]);
01923 priorMeans[ removedSlotIndexes[i] ] = 0;
01924 }
01925
01926 for (int i = 0; i < addSlots; i++)
01927 priorMeans[numClausesFormulas+i] = priorMean_;
01928 }
01929
01930
01931 int getNumClausesFormulas()
01932 {
01933 if (indexTrans_) return indexTrans_->getNumClausesAndExistFormulas();
01934 return mln0_->getNumClauses();
01935 }
01936
01937
01938 void updateWts(const Array<double>& wts,
01939 const Array<Clause*>* const & appendedClauses,
01940 const Array<string>* const & appendedFormulas)
01941 {
01942 if (indexTrans_)
01943 {
01944 Array<double> tmpWts;
01945 tmpWts.growToSize(wts.size()-1);
01946 for (int i = 1; i < wts.size(); i++) tmpWts[i-1] = wts[i];
01947 indexTrans_->updateClauseFormulaWtsInMLNs(tmpWts, appendedClauses,
01948 appendedFormulas);
01949 }
01950 else
01951 {
01952 for (int i = 0; i < mlns_->size(); i++)
01953 for (int j = 0; j < (*mlns_)[i]->getNumClauses(); j++)
01954 ((Clause*)(*mlns_)[i]->getClause(j))->setWt(wts[j+1]);
01955 }
01956 }
01957
01958
01960
01961 private:
01962
01963 void rquicksort(Array<Clause*>& clauses, const int& l, const int& r)
01964 {
01965 Clause** items = (Clause**) clauses.getItems();
01966 if (l >= r) return;
01967 Clause* tmp = items[l];
01968 items[l] = items[(l+r)/2];
01969 items[(l+r)/2] = tmp;
01970
01971 int last = l;
01972 for (int i = l+1; i <= r; i++)
01973 if (items[i]->getAuxClauseData()->gain >
01974 items[l]->getAuxClauseData()->gain)
01975 {
01976 ++last;
01977 Clause* tmp = items[last];
01978 items[last] = items[i];
01979 items[i] = tmp;
01980 }
01981
01982 tmp = items[l];
01983 items[l] = items[last];
01984 items[last] = tmp;
01985 rquicksort(clauses, l, last-1);
01986 rquicksort(clauses, last+1, r);
01987 }
01988
01989 void rquicksort(Array<Clause*>& ca) { rquicksort(ca, 0, ca.size()-1); }
01990
01991
01992
01994 private:
01995
01996 void deleteExistFormulas(Array<ExistFormula*>& existFormulas)
01997 { existFormulas.deleteItemsAndClear(); }
01998
01999
02000 void getMLNClauses(Array<Clause*>& initialMLNClauses,
02001 Array<ExistFormula*>& existFormulas)
02002 {
02003 for (int i = 0; i < mln0_->getNumClauses(); i++)
02004 if (isModifiableClause(i))
02005 {
02006 Clause* c = (Clause*) mln0_->getClause(i);
02007 initialMLNClauses.append(new Clause(*c));
02008 }
02009
02010 const FormulaAndClausesArray* fnca = mln0_->getFormulaAndClausesArray();
02011 for (int i = 0; i < fnca->size(); i++)
02012 if (isNonModifiableFormula((*fnca)[i]))
02013 existFormulas.append(new ExistFormula((*fnca)[i]->formula));
02014
02015 for (int i = 0; i < existFormulas.size(); i++)
02016 {
02017 string formula = existFormulas[i]->formula;
02018 FormulaAndClauses tmp(formula, 0, false, false);
02019 const FormulaAndClausesArray* fnca
02020 = (*mlns_)[0]->getFormulaAndClausesArray();
02021 int a = fnca->find(&tmp);
02022 existFormulas[i]->numPreds = (*fnca)[a]->numPreds;
02023
02024 Array<Array<Clause*> >& cnfClausesForDomains
02025 = existFormulas[i]->cnfClausesForDomains;
02026 cnfClausesForDomains.growToSize(mlns_->size());
02027 for (int j = 0; j < mlns_->size(); j++)
02028 {
02029 Array<Clause*>& cnfClauses = cnfClausesForDomains[j];
02030 fnca = (*mlns_)[j]->getFormulaAndClausesArray();
02031 a = fnca->find(&tmp);
02032 assert(a >= 0);
02033 IndexClauseHashArray* indexClauses = (*fnca)[a]->indexClauses;
02034 for (int k = 0; k < indexClauses->size(); k++)
02035 {
02036 Clause* c = new Clause(*((*indexClauses)[k]->clause));
02037 c->newAuxClauseData();
02038 cnfClauses.append(c);
02039 }
02040 cnfClauses.compress();
02041 }
02042 }
02043 }
02044
02045
02046
02047 inline void pllCountsForExistFormula(Clause* cnfClause,
02048 const int& domainIdx,
02049 int* clauseIdxInMln,
02050 Array<UndoInfo*>* const & undoInfos)
02051 {
02052 assert(cnfClause->getAuxClauseData()->cache == NULL);
02053 bool inCache = false;
02054 bool hasDomainCounts = false;
02055 double prevCNFClauseSize = 0;
02056
02057 if (cacheClauses_)
02058 {
02059 int i;
02060 if ((i = cachedClauses_->find(cnfClause)) >= 0)
02061 {
02062 inCache = true;
02063 Array<Array<Array<CacheCount*>*>*>* cache =
02064 (*cachedClauses_)[i]->getAuxClauseData()->cache;
02065 Array<Array<CacheCount*>*>* domainCache = (*cache)[domainIdx];
02066 for (int j = 0; j < domainCache->size(); j++)
02067 if ((*domainCache)[j] != NULL) { hasDomainCounts = true; break; }
02068 }
02069
02070 if (hasDomainCounts)
02071 {
02072 pll_->insertCounts(clauseIdxInMln, undoInfos,
02073 (*cachedClauses_)[i]->getAuxClauseData()->cache,
02074 domainIdx);
02075 return;
02076 }
02077 else
02078 {
02079 if (cacheSizeMB_ < maxCacheSizeMB_)
02080 {
02081 if (inCache)
02082 {
02083 cnfClause->getAuxClauseData()->cache
02084 = (*cachedClauses_)[i]->getAuxClauseData()->cache;
02085 prevCNFClauseSize = cnfClause->sizeMB();
02086 }
02087 else
02088 cnfClause->newCache(domains_->size(),
02089 (*domains_)[0]->getNumPredicates());
02090 }
02091 else
02092 {
02093 static bool printCacheFull = true;
02094 if (printCacheFull)
02095 {
02096 cout << "Cache is full, approximate size = " << cacheSizeMB_
02097 <<" MB" << endl;
02098 printCacheFull = false;
02099 }
02100 }
02101 }
02102 }
02103
02104
02105 pll_->computeCountsForNewAppendedClause(cnfClause,clauseIdxInMln,domainIdx,
02106 undoInfos, sampleClauses_,
02107 cnfClause->getAuxClauseData()->cache);
02108
02109 if (cnfClause->getAuxClauseData()->cache)
02110 {
02111 if (inCache)
02112 {
02113 cacheSizeMB_ += cnfClause->sizeMB() - prevCNFClauseSize;
02114 cnfClause->getAuxClauseData()->cache = NULL;
02115 }
02116 else
02117 {
02118 if (cacheSizeMB_ < maxCacheSizeMB_)
02119 {
02120 cacheSizeMB_ += cnfClause->sizeMB();
02121 Array<Array<Array<CacheCount*>*>*>* cache
02122 = cnfClause->getAuxClauseData()->cache;
02123 cnfClause->getAuxClauseData()->cache = NULL;
02124 Clause* copyClause = new Clause(*cnfClause);
02125 copyClause->getAuxClauseData()->cache = cache;
02126 copyClause->compress();
02127 cachedClauses_->append(copyClause);
02128 }
02129 else
02130 {
02131 cnfClause->getAuxClauseData()->deleteCache();
02132 cnfClause->getAuxClauseData()->cache = NULL;
02133 }
02134 }
02135 }
02136 }
02137
02138
02139 inline void printNewScore(const string existFormStr,const int& lbfgsbIter,
02140 const double& lbfgsbSec, const double& newScore,
02141 const double& gain, const double& penalty,
02142 const double& wt)
02143 {
02144 cout << "*************************** " << candCnt_++ << ", iter " << iter_
02145 << ", beam search iter " << bsiter_ << endl;
02146
02147 cout << "exist formula : " << existFormStr << endl;
02148 cout << "op : OP_ADD" << endl;
02149 cout << "score : " << newScore << endl;
02150 cout << "gain : " << gain << endl;
02151 cout << "penalty : " << penalty << endl;
02152 cout << "net gain : " << gain - penalty << endl;
02153 cout << "wt : " << wt << endl;
02154 cout << "num LBFGSB iter = " << lbfgsbIter << endl;
02155 cout << "time taken by LBFGSB = "; Timer::printTime(cout, lbfgsbSec);
02156 cout << endl;
02157 cout << "*************************** " << endl << endl;;
02158 }
02159
02160
02161 inline void rquicksort(Array<ExistFormula*>& existFormulas, const int& l,
02162 const int& r)
02163 {
02164 ExistFormula** items = (ExistFormula**) existFormulas.getItems();
02165 if (l >= r) return;
02166 ExistFormula* tmp = items[l];
02167 items[l] = items[(l+r)/2];
02168 items[(l+r)/2] = tmp;
02169
02170 int last = l;
02171 for (int i = l+1; i <= r; i++)
02172 if (items[i]->gain > items[l]->gain)
02173 {
02174 ++last;
02175 ExistFormula* tmp = items[last];
02176 items[last] = items[i];
02177 items[i] = tmp;
02178 }
02179
02180 tmp = items[l];
02181 items[l] = items[last];
02182 items[last] = tmp;
02183 rquicksort(existFormulas, l, last-1);
02184 rquicksort(existFormulas, last+1, r);
02185 }
02186
02187
02188 inline void rquicksort(Array<ExistFormula*>& ef)
02189 { rquicksort(ef, 0, ef.size()-1); }
02190
02191
02192
02193
02194
02195 inline void evaluateExistFormula(ExistFormula* const& ef,
02196 const bool& computeCountsOnly,
02197 Array<Array<Array<IndexAndCount*> > >* const & iacArraysPerDomain,
02198 const double& prevScore)
02199 {
02200 bool undo = !computeCountsOnly;
02201 bool evalGainLearnWts = !computeCountsOnly;
02202
02203 Array<int*> idxPtrs;
02204 Array<UndoInfo*>* undoInfos = (undo) ? new Array<UndoInfo*> : NULL;
02205
02206
02207
02208 Array<Array<Clause*> >& cnfClausesForDomains = ef->cnfClausesForDomains;
02209 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02210 {
02211 Array<Clause*>& cnfClauses = cnfClausesForDomains[d];
02212 MLN* mln = (*mlns_)[d];
02213 for (int k = 0; k < cnfClauses.size(); k++)
02214 {
02215 Clause* c = cnfClauses[k];
02216
02217
02218
02219 if (!undo && mln->containsClause(c)) continue;
02220 int* tmpClauseIdxInMLN = new int(mln->getNumClauses() + k);
02221 idxPtrs.append(tmpClauseIdxInMLN);
02222 if (iacArraysPerDomain)
02223 {
02224 Array<Array<IndexAndCount*> >& iacArrays = (*iacArraysPerDomain)[d];
02225 assert(iacArrays.size() == cnfClauses.size());
02226 Array<IndexAndCount*>& iacArray = iacArrays[k];
02227
02228 Array<UndoInfo*> tmpUndoInfos;
02229 pllCountsForExistFormula(c, d, tmpClauseIdxInMLN, &tmpUndoInfos);
02230 for (int i = 0; i < tmpUndoInfos.size(); i++)
02231 iacArray.append(tmpUndoInfos[i]->affectedArr->lastItem());
02232 if (undoInfos) undoInfos->append(tmpUndoInfos);
02233 else tmpUndoInfos.deleteItemsAndClear();
02234 }
02235 else
02236 pllCountsForExistFormula(c, d, tmpClauseIdxInMLN, undoInfos);
02237 }
02238 }
02239
02240
02241
02242 if (evalGainLearnWts)
02243 {
02244 Array<double> priorMeans, priorStdDevs;
02245
02246
02247
02248 int numAdded = (indexTrans_) ? 1 : cnfClausesForDomains[0].size();
02249 setPriorMeansStdDevs(priorMeans, priorStdDevs, numAdded, NULL);
02250
02251 if (hasPrior_)
02252 pll_->setMeansStdDevs(priorMeans.size(), priorMeans.getItems(),
02253 priorStdDevs.getItems());
02254 else
02255 pll_->setMeansStdDevs(-1, NULL, NULL);
02256
02257 if (indexTrans_)
02258 {
02259 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02260 {
02261 int numCNFClauses = cnfClausesForDomains[d].size();
02262 indexTrans_->appendClauseIdxToClauseFormulaIdxs(1, numCNFClauses);
02263 }
02264 }
02265
02266 int iter; bool error; double elapsedSec;
02267 int numClausesFormulas = getNumClausesFormulas();
02268 Array<double>* wts = &(ef->wts);
02269 wts->growToSize(numClausesFormulas + numAdded + 1);
02270
02271 double newScore = maximizeScore(numClausesFormulas, numAdded, wts,
02272 NULL, NULL, iter, error, elapsedSec);
02273
02274 if (indexTrans_)
02275 {
02276 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02277 {
02278 int numCNFClauses = cnfClausesForDomains[d].size();
02279 indexTrans_->removeClauseIdxToClauseFormulaIdxs(1, numCNFClauses);
02280 }
02281 }
02282
02283 double totalWt = 0;
02284 for (int i = 0; i < numAdded; i++)
02285 totalWt += (*wts)[numClausesFormulas+i+1];
02286
02287 double penalty = ef->numPreds * penalty_;
02288
02289 ef->gain = newScore-prevScore-penalty;
02290 ef->wt = totalWt;
02291 ef->newScore = newScore;
02292
02293 if (error) {newScore=prevScore;cout<<"LBFGSB failed to find wts"<<endl;}
02294 printNewScore(ef->formula, iter, elapsedSec, newScore,
02295 newScore-prevScore, penalty, totalWt);
02296 }
02297
02298 if (undo) { pll_->undoAppendRemoveCounts(undoInfos); delete undoInfos; }
02299 idxPtrs.deleteItemsAndClear();
02300 }
02301
02302
02303
02304 double evaluateExistFormulas(Array<ExistFormula*>& existFormulas,
02305 Array<ExistFormula*>& highGainWtFormulas,
02306 const double& prevScore)
02307 {
02308 if (existFormulas.size() == 0) return 0;
02309 for (int i = 0; i < existFormulas.size(); i++)
02310 evaluateExistFormula(existFormulas[i], false, NULL, prevScore);
02311
02312 Array<ExistFormula*> tmp(existFormulas);
02313 rquicksort(tmp);
02314 double minGain = DBL_MAX;
02315 cout << "evaluated existential formulas " << endl;
02316 cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << endl;
02317 for (int i = 0; i < tmp.size(); i++)
02318 {
02319 existFormulas[i] = tmp[i];
02320 double gain = existFormulas[i]->gain;
02321 double wt = existFormulas[i]->wt;
02322 cout << i << "\t" << existFormulas[i]->formula << endl
02323 << "\tgain = " << gain << ", wt = " << wt << ", op = OP_ADD"
02324 << endl;
02325 if (gain > minGain_ && wt >= minWt_)
02326 {
02327 highGainWtFormulas.append(tmp[i]);
02328 if (gain < minGain) minGain = gain;
02329 }
02330 }
02331 cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << endl << endl;
02332
02333 if (minGain == DBL_MAX) minGain = 0;
02334 return minGain;
02335 }
02336
02337
02338 inline void appendExistFormulaToMLNs(ExistFormula* const & ef)
02339 {
02340 Array<Array<Array<IndexAndCount*> > > iacsPerDomain;
02341 Array<Array<Clause*> >& cnfClausesForDomains = ef->cnfClausesForDomains;
02342 iacsPerDomain.growToSize(cnfClausesForDomains.size());
02343 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02344 {
02345 Array<Array<IndexAndCount*> >& iacArrays = iacsPerDomain[d];
02346 iacArrays.growToSize( cnfClausesForDomains[d].size() );
02347 }
02348 evaluateExistFormula(ef, true, &iacsPerDomain, 0);
02349
02350 Array<double>& wts = ef->wts;
02351 int numClausesFormulas = getNumClausesFormulas();
02352
02353
02354
02355 if (indexTrans_ == NULL) updateWts(wts, NULL, NULL);
02356
02357
02358 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02359 {
02360 MLN* mln = (*mlns_)[d];
02361 Array<Array<IndexAndCount*> >& iacArrays = iacsPerDomain[d];
02362 Array<Clause*>& cnfClauses = cnfClausesForDomains[d];
02363
02364 for (int i = 0; i < cnfClauses.size(); i++)
02365 {
02366 int idx;
02367
02368
02369 double wt = (indexTrans_) ? 0 : wts[numClausesFormulas+i+1];
02370
02371
02372 mln->appendClause(ef->formula, true, new Clause(*cnfClauses[i]),
02373 wt, false, idx, false);
02374 mln->setFormulaPriorMean(ef->formula, priorMean_);
02375 ((MLNClauseInfo*)mln->getMLNClauseInfo(idx))->priorMean
02376 += priorMean_/cnfClauses.size();
02377
02378 int* idxPtr = mln->getMLNClauseInfoIndexPtr(idx);
02379 Array<IndexAndCount*>& iacs = iacArrays[i];
02380 for (int j = 0; j < iacs.size(); j++) iacs[j]->index = idxPtr;
02381 }
02382 }
02383 assert(pll_->checkNoRepeatedIndex());
02384
02385
02386 if (indexTrans_)
02387 {
02388
02389 Array<string> appendedFormula;
02390 appendedFormula.append(ef->formula);
02391 updateWts(wts, NULL, &appendedFormula);
02392
02393 indexTrans_->createClauseIdxToClauseFormulaIdxsMap();
02394 }
02395
02396 cout << "Modified MLN: Appended formula to MLN: " << ef->formula << endl;
02397 }
02398
02399
02400 inline bool effectExistFormulaOnMLNs(ExistFormula* ef,
02401 Array<ExistFormula*>& existFormulas,
02402 double& score)
02403 {
02404 cout << "effecting existentially quantified formula " << ef->formula
02405 << " on MLN..." << endl;
02406 appendExistFormulaToMLNs(ef);
02407 score = ef->newScore;
02408 printMLNClausesWithWeightsAndScore(score, iter_);
02409 printMLNToFile(NULL, iter_);
02410
02411 int r = existFormulas.find(ef);
02412 assert(r >= 0);
02413 ExistFormula* rf = existFormulas.removeItemFastDisorder(r);
02414 assert(rf == ef);
02415 delete rf;
02416 return true;
02417 }
02418
02419
02420 bool effectBestCandidateOnMLNs(Array<Clause*>& bestCandidates,
02421 Array<ExistFormula*>& existFormulas,
02422 Array<ExistFormula*>& highGainWtFormulas,
02423 double& score)
02424 {
02425 cout << "effecting best candidate among existential formulas and "
02426 << "best candidates on MLN..." << endl << endl;
02427
02428 int a = 0, b = 0;
02429 bool ok = false;
02430 int numCands = bestCandidates.size() + highGainWtFormulas.size();
02431 for (int i = 0; i < numCands; i++)
02432 {
02433 if (a >= bestCandidates.size())
02434 {
02435 if (ok=effectExistFormulaOnMLNs(highGainWtFormulas[b++],
02436 existFormulas, score)) break;
02437 }
02438 else
02439 if (b >= highGainWtFormulas.size())
02440 {
02441 cout << "effecting best candidate " << a << " on MLN..." << endl;
02442 if (ok=effectBestCandidateOnMLNs(bestCandidates[a++], score)) break;
02443 cout << "failed to effect candidate on MLN." << endl;
02444 delete bestCandidates[a-1];
02445 }
02446 else
02447 if (highGainWtFormulas[b]->gain >
02448 bestCandidates[a]->getAuxClauseData()->gain)
02449 {
02450 if (ok=effectExistFormulaOnMLNs(highGainWtFormulas[b++],
02451 existFormulas, score)) break;
02452 }
02453 else
02454 {
02455 cout << "effecting best candidate " << a << " on MLN..." << endl;
02456 if (ok=effectBestCandidateOnMLNs(bestCandidates[a++], score)) break;
02457 cout << "failed to effect candidate on MLN." << endl;
02458 delete bestCandidates[a-1];
02459 }
02460 }
02461
02462 for (int i = a; i < bestCandidates.size(); i++) delete bestCandidates[i];
02463 return ok;
02464 }
02465
02466
02467
02468
02469
02470
02471
02472
02473
02474
02475
02476
02477
02478
02479
02480
02481
02482
02483
02484
02485
02486
02487
02488
02489
02490
02491
02492
02493
02494
02495
02496
02497
02498
02499
02500
02501
02502
02503
02504
02505
02506
02507
02508
02509
02511 private:
02512
02513
02514 MLN* mln0_;
02515 Array<MLN*>* mlns_;
02516 bool startFromEmptyMLN_;
02517 string outMLNFileName_;
02518 Array<Domain*>* domains_;
02519 Array<Predicate*>* preds_;
02520 Array<bool>* areNonEvidPreds_;
02521 ClauseFactory* clauseFactory_;
02522
02523 bool cacheClauses_;
02524 bool origCacheClauses_;
02525 ClauseHashArray* cachedClauses_;
02526 double cacheSizeMB_;
02527 double maxCacheSizeMB_;
02528
02529 bool tryAllFlips_;
02530 bool sampleClauses_;
02531 bool origSampleClauses_;
02532
02533 PseudoLogLikelihood* pll_;
02534 bool hasPrior_;
02535 double priorMean_;
02536 double priorStdDev_;
02537 bool wtPredsEqually_;
02538
02539 LBFGSB* lbfgsb_;
02540 int lbMaxIter_;
02541 double lbConvThresh_;
02542 int looseMaxIter_;
02543 double looseConvThresh_;
02544
02545 int beamSize_;
02546 int bestGainUnchangedLimit_;
02547 int numEstBestClauses_;
02548 double minGain_;
02549 double minWt_;
02550 double penalty_;
02551
02552 bool sampleGndPreds_;
02553 double fraction_;
02554 int minGndPredSamples_;
02555 int maxGndPredSamples_;
02556
02557 bool reEvalBestCandsWithTightParams_;
02558
02559 Timer timer_;
02560 int candCnt_;
02561
02562 int iter_;
02563 int bsiter_;
02564 double startSec_;
02565
02566
02567
02568
02569
02570 IndexTranslator* indexTrans_;
02571
02572 bool structGradDescent_;
02573 bool withEM_;
02574 };
02575
02576
02577 #endif