00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef STRUCTLEARN_H_NOV_2_2005
00067 #define STRUCTLEARN_H_NOV_2_2005
00068
00069 #include "clausefactory.h"
00070 #include "mln.h"
00071
00072 #include "lbfgsb.h"
00073 #include "votedperceptron.h"
00074
00075
00077 struct ExistFormula
00078 {
00079 ExistFormula(const string& fformula)
00080 : formula(fformula), gain(0), wt(0), newScore(0), numPreds(0) {}
00081 ~ExistFormula()
00082 {
00083 for (int i = 0; i < cnfClausesForDomains.size(); i++)
00084 cnfClausesForDomains[i].deleteItemsAndClear();
00085 }
00086
00087 string formula;
00088
00089
00090 Array<Array<Clause*> > cnfClausesForDomains;
00091 double gain;
00092 double wt;
00093 Array<double> wts;
00094 double newScore;
00095 int numPreds;
00096 };
00097
00098 struct IndexCountDomainIdx { IndexAndCount* iac; int domainIdx; };
00099 struct ClauseAndICDArray {Clause* clause;Array<IndexCountDomainIdx> icdArray;};
00100
00101
00102 class StructLearn
00103 {
00104 public:
00105
00106
00107 StructLearn(Array<MLN*>* const & mlns, const bool& startFromEmptyMLN,
00108 const string& outMLNFileName, Array<Domain*>* const & domains,
00109 const Array<string>* const & nonEvidPredNames,
00110 const int& maxVars, const int& maxNumPredicates,
00111 const bool& cacheClauses, const double& maxCacheSizeMB,
00112 const bool& tryAllFlips, const bool& sampleClauses,
00113 const double& delta, const double& epsilon,
00114 const int& minClauseSamples, const int& maxClauseSamples,
00115 const bool& hasPrior, const double& priorMean,
00116 const double& priorStdDev, const bool& wtPredsEqually,
00117 const int& lbMaxIter, const double& lbConvThresh,
00118 const int& looseMaxIter, const double& looseConvThresh,
00119 const int& beamSize, const int& bestGainUnchangedLimit,
00120 const int& numEstBestClauses,
00121 const double& minWt, const double& penalty,
00122 const bool& sampleGndPreds, const double& fraction,
00123 const int& minGndPredSamples, const int& maxGndPredSamples,
00124 const bool& reEvaluateBestCandidatesWithTightParams,
00125 const bool& structGradDescent, const bool& withEM)
00126 : mln0_((*mlns)[0]), mlns_(mlns), startFromEmptyMLN_(startFromEmptyMLN),
00127 outMLNFileName_(outMLNFileName), domains_(domains),
00128 preds_(new Array<Predicate*>), areNonEvidPreds_(new Array<bool>),
00129 clauseFactory_(new ClauseFactory(maxVars, maxNumPredicates,
00130 (*domains_)[0])),
00131 cacheClauses_(cacheClauses), origCacheClauses_(cacheClauses),
00132 cachedClauses_((cacheClauses) ? (new ClauseHashArray) : NULL),
00133 cacheSizeMB_(0), maxCacheSizeMB_(maxCacheSizeMB),
00134 tryAllFlips_(tryAllFlips),
00135 sampleClauses_(sampleClauses), origSampleClauses_(sampleClauses),
00136 pll_(NULL), hasPrior_(hasPrior), priorMean_(priorMean),
00137 priorStdDev_(priorStdDev), wtPredsEqually_(wtPredsEqually),
00138 lbfgsb_(NULL), lbMaxIter_(lbMaxIter), lbConvThresh_(lbConvThresh),
00139 looseMaxIter_(looseMaxIter), looseConvThresh_(looseConvThresh),
00140 beamSize_(beamSize), bestGainUnchangedLimit_(bestGainUnchangedLimit),
00141 numEstBestClauses_(numEstBestClauses),
00142 minGain_(0), minWt_(minWt), penalty_(penalty),
00143 sampleGndPreds_(sampleGndPreds), fraction_(fraction),
00144 minGndPredSamples_(minGndPredSamples),
00145 maxGndPredSamples_(maxGndPredSamples),
00146 reEvalBestCandsWithTightParams_(reEvaluateBestCandidatesWithTightParams),
00147 candCnt_(0), iter_(-1), bsiter_(-1), startSec_(-1), indexTrans_(NULL),
00148 structGradDescent_(structGradDescent), withEM_(withEM)
00149 {
00150 assert(minWt_ >= 0);
00151 assert(domains_->size() == mlns_->size());
00152
00153 areNonEvidPreds_->growToSize((*domains_)[0]->getNumPredicates(), false);
00154 for (int i = 0; i < nonEvidPredNames->size(); i++)
00155 {
00156 int predId=(*domains_)[0]->getPredicateId((*nonEvidPredNames)[i].c_str());
00157 if (predId < 0)
00158 {
00159 cout << "ERROR: in StructLearn::StructLearn(). Predicate "
00160 << (*nonEvidPredNames)[i] << " undefined." << endl;
00161 exit(-1);
00162 }
00163 (*areNonEvidPreds_)[predId] = true;
00164 }
00165
00166 (*domains_)[0]->createPredicates(preds_, true);
00167
00168 if (origSampleClauses_)
00169 {
00170 ClauseSampler* cs = new ClauseSampler(delta, epsilon, minClauseSamples,
00171 maxClauseSamples);
00172 Clause::setClauseSampler(cs);
00173 for (int i = 0; i < domains_->size(); i++)
00174 (*domains_)[i]->newTrueFalseGroundingsStore();
00175 }
00176 }
00177
00178
00179 ~StructLearn()
00180 {
00181 if (pll_) delete pll_;
00182 if (lbfgsb_) delete lbfgsb_;
00183 preds_->deleteItemsAndClear();
00184 delete preds_;
00185 delete areNonEvidPreds_;
00186 delete clauseFactory_;
00187 if (cachedClauses_)
00188 {
00189 cachedClauses_->deleteItemsAndClear();
00190 delete cachedClauses_;
00191 }
00192 if (origSampleClauses_) delete Clause::getClauseSampler();
00193 if (indexTrans_) delete indexTrans_;
00194 }
00195
00196 void run()
00197 {
00198 startSec_ = timer_.time();
00199
00200 bool needIndexTrans = IndexTranslator::needIndexTranslator(*mlns_,*domains_);
00201
00202
00203
00204 Array<Clause*> initialMLNClauses;
00205 Array<ExistFormula*> existFormulas;
00206 if (startFromEmptyMLN_)
00207 {
00208 getMLNClauses(initialMLNClauses, existFormulas);
00209 removeClausesFromMLNs();
00210 for (int i = 0; i < initialMLNClauses.size(); i++)
00211 {
00212 Clause* c = initialMLNClauses[i];
00213 c->newAuxClauseData();
00214 c->trackConstants();
00215 c->getAuxClauseData()->gain = 0;
00216 c->getAuxClauseData()->op = OP_ADD;
00217 c->getAuxClauseData()->removedClauseIdx = -1;
00218 c->getAuxClauseData()->hasBeenExpanded = false;
00219 c->getAuxClauseData()->lastStepExpanded = -1;
00220 c->getAuxClauseData()->lastStepOverMinWeight = -1;
00221 }
00222 }
00223
00224
00225 cout << "adding unit clauses to MLN..." << endl << endl;
00226 addUnitClausesToMLNs();
00227
00228
00229 for (int i = 0; i < mln0_->getNumClauses(); i++)
00230 {
00231 Clause* c = (Clause*) mln0_->getClause(i);
00232 c->newAuxClauseData();
00233
00234
00235 if (isModifiableClause(i)) c->trackConstants();
00236 }
00237
00238 indexTrans_ = (needIndexTrans)? new IndexTranslator(mlns_, domains_) : NULL;
00239 if (indexTrans_)
00240 cout << "The weights of clauses in the CNFs of existential formulas wiil "
00241 << "be tied" << endl;
00242
00243
00244 runStructLearning(initialMLNClauses, existFormulas);
00245 }
00246
00247
00248 void runStructLearning(Array<Clause*> initialMLNClauses,
00249 Array<ExistFormula*> existFormulas)
00250 {
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295 pll_ = new PseudoLogLikelihood(areNonEvidPreds_, domains_, wtPredsEqually_,
00296 sampleGndPreds_, fraction_,
00297 minGndPredSamples_, maxGndPredSamples_);
00298 pll_->setIndexTranslator(indexTrans_);
00299
00300 int numClausesFormulas = getNumClausesFormulas();
00301 lbfgsb_ = new LBFGSB(-1, -1, pll_, numClausesFormulas);
00302
00303 useTightParams();
00304
00305
00306 cout << "computing counts for initial MLN clauses..." << endl;
00307 double begSec = timer_.time();
00308 pllComputeCountsForInitialMLN();
00309 cout << "computing counts for initial MLN clauses took ";
00310 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00311
00312
00313
00314 cout << "learning the initial weights and score of MLN..." << endl << endl;
00315 begSec = timer_.time();
00316 double score;
00317 Array<double> wts;
00318 if (!learnAndSetMLNWeights(score)) return;
00319 printMLNClausesWithWeightsAndScore(score, -1);
00320 cout << "learning the initial weights and score of MLN took ";
00321 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00322
00323
00324
00325 cout <<"trying to add unit clause with diff variable combinations to MLN..."
00326 << endl << endl;
00327 begSec = timer_.time();
00328 appendUnitClausesWithDiffCombOfVar(score);
00329 printMLNClausesWithWeightsAndScore(score, -1);
00330 cout <<"adding unit clause with diff variable combinations to MLN took ";
00331 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00332
00333
00334
00335 if (tryAllFlips_)
00336 {
00337 cout << "trying to flip the senses of MLN clauses..." << endl << endl;
00338 begSec = timer_.time();
00339 flipMLNClauses(score);
00340 printMLNClausesWithWeightsAndScore(score, -1);
00341 cout << "trying to flip the senses of MLN clauses took ";;
00342 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00343 }
00344
00345 iter_ = -1;
00346
00347
00348
00349 double bsec;
00350 Array<Clause*> initialClauses;
00351 Array<Clause*> bestCandidates;
00352 while (true)
00353 {
00354 begSec = timer_.time();
00355 iter_++;
00356
00357
00358 cout << "Iteration " << iter_ << endl << endl;
00359
00360 minGain_ = 0;
00361
00362 Array<ExistFormula*> highGainWtExistFormulas;
00363 if (startFromEmptyMLN_ && !existFormulas.empty())
00364 {
00365 useTightParams();
00366 cout << "evaluating the gains of existential formulas..." << endl<<endl;
00367 bsec = timer_.time();
00368
00369 minGain_ = evaluateExistFormulas(existFormulas, highGainWtExistFormulas,
00370 score);
00371 cout << "evaluating the gains of existential formulas took ";
00372 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00373 cout << "setting minGain to min gain among existential formulas. "
00374 << "minGain = " << minGain_ << endl;
00375 }
00376
00377 useLooseParams();
00378
00379
00380 initialClauses.clear();
00381 for (int i = 0; i < mln0_->getNumClauses(); i++)
00382 {
00383 Clause* c = (Clause*) mln0_->getClause(i);
00384 c->getAuxClauseData()->reset();
00385 if (isModifiableClause(i))
00386 {
00387 c->trackConstants();
00388 initialClauses.append(new Clause(*c));
00389 }
00390 }
00391
00392 bestCandidates.clear();
00393 beamSearch(initialClauses, initialMLNClauses, score, bestCandidates);
00394 bool noCand = (startFromEmptyMLN_ && !existFormulas.empty()) ?
00395 (bestCandidates.empty() && highGainWtExistFormulas.empty())
00396 : bestCandidates.empty();
00397 if (noCand)
00398 {
00399 cout << "Beam is empty. Ending search for MLN clauses." << endl;
00400 printIterAndTimeElapsed(begSec);
00401 break;
00402 }
00403
00404 useTightParams();
00405
00406 if (reEvalBestCandsWithTightParams_)
00407 {
00408 cout << "reevaluating top candidates... " << endl << endl;
00409 bsec = timer_.time();
00410 reEvaluateCandidates(bestCandidates, score);
00411 cout << "reevaluating top candidates took ";
00412 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00413 }
00414
00415
00416 bsec = timer_.time();
00417 bool ok;
00418 if (startFromEmptyMLN_ && !existFormulas.empty())
00419 ok = effectBestCandidateOnMLNs(bestCandidates, existFormulas,
00420 highGainWtExistFormulas, score);
00421 else
00422 ok = effectBestCandidateOnMLNs(bestCandidates, score);
00423
00424 cout << "effecting best candidates took ";
00425 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00426
00427 if (!ok)
00428 {
00429 cout << "failed to effect any of the best candidates on MLN" << endl
00430 << "stopping search for MLN clauses..." << endl;
00431 break;
00432 }
00433
00434 printIterAndTimeElapsed(begSec);
00435 }
00436
00437 cout << "done searching for MLN clauses" << endl << endl;
00438 int numIterTaken = iter_+1;
00439 iter_= -1;
00440
00441 useTightParams();
00442
00443
00444 cout << "pruning clauses from MLN..." << endl << endl;
00445 begSec = timer_.time();
00446 pruneMLN(score);
00447 cout << "pruning clauses from MLN took ";
00448 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
00449
00450 printMLNClausesWithWeightsAndScore(score, -1);
00451 printMLNToFile(NULL, -2);
00452 cout << "num of iterations taken = " << numIterTaken << endl;
00453
00454 cout << "time taken for structure learning = ";
00455 timer_.printTime(cout, timer_.time()-startSec_); cout << endl << endl;
00456
00457 initialMLNClauses.deleteItemsAndClear();
00458 deleteExistFormulas(existFormulas);
00459 }
00460
00461
00462 bool learnAndSetMLNWeights(double& score)
00463 {
00464 Array<double> priorMeans, priorStdDevs;
00465 double tmpScore = score;
00466 pllSetPriorMeansStdDevs(priorMeans, priorStdDevs, 0, NULL);
00467 int numClausesFormulas = getNumClausesFormulas();
00468 Array<double> wts;
00469 wts.growToSize(numClausesFormulas + 1);
00470
00471 int iter; bool error; double elapsedSec;
00472 tmpScore = maximizeScore(numClausesFormulas, 0, &wts, NULL, NULL,
00473 iter, error, elapsedSec);
00474 if (error)
00475 {
00476 cout << "LBFGSB failed to find wts" << endl;
00477 return false;
00478 }
00479 else
00480 {
00481 score = tmpScore;
00482 printNewScore((Clause*)NULL, NULL, iter, elapsedSec, score, 0, 0);
00483 }
00484
00485 updateWts(wts, NULL, NULL);
00486
00487 return true;
00488 }
00489
00491 private:
00492
00493
00494
00495 void beamSearch(const Array<Clause*>& initClauses,
00496 const Array<Clause*>& initMLNClauses, const double& prevScore,
00497 Array<Clause*>& bestClauses)
00498 {
00499 int iterBestGainUnchanged = 0; bsiter_ = -1;
00500 Array<double> priorMeans, priorStdDevs;
00501 ClauseOpHashArray* beam = new ClauseOpHashArray;
00502 for (int i = 0; i < initClauses.size(); i++) beam->append(initClauses[i]);
00503
00504 int numClausesFormulas = getNumClausesFormulas();
00505 Array<double> wts;
00506 wts.growToSize(numClausesFormulas + 2);
00507 Array<double> origWts(numClausesFormulas);
00508 if (indexTrans_) indexTrans_->getClauseFormulaWts(origWts);
00509 else mln0_->getClauseWts(origWts);
00510
00511
00512 setPriorMeansStdDevs(priorMeans, priorStdDevs, 1, NULL);
00513 if (indexTrans_) indexTrans_->appendClauseIdxToClauseFormulaIdxs(1, 1);
00514
00515 bool error;
00516 double begIterSec, bsec;
00517 Array<Clause*> candidates;
00518 while (!beam->empty() && iterBestGainUnchanged < bestGainUnchangedLimit_)
00519 {
00520 begIterSec = timer_.time();
00521 bsiter_++;
00522
00523
00524 cout << endl << "BEAM SEARCH ITERATION " << bsiter_ << endl << endl;
00525
00526 cout << "creating candidate clauses..." << endl;
00527 candidates.clear();
00528 bsec = timer_.time();
00529
00530 if (bsiter_==0) createCandidateClauses(beam, candidates, &initMLNClauses);
00531 else createCandidateClauses(beam, candidates, NULL);
00532
00533 cout << "num of candidates created = " << candidates.size() << "; took ";
00534 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00535
00536 cout << "evaluating gain of candidates..." << endl << endl;
00537 bsec = timer_.time();
00538 for (int i = 0; i < candidates.size(); i++)
00539 countAndMaxScoreEffectCandidate(candidates[i], &wts, &origWts,prevScore,
00540 false, priorMeans, priorStdDevs, error,
00541 NULL, NULL);
00542 cout << "evaluating gain of candidates took ";
00543 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00544
00545
00546 cout << "finding best candidates..." << endl;
00547 bsec = timer_.time();
00548 ClauseOpHashArray* newBeam = new ClauseOpHashArray(beamSize_);
00549 bool newBestClause = findBestClausesAmongCandidates(candidates, newBeam,
00550 bestClauses);
00551 cout << "finding best candidates took ";
00552 timer_.printTime(cout, timer_.time()-bsec); cout << endl << endl;
00553
00554 beam->deleteItemsAndClear();
00555 delete beam;
00556 beam = newBeam;
00557
00558 if (newBestClause)
00559 {
00560 iterBestGainUnchanged = 0;
00561 cout << "found new best clause in beam search iter " << bsiter_ << endl;
00562 }
00563 else
00564 iterBestGainUnchanged++;
00565
00566
00567
00568 cout << "best clauses found in beam search iter " << bsiter_ << endl;
00569 cout << "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" << endl;
00570 for (int i = 0; i < bestClauses.size(); i++)
00571 {
00572 cout << i << "\t";
00573 bestClauses[i]->printWithoutWtWithStrVar(cout,(*domains_)[0]);
00574 cout << endl
00575 << "\tgain = " << bestClauses[i]->getAuxClauseData()->gain
00576 << ", op = "
00577 << Clause::getOpAsString(bestClauses[i]->getAuxClauseData()->op);
00578 if (bestClauses[i]->getAuxClauseData()->op != OP_REMOVE)
00579 cout << ", wt = " << bestClauses[i]->getWt();
00580 cout << endl;
00581 }
00582 cout << "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" << endl << endl;
00583
00584
00585 cout << "BEAM SEARCH ITERATION " << bsiter_ << " took ";
00586 timer_.printTime(cout, timer_.time()-begIterSec); cout << endl << endl;
00587 cout << "Time elapsed = ";
00588 timer_.printTime(cout, timer_.time()-startSec_); cout << endl << endl;
00589 }
00590
00591 if (beam->empty()) cout << "Beam search ended because beam is empty"
00592 << endl << endl;
00593 else cout << "Beam search ended because best clause did not change in "
00594 << iterBestGainUnchanged << " iterations" << endl << endl;
00595
00596 beam->deleteItemsAndClear();
00597 delete beam;
00598 bsiter_ = -1;
00599
00600 if (indexTrans_) indexTrans_->removeClauseIdxToClauseFormulaIdxs(1, 1);
00601 }
00602
00603
00605 private:
00606
00607
00608
00609 double maximizeScore(const int& numClausesFormulas, const int& numExtraWts,
00610 Array<double>* const & wts,
00611 const Array<double>* const & origWts,
00612 const Array<int>* const & removedClauseFormulaIdxs,
00613 int& iter, bool& error, double& elapsedSec)
00614 {
00615 if (origWts)
00616 { for (int i=1 ; i<=numClausesFormulas; i++) (*wts)[i] = (*origWts)[i-1];}
00617 else
00618 { for (int i=1; i<=numClausesFormulas; i++) (*wts)[i] = 0; }
00619
00620 for (int i = 1; i <= numExtraWts; i++) (*wts)[numClausesFormulas+i] = 0;
00621
00622 if (removedClauseFormulaIdxs)
00623 for (int i = 0; i < removedClauseFormulaIdxs->size(); i++)
00624 (*wts)[ (*removedClauseFormulaIdxs)[i]+1 ] = 0;
00625
00626
00627
00628
00629
00630 double* wwts = (double*) wts->getItems();
00631 double begSec = timer_.time();
00632 double newScore
00633 = lbfgsb_->minimize(numClausesFormulas + numExtraWts, wwts, iter, error);
00634 newScore = -newScore;
00635 elapsedSec = timer_.time() - begSec;
00636
00637
00638
00639
00640
00641 return newScore;
00642 }
00643
00644
00645
00646
00647
00648
00649
00650
00651
00652 double countAndMaxScoreEffectAllCandidates(const Array<Clause*>& candidates,
00653 Array<double>* const & wts,
00654 const Array<double>*const& origWts,
00655 const double& prevScore,
00656 const bool& resetPriors,
00657 Array<double>& priorMeans,
00658 Array<double>& priorStdDevs,
00659 bool& error,
00660 Array<UndoInfo*>* const& uundoInfos,
00661 Array<ClauseAndICDArray*>* const & appendedClauseInfos)
00662 {
00663 Array<UndoInfo*>* undoInfos = (uundoInfos)? uundoInfos:new Array<UndoInfo*>;
00664 Array<int> remClauseFormulaIdxs;
00665 Array<int*> idxPtrs;
00666 Array<Clause*> toBeRemovedClauses;
00667 Array<Clause*> toBeAppendedClauses;
00668 int numClausesFormulas = getNumClausesFormulas();
00669
00670 for (int i = 0; i < candidates.size(); i++)
00671 {
00672 Clause* cand = candidates[i];
00673 AuxClauseData* acd = cand->getAuxClauseData();
00674 int op = acd->op;
00675
00676
00677 if (op == OP_REMOVE || op == OP_REPLACE || op == OP_REPLACE_ADDPRED ||
00678 op == OP_REPLACE_REMPRED)
00679 {
00680 Clause* remClause = (mlns_->size() > 1) ?
00681 (Clause*) mln0_->getClause(acd->removedClauseIdx) : NULL;
00682 Array<int*> idxs; idxs.growToSize(mlns_->size());
00683 Array<Clause*> remClauses; remClauses.growToSize(mlns_->size());
00684 for (int d = 0; d < mlns_->size(); d++)
00685 {
00686 int remIdx;
00687 if (d == 0)
00688 {
00689 remIdx = acd->removedClauseIdx;
00690 remClauseFormulaIdxs.append(remIdx);
00691 }
00692 else
00693 {
00694 if (remClause->containsConstants())
00695 remClause->translateConstants((*domains_)[d-1], (*domains_)[d]);
00696 remIdx = (*mlns_)[d]->findClauseIdx(remClause);
00697 }
00698 idxs[d] = (*mlns_)[d]->getMLNClauseInfoIndexPtr(remIdx);
00699 remClauses[d] = (Clause*) (*mlns_)[d]->getClause(remIdx);
00700 }
00701 if (mlns_->size() > 1 && remClause->containsConstants())
00702 remClause->translateConstants((*domains_)[mlns_->size()-1],
00703 (*domains_)[0]);
00704 pllRemoveCountsForClause(remClauses, idxs, undoInfos);
00705
00706 if (op == OP_REMOVE) toBeRemovedClauses.append(cand);
00707 }
00708
00709
00710 if (op == OP_ADD || op == OP_REPLACE || op == OP_REPLACE_ADDPRED ||
00711 op == OP_REPLACE_REMPRED)
00712 {
00713 Array<int*> tmpIdxs; tmpIdxs.growToSize(mlns_->size());
00714 for (int d = 0; d < mlns_->size(); d++)
00715 {
00716 int* tmpClauseIdxInMLN = new int( (*mlns_)[d]->getNumClauses() +
00717 toBeAppendedClauses.size() );
00718 tmpIdxs[d] = tmpClauseIdxInMLN;
00719 idxPtrs.append(tmpClauseIdxInMLN);
00720 }
00721
00722 Array<UndoInfo*>* tmpInfos = (appendedClauseInfos)?
00723 new Array<UndoInfo*> : undoInfos;
00724 pllComputeCountsForClause(cand, tmpIdxs, tmpInfos);
00725 toBeAppendedClauses.append(cand);
00726
00727 if (appendedClauseInfos)
00728 {
00729 undoInfos->append(*tmpInfos);
00730 ClauseAndICDArray* ca = new ClauseAndICDArray;
00731 ca->clause = cand;
00732 appendedClauseInfos->append(ca);
00733 for (int j = 0; j < tmpInfos->size(); j++)
00734 {
00735 ca->icdArray.append(IndexCountDomainIdx());
00736 ca->icdArray.lastItem().iac=(*tmpInfos)[j]->affectedArr->lastItem();
00737 ca->icdArray.lastItem().domainIdx = (*tmpInfos)[j]->domainIdx;
00738 }
00739 delete tmpInfos;
00740 }
00741 }
00742 }
00743
00744
00745
00746
00747 if (indexTrans_)
00748 indexTrans_->getClauseFormulaIndexes(remClauseFormulaIdxs, 0);
00749
00750 Array<double> removedValues;
00751 if (resetPriors)
00752 {
00753 setPriorMeansStdDevs(priorMeans, priorStdDevs, toBeAppendedClauses.size(),
00754 &remClauseFormulaIdxs);
00755 if (indexTrans_) indexTrans_->appendClauseIdxToClauseFormulaIdxs(
00756 toBeAppendedClauses.size(), 1);
00757 }
00758 else
00759 {
00760 assert(priorMeans.size() == numClausesFormulas + candidates.size());
00761 assert(priorStdDevs.size() == numClausesFormulas + candidates.size());
00762 if (indexTrans_)
00763 assert(indexTrans_->checkCIdxWtsGradsSize(candidates.size()));
00764
00765 setRemoveAppendedPriorMeans(numClausesFormulas, priorMeans,
00766 remClauseFormulaIdxs,
00767 toBeAppendedClauses.size(), removedValues);
00768 }
00769
00770 if (hasPrior_)
00771 pll_->setMeansStdDevs(priorMeans.size(), priorMeans.getItems(),
00772 priorStdDevs.getItems());
00773 else
00774 pll_->setMeansStdDevs(-1, NULL, NULL);
00775
00776 wts->growToSize(numClausesFormulas + toBeAppendedClauses.size() + 1);
00777
00778 int iter; double elapsedSec;
00779 double newScore = maximizeScore(numClausesFormulas,
00780 toBeAppendedClauses.size(),
00781 wts, origWts, &remClauseFormulaIdxs,
00782 iter, error, elapsedSec);
00783
00784
00785 for (int i = 0; i < toBeAppendedClauses.size(); i++)
00786 toBeAppendedClauses[i]->setWt( (*wts)[numClausesFormulas+i+1] );
00787
00788 for (int i = 0; i < toBeRemovedClauses.size(); i++)
00789 toBeRemovedClauses[i]->setWt(111111111);
00790
00791 Array<double> penalties;
00792 for (int i = 0; i < candidates.size(); i++)
00793 penalties.append(getPenalty(candidates[i]));
00794
00795 if (error) { newScore=prevScore; cout<<"LBFGSB failed to find wts"<<endl; }
00796 printNewScore(candidates, (*domains_)[0], iter, elapsedSec,
00797 newScore, newScore-prevScore, penalties);
00798
00799 for (int i = 0; i < candidates.size(); i++)
00800 candidates[i]->getAuxClauseData()->gain = newScore-prevScore-penalties[i];
00801
00802 if (uundoInfos == NULL)
00803 {
00804 pll_->undoAppendRemoveCounts(undoInfos);
00805 delete undoInfos;
00806 }
00807
00808 if (resetPriors)
00809 {
00810 if (indexTrans_) indexTrans_->removeClauseIdxToClauseFormulaIdxs(
00811 toBeAppendedClauses.size(), 1);
00812 }
00813 else
00814 {
00815 for (int i = 0; i < remClauseFormulaIdxs.size(); i++)
00816 priorMeans[remClauseFormulaIdxs[i]] = removedValues[i];
00817 }
00818
00819 idxPtrs.deleteItemsAndClear();
00820
00821 return newScore;
00822 }
00823
00824
00825 double countAndMaxScoreEffectCandidate(Clause* const & candidate,
00826 Array<double>* const & wts,
00827 const Array<double>* const & origWts,
00828 const double& prevScore,
00829 const bool& resetPriors,
00830 Array<double>& priorMeans,
00831 Array<double>& priorStdDevs,
00832 bool& error,
00833 Array<UndoInfo*>* const& undoInfos,
00834 Array<ClauseAndICDArray*>* const & appendedClauseInfos)
00835 {
00836 Array<Clause*> candidates; candidates.append(candidate);
00837 return countAndMaxScoreEffectAllCandidates(candidates, wts, origWts,
00838 prevScore, resetPriors,
00839 priorMeans, priorStdDevs, error,
00840 undoInfos, appendedClauseInfos);
00841 }
00842
00843
00844
00845
00846 bool appendToAndRemoveFromMLNs(const Array<Clause*>& candidates,double& score,
00847 const bool& makeChangeForEqualScore=false)
00848 {
00849 Array<double> wts;
00850 Array<UndoInfo*> undoInfos;
00851 Array<ClauseAndICDArray*> appendedClauseInfos;
00852 Array<double> priorMeans, priorStdDevs;
00853 bool error;
00854
00855 double newScore
00856 = countAndMaxScoreEffectAllCandidates(candidates, &wts, NULL, score,
00857 true, priorMeans,priorStdDevs,error,
00858 &undoInfos, &appendedClauseInfos);
00859
00860 bool improve = (makeChangeForEqualScore) ? (newScore >= score)
00861 : (newScore > score);
00862
00863 if (!error && improve)
00864 {
00865 score = newScore;
00866
00867
00868
00869 updateWts(wts, NULL, NULL);
00870
00871 int numClausesFormulas = getNumClausesFormulas();
00872
00873 for (int i = 0; i < appendedClauseInfos.size(); i++)
00874 appendedClauseInfos[i]->clause->setWt(wts[numClausesFormulas+i+1]);
00875
00876
00877 int appClauseIdx = 0;
00878 for (int i = 0; i < candidates.size(); i++)
00879 {
00880 Clause* cand = candidates[i];
00881 AuxClauseData* acd = cand->getAuxClauseData();
00882 int op = acd->op;
00883
00884
00885
00886
00887 if (op == OP_REPLACE) cand->canonicalize();
00888
00889 if (op == OP_REMOVE || op == OP_REPLACE || op == OP_REPLACE_ADDPRED ||
00890 op == OP_REPLACE_REMPRED)
00891 {
00892 Clause* remClause = (Clause*) mln0_->getClause(acd->removedClauseIdx);
00893
00894 for (int d = 0; d < mlns_->size(); d++)
00895 {
00896 if (d == 0)
00897 {
00898 Clause* r = removeClauseFromMLN(acd->removedClauseIdx, d);
00899 cout << "Modified MLN: Removed clause from MLN: ";
00900 r->printWithoutWtWithStrVar(cout,(*domains_)[0]); cout << endl;
00901 if (op == OP_REMOVE && cand != r) delete cand;
00902
00903 }
00904 else
00905 {
00906 if (remClause->containsConstants())
00907 remClause->translateConstants((*domains_)[d-1], (*domains_)[d]);
00908 int remIdx = (*mlns_)[d]->findClauseIdx(remClause);
00909 delete removeClauseFromMLN(remIdx, d);
00910 }
00911 }
00912 delete remClause;
00913 }
00914
00915 if (op == OP_ADD || op == OP_REPLACE || op == OP_REPLACE_ADDPRED ||
00916 op == OP_REPLACE_REMPRED)
00917 {
00918 Array<int*> idxPtrs; idxPtrs.growToSize(mlns_->size());
00919 for (int d = 0; d < mlns_->size(); d++)
00920 {
00921 Clause* c = cand;
00922 if (d > 0)
00923 {
00924 if (cand->containsConstants())
00925 cand->translateConstants((*domains_)[d-1], (*domains_)[d]);
00926 c = new Clause(*cand);
00927 }
00928 int idx = appendClauseToMLN(c, d);
00929 idxPtrs[d] = (*mlns_)[d]->getMLNClauseInfoIndexPtr(idx);
00930 }
00931
00932 Array<IndexCountDomainIdx>& icds
00933 = appendedClauseInfos[appClauseIdx++]->icdArray;
00934 for (int j = 0; j < icds.size(); j++)
00935 icds[j].iac->index = idxPtrs[ icds[j].domainIdx ];
00936
00937 cout << "Modified MLN: Appended clause to MLN: ";
00938 cand->printWithoutWtWithStrVar(cout,(*domains_)[0]);
00939 cout << endl;
00940 }
00941 }
00942
00943 assert(pll_->checkNoRepeatedIndex());
00944 assert(appClauseIdx == appendedClauseInfos.size());
00945
00946 undoInfos.deleteItemsAndClear();
00947
00948
00949 if (indexTrans_) indexTrans_->createClauseIdxToClauseFormulaIdxsMap();
00950 }
00951 else
00952 {
00953 cout << "undoing candidates because score did not improve..."<<endl<<endl;
00954 pll_->undoAppendRemoveCounts(&undoInfos);
00955 }
00956
00957 appendedClauseInfos.deleteItemsAndClear();
00958 return improve;
00959 }
00960
00961
00962
00963 bool appendToAndRemoveFromMLNs(Clause* const & candidate, double& score,
00964 const bool& makeChangeForEqualScore=false)
00965 {
00966 Array<Clause*> candidates; candidates.append(candidate);
00967 return appendToAndRemoveFromMLNs(candidates, score,makeChangeForEqualScore);
00968 }
00969
00970
00971
00972 void pllComputeCountsForClause(Clause* const & c,
00973 const Array<int*>& clauseIdxInMLNs,
00974 Array<UndoInfo*>* const & undoInfos)
00975 {
00976 assert(c->getAuxClauseData()->cache == NULL);
00977 assert(clauseIdxInMLNs.size() == domains_->size());
00978 double begSec = timer_.time();
00979
00980 if (cacheClauses_)
00981 {
00982 int i;
00983 if ((i = cachedClauses_->find(c)) >= 0)
00984 {
00985
00986
00987
00988 pll_->insertCounts(clauseIdxInMLNs, undoInfos,
00989 (*cachedClauses_)[i]->getAuxClauseData()->cache);
00990 cout << "using cached counts took ";
00991 timer_.printTime(cout, timer_.time()-begSec); cout << endl;
00992 return;
00993 }
00994 else
00995 {
00996 assert(c->getAuxClauseData()->cache == NULL);
00997 if (cacheSizeMB_ < maxCacheSizeMB_)
00998 c->newCache(domains_->size(), (*domains_)[0]->getNumPredicates());
00999 else
01000 {
01001 static bool printCacheFull = true;
01002 if (printCacheFull)
01003 {
01004 cout << "Cache is full, approximate size = " << cacheSizeMB_
01005 << " MB" << endl;
01006 printCacheFull = false;
01007 }
01008 }
01009
01010 }
01011 }
01012
01013
01014
01015
01016 if (!c->containsConstants())
01017 {
01018 for (int i = 0; i < domains_->size(); i++)
01019 {
01020 int* clauseIdxInMLN = clauseIdxInMLNs[i];
01021 pll_->computeCountsForNewAppendedClause(c, clauseIdxInMLN, i,
01022 undoInfos, sampleClauses_,
01023 c->getAuxClauseData()->cache);
01024 }
01025 }
01026 else
01027 {
01028 int i;
01029 for (i = 0; i < domains_->size(); i++)
01030 {
01031 if (i > 0) c->translateConstants((*domains_)[i-1], (*domains_)[i]);
01032 int* clauseIdxInMLN = clauseIdxInMLNs[i];
01033 pll_->computeCountsForNewAppendedClause(c, clauseIdxInMLN, i,
01034 undoInfos, sampleClauses_,
01035 c->getAuxClauseData()->cache);
01036 }
01037 if (i > 1) c->translateConstants((*domains_)[i-1], (*domains_)[0]);
01038 }
01039
01040
01041 if (c->getAuxClauseData()->cache)
01042 {
01043 if (cacheSizeMB_ < maxCacheSizeMB_)
01044 {
01045 cacheSizeMB_ += c->sizeMB();
01046 Array<Array<Array<CacheCount*>*>*>* cache =c->getAuxClauseData()->cache;
01047 c->getAuxClauseData()->cache = NULL;
01048 Clause* copyClause = new Clause(*c);
01049 copyClause->getAuxClauseData()->cache = cache;
01050 copyClause->compress();
01051 cachedClauses_->append(copyClause);
01052 }
01053 else
01054 {
01055 c->getAuxClauseData()->deleteCache();
01056 c->getAuxClauseData()->cache = NULL;
01057 }
01058 }
01059
01060 cout << "Computing counts took ";
01061 timer_.printTime(cout, timer_.time()-begSec);
01062
01063
01064 cout << endl;
01065 }
01066
01067
01068 void pllRemoveCountsForClause(const Array<Clause*>& remClauses,
01069 const Array<int*>& clauseIdxInMLNs,
01070 Array<UndoInfo*>* const & undoInfos)
01071 {
01072 assert(clauseIdxInMLNs.size() == domains_->size());
01073 double begSec = timer_.time();
01074 for (int i = 0; i < domains_->size(); i++)
01075 pll_->removeCountsForClause(remClauses[i],clauseIdxInMLNs[i],i,undoInfos);
01076 cout << "Removing counts took ";
01077 timer_.printTime(cout, timer_.time()-begSec);
01078 cout << endl;
01079 }
01080
01081
01082
01083 void pllComputeCountsForInitialMLN()
01084 {
01085 for (int i = 0; i < mlns_->size(); i++)
01086 {
01087 cout << "computing counts for clauses in domain " << i << "..." << endl;
01088 MLN* mln = (*mlns_)[i];
01089 for (int j = 0; j < mln->getNumClauses(); j++)
01090 {
01091 Clause* c = (Clause*)mln->getClause(j);
01092 cout << "Clause " << j << ": ";
01093 c->printWithoutWtWithStrVar(cout, (*domains_)[i]); cout << endl;
01094 int* clauseIdxInMLN = mln->getMLNClauseInfoIndexPtr(j);
01095 pll_->computeCountsForNewAppendedClause(c, clauseIdxInMLN, i,
01096 NULL, sampleClauses_, NULL);
01097 }
01098 }
01099 }
01100
01101
01103 private:
01104 void addUnitClausesToMLNs()
01105 {
01106 Array<Predicate*> nonEvidPreds;
01107 for (int i = 0; i < preds_->size(); i++)
01108 if ((*areNonEvidPreds_)[(*preds_)[i]->getId()])
01109 nonEvidPreds.append((*preds_)[i]);
01110
01111 Array<Clause*> unitClauses;
01112 bool allowEqualPreds = true;
01113 ClauseFactory::createUnitClauses(unitClauses,nonEvidPreds,allowEqualPreds);
01114
01115 for (int i = 0; i < unitClauses.size(); i++)
01116 {
01117 if (mln0_->containsClause(unitClauses[i]))
01118 { delete unitClauses[i]; continue; }
01119 ostringstream oss; int idx;
01120 unitClauses[i]->printWithoutWtWithStrVar(oss, (*domains_)[0]);
01121
01122 for (int j = 0; j < mlns_->size(); j++)
01123 {
01124 Clause* c = (j == 0) ? unitClauses[i] : new Clause(*unitClauses[i]);
01125 (*mlns_)[j]->appendClause(oss.str(), false, c, priorMean_, false, idx);
01126 ((MLNClauseInfo*)(*mlns_)[j]->getMLNClauseInfo(idx))->priorMean
01127 = priorMean_;
01128 }
01129 }
01130 }
01131
01132
01133 void appendUnitClausesWithDiffCombOfVar(double& score)
01134 {
01135 bool allowEqualPreds = false;
01136 for (int i = 0; i < preds_->size(); i++)
01137 {
01138 if (!(*areNonEvidPreds_)[(*preds_)[i]->getId()]) continue;
01139
01140 Clause* origClause = ClauseFactory::createUnitClause((*preds_)[i],
01141 allowEqualPreds);
01142 if (origClause == NULL) continue;
01143 assert(origClause->getAuxClauseData() == NULL);
01144 origClause->setAuxClauseData(new AuxClauseData);
01145 origClause->trackConstants();
01146
01147 ClauseOpHashArray newUnitClauses;
01148 clauseFactory_->createUnitClausesWithDiffCombOfVar((*preds_)[i], OP_ADD,
01149 -1, newUnitClauses);
01150 for (int j = 0; j < newUnitClauses.size(); j++)
01151 {
01152 Clause* newClause = newUnitClauses[j];
01153
01154 if (origClause->same(newClause) ||
01155 !clauseFactory_->validClause(newClause) ||
01156 mln0_->containsClause(newClause))
01157 {
01158 newUnitClauses.removeItemFastDisorder(j);
01159 delete newClause;
01160 j--;
01161 continue;
01162 }
01163
01164 if (!appendToAndRemoveFromMLNs(newClause, score)) delete newClause;
01165 }
01166 delete origClause;
01167 }
01168 }
01169
01170
01171 void flipMLNClauses(double& score)
01172 {
01173 Array<Clause*> mlnClauses;
01174 for (int i = 0; i < mln0_->getNumClauses(); i++)
01175 {
01176
01177 if (mln0_->getClause(i)->getNumPredicates()==1 || !isModifiableClause(i))
01178 continue;
01179 mlnClauses.append((Clause*)mln0_->getClause(i));
01180 }
01181
01182 for (int i = 0; i < mlnClauses.size(); i++)
01183 {
01184 Clause* origClause = mlnClauses[i];
01185 int origIdx = mln0_->findClauseIdx(origClause);
01186
01187
01188
01189 bool canonicalizeNewClauses = false;
01190 ClauseOpHashArray newClauses;
01191 clauseFactory_->flipSensesInClause(origClause, OP_REPLACE, origIdx,
01192 newClauses, canonicalizeNewClauses);
01193
01194 Array<double> priorMeans, priorStdDevs;
01195 setPriorMeansStdDevs(priorMeans, priorStdDevs, 1, NULL);
01196 Array<double> wts;
01197 wts.growToSize(getNumClausesFormulas()+2);
01198 Clause* bestClause = NULL;
01199 double bestScore = score;
01200 bool error;
01201
01202 for (int j = 0; j < newClauses.size(); j++)
01203 {
01204 Clause* newClause = newClauses[j];
01205 if (origClause->same(newClause) || newClause->hasRedundantPredicates()
01206 || mln0_->containsClause(newClause)) { delete newClause; continue; }
01207
01208 double newScore
01209 = countAndMaxScoreEffectCandidate(newClause, &wts, NULL, score, false,
01210 priorMeans, priorStdDevs, error,
01211 NULL, NULL);
01212 if (newScore > bestScore)
01213 {
01214 bestScore = newScore;
01215 if (bestClause) delete bestClause;
01216 bestClause = newClause;
01217 }
01218 else
01219 delete newClause;
01220 }
01221
01222 if (bestClause && !appendToAndRemoveFromMLNs(bestClause, score))
01223 delete bestClause;
01224 }
01225 }
01226
01227
01228
01229
01230 void pruneMLN(double& score)
01231 {
01232 Array<Clause*> mlnClauses;
01233 for (int i = 0; i < mln0_->getNumClauses(); i++)
01234 {
01235
01236 if (mln0_->getClause(i)->getNumPredicates() == 1 ||
01237 !isModifiableClause(i)) continue;
01238 mlnClauses.append((Clause*)mln0_->getClause(i));
01239 }
01240
01241 for (int i = 0; i < mlnClauses.size(); i++)
01242 {
01243 Clause* origClause = mlnClauses[i];
01244 int origIdx = mln0_->findClauseIdx(origClause);
01245 Clause* copy = new Clause(*origClause);
01246 copy->setAuxClauseData(new AuxClauseData(0, OP_REMOVE, origIdx));
01247 copy->trackConstants();
01248 if (!appendToAndRemoveFromMLNs(copy, score, true)) delete copy;
01249 }
01250 }
01251
01252
01253 int appendClauseToMLN(Clause* const c, const int& domainIdx)
01254 {
01255 ostringstream oss; int idx;
01256 c->printWithoutWtWithStrVar(oss, (*domains_)[domainIdx]);
01257 MLN* mln = (*mlns_)[domainIdx];
01258 bool ok = mln->appendClause(oss.str(), false, c, c->getWt(), false, idx);
01259 if (!ok) { cout << "ERROR: failed to insert " << oss.str() <<" into MLN"
01260 << endl; exit(-1); }
01261 ((MLNClauseInfo*)mln->getMLNClauseInfo(idx))->priorMean = priorMean_;
01262 return idx;
01263 }
01264
01265
01266 Clause* removeClauseFromMLN(const int& remClauseIdx, const int& domainIdx)
01267 {
01268 Clause* remClause = (*mlns_)[domainIdx]->removeClause(remClauseIdx);
01269 if (remClause == NULL)
01270 {
01271 cout << "ERROR: failed to remove ";
01272 remClause->printWithoutWtWithStrVar(cout, (*domains_)[0]);
01273 cout << " from MLN" << endl;
01274 exit(-1);
01275 }
01276 return remClause;
01277 }
01278
01279
01281 private:
01282 void addPredicateToClause(Clause* const & beamClause, const int& op,
01283 const int& removeClauseIdx,
01284 ClauseOpHashArray& newClauses)
01285 {
01286
01287 clauseFactory_->addPredicateToClause(*preds_,beamClause,op,removeClauseIdx,
01288 true,newClauses,false);
01289
01290
01291 if (beamClause->getNumPredicates() == 1)
01292 {
01293 beamClause->getPredicate(0)->invertSense();
01294 clauseFactory_->addPredicateToClause(*preds_, beamClause, op,
01295 removeClauseIdx, true, newClauses,
01296 false);
01297 beamClause->getPredicate(0)->invertSense();
01298 }
01299 }
01300
01301
01302
01303 void removePredicateFromClause(Clause* const & beamClause, const int& op,
01304 const int& removeClauseIdx,
01305 ClauseOpHashArray& newClauses)
01306 {
01307 if (beamClause->getNumPredicates() > 2)
01308 clauseFactory_->removePredicateFromClause(beamClause, op, removeClauseIdx,
01309 newClauses);
01310 }
01311
01312
01313 void addNewClauseToCandidates(Clause* const & newClause,
01314 Array<Clause*>& candidates,
01315 Array<Clause*>* const & thrownOut)
01316 {
01317
01318 if (thrownOut && mln0_->containsClause(newClause))
01319 { thrownOut->append(newClause); return;}
01320 newClause->setWt(0);
01321 candidates.append(newClause);
01322 }
01323
01324
01325 void createCandidateClauses(const ClauseOpHashArray* const & beam,
01326 Array<Clause*>& candidates,
01327 const Array<Clause*>* const & initMLNClauses)
01328 {
01329
01330 Array<Clause*> thrownOutClauses;
01331 ClauseOpHashArray newClauses;
01332 for (int i = 0; i < beam->size(); i++)
01333 {
01334 Clause* beamClause = (*beam)[i];
01335 AuxClauseData* beamacd = beamClause->getAuxClauseData();
01336 int op = beamacd->op;
01337 int newClausesBegIdx = newClauses.size();
01338
01339 if (op == OP_ADD)
01340 {
01341 int remIdx = beamacd->removedClauseIdx;
01342 addPredicateToClause(beamClause, op, remIdx, newClauses);
01343 removePredicateFromClause(beamClause, op, remIdx, newClauses);
01344 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01345 addNewClauseToCandidates(newClauses[j], candidates,&thrownOutClauses);
01346 }
01347 else
01348 if (op == OP_REPLACE_ADDPRED)
01349 {
01350 addPredicateToClause(beamClause, op, beamacd->removedClauseIdx,
01351 newClauses);
01352 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01353 addNewClauseToCandidates(newClauses[j], candidates,&thrownOutClauses);
01354 }
01355 else
01356 if (op == OP_REPLACE_REMPRED)
01357 {
01358 removePredicateFromClause(beamClause, op, beamacd->removedClauseIdx,
01359 newClauses);
01360 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01361 addNewClauseToCandidates(newClauses[j], candidates,&thrownOutClauses);
01362 }
01363 else
01364 if (op == OP_NONE)
01365 {
01366 int idx = mln0_->findClauseIdx(beamClause);
01367 bool beamClauseInMLN = (idx >= 0);
01368 bool removeBeamClause = (beamClauseInMLN &&
01369 beamClause->getNumPredicates() > 1);
01370 bool isModClause = (!beamClauseInMLN || isModifiableClause(idx));
01371
01372 if (isModClause)
01373 {
01374 addPredicateToClause(beamClause, OP_ADD, -1, newClauses);
01375 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01376 addNewClauseToCandidates(newClauses[j], candidates,
01377 &thrownOutClauses);
01378
01379 if (removeBeamClause)
01380 {
01381 newClausesBegIdx = newClauses.size();
01382 addPredicateToClause(beamClause, OP_REPLACE_ADDPRED, idx,
01383 newClauses);
01384 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01385 addNewClauseToCandidates(newClauses[j], candidates,
01386 &thrownOutClauses);
01387 }
01388
01389 newClausesBegIdx = newClauses.size();
01390 removePredicateFromClause(beamClause, OP_ADD, -1, newClauses);
01391 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01392 addNewClauseToCandidates(newClauses[j], candidates,
01393 &thrownOutClauses);
01394
01395 if (removeBeamClause)
01396 {
01397 newClausesBegIdx = newClauses.size();
01398 removePredicateFromClause(beamClause, OP_REPLACE_REMPRED, idx,
01399 newClauses);
01400 for (int j = newClausesBegIdx; j < newClauses.size(); j++)
01401 addNewClauseToCandidates(newClauses[j], candidates,
01402 &thrownOutClauses);
01403 }
01404
01405 if (removeBeamClause)
01406 {
01407 Clause* c = new Clause(*beamClause);
01408 c->getAuxClauseData()->op = OP_REMOVE;
01409 c->getAuxClauseData()->removedClauseIdx = idx;
01410 addNewClauseToCandidates(c, candidates, NULL);
01411 }
01412 }
01413 }
01414 else
01415 assert(op == OP_REMOVE || op == OP_REPLACE);
01416 }
01417
01418
01419 if (initMLNClauses)
01420 {
01421
01422
01423 int newClausesBegIdx = newClauses.size();
01424 for (int i = 0; i < initMLNClauses->size(); i++)
01425 {
01426 Clause* c = new Clause(*((*initMLNClauses)[i]));
01427 if (newClauses.append(c) < 0) delete c;
01428 }
01429 for (int i = newClausesBegIdx; i < newClauses.size(); i++)
01430 addNewClauseToCandidates(newClauses[i], candidates, &thrownOutClauses);
01431 }
01432
01433 for (int i = 0; i < thrownOutClauses.size(); i++)
01434 delete thrownOutClauses[i];
01435 }
01436
01437
01439
01440 private:
01441 void useTightParams()
01442 {
01443 sampleClauses_ = false;
01444 pll_->setSampleGndPreds(false);
01445 lbfgsb_->setMaxIter(lbMaxIter_);
01446 lbfgsb_->setFtol(lbConvThresh_);
01447 cacheClauses_ = false;
01448 }
01449
01450
01451 void useLooseParams()
01452 {
01453 sampleClauses_ = origSampleClauses_;
01454 pll_->setSampleGndPreds(sampleGndPreds_);
01455 if (looseMaxIter_ >= 0) lbfgsb_->setMaxIter(looseMaxIter_);
01456 if (looseConvThresh_ >= 0) lbfgsb_->setFtol(looseConvThresh_);
01457 cacheClauses_ = origCacheClauses_;
01458 }
01459
01460
01461
01462 bool isModifiableClause(const int& i) const
01463 { return (!mln0_->isExistClause(i) && !mln0_->isExistUniqueClause(i)); }
01464
01465
01466 bool isNonModifiableFormula(const FormulaAndClauses* const & fnc) const
01467 { return (fnc->hasExist || fnc->isExistUnique); }
01468
01469
01470 void printIterAndTimeElapsed(const double& begSec)
01471 {
01472 cout << "Iteration " << iter_ << " took ";
01473 timer_.printTime(cout, timer_.time()-begSec); cout << endl << endl;
01474 cout << "Time elapsed = ";
01475 timer_.printTime(cout, timer_.time()-startSec_); cout << endl << endl;
01476 }
01477
01478
01479 void removeClausesFromMLNs()
01480 {
01481 for (int i = 0; i < mlns_->size(); i++)
01482 (*mlns_)[i]->removeAllClauses(NULL);
01483 }
01484
01485
01486 void reEvaluateCandidates(Array<Clause*>& candidates,
01487 const double& prevScore)
01488 {
01489 int numClausesFormulas = getNumClausesFormulas();
01490 Array<double> wts;
01491 wts.growToSize(numClausesFormulas + 2);
01492
01493 Array<double> priorMeans, priorStdDevs;
01494 setPriorMeansStdDevs(priorMeans, priorStdDevs, 1, NULL);
01495 if (indexTrans_) indexTrans_->appendClauseIdxToClauseFormulaIdxs(1, 1);
01496
01497 bool error;
01498 for (int i = 0; i < candidates.size(); i++)
01499 {
01500 candidates[i]->getAuxClauseData()->gain = 0;
01501 countAndMaxScoreEffectCandidate(candidates[i], &wts, NULL, prevScore,
01502 false, priorMeans, priorStdDevs, error,
01503 NULL, NULL);
01504 }
01505
01506 if (indexTrans_) indexTrans_->removeClauseIdxToClauseFormulaIdxs(1, 1);
01507
01508 Array<Clause*> tmpCand(candidates);
01509 rquicksort(tmpCand);
01510 candidates.clear();
01511 for (int i = 0; i < tmpCand.size(); i++)
01512 {
01513 if (tmpCand[i]->getAuxClauseData()->gain > minGain_ &&
01514 fabs(tmpCand[i]->getWt()) >= minWt_)
01515 candidates.append(tmpCand[i]);
01516 else
01517 delete tmpCand[i];
01518 }
01519
01520 cout << "reevaluated top " << candidates.size()
01521 << " candidates with tight params:" << endl;
01522 cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << endl;
01523 for (int i = 0; i < candidates.size(); i++)
01524 {
01525 cout << i << "\t";
01526 candidates[i]->printWithoutWtWithStrVar(cout,(*domains_)[0]);
01527 cout << endl
01528 << "\tgain = " << candidates[i]->getAuxClauseData()->gain
01529 << ", wt = " << candidates[i]->getWt()
01530 << ", op = "
01531 << Clause::getOpAsString(candidates[i]->getAuxClauseData()->op)
01532 << endl;
01533 }
01534 cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << endl << endl;;
01535 }
01536
01537
01538
01539
01540
01541 bool findBestClausesAmongCandidates(Array<Clause*>& candidates,
01542 ClauseOpHashArray* const & beam,
01543 Array<Clause*>& bestClauses)
01544 {
01545 assert(beam->size() == 0);
01546
01547
01548
01549
01550 rquicksort(candidates);
01551
01552 for (int i = 0; i < candidates.size(); i++)
01553 {
01554 if (beam->size() >= beamSize_) break;
01555 double candGain = candidates[i]->getAuxClauseData()->gain;
01556 double candAbsWt = fabs(candidates[i]->getWt());
01557 if (candGain > minGain_ && candAbsWt >= minWt_)
01558 {
01559 int a = beam->append(new Clause(*(candidates[i])));
01560 assert(a >= 0); a = 0;
01561 }
01562 }
01563
01564 if (beam->size() == 0)
01565 {
01566 for (int i = 0; i < candidates.size(); i++) delete candidates[i];
01567 return false;
01568 }
01569
01570
01571 Clause* prevBestClause = (bestClauses.size() > 0) ?
01572 new Clause(*(bestClauses[0])) : NULL;
01573
01574
01575
01576 ClauseOpSet cset;
01577 ClauseOpSet::iterator it;
01578 for (int i = 0; i < candidates.size(); i++) cset.insert(candidates[i]);
01579
01580 for (int i = 0; i < bestClauses.size(); i++)
01581 {
01582 if ((it=cset.find(bestClauses[i])) == cset.end())
01583 {
01584 candidates.append(bestClauses[i]);
01585 cset.insert(bestClauses[i]);
01586 }
01587 else
01588 {
01589 assert((*it)->getAuxClauseData()->gain ==
01590 bestClauses[i]->getAuxClauseData()->gain);
01591 delete bestClauses[i];
01592 }
01593 }
01594
01595 rquicksort(candidates);
01596
01597 bestClauses.clear();
01598 for (int i = 0; i < candidates.size(); i++)
01599 {
01600 if (bestClauses.size() < numEstBestClauses_)
01601 {
01602 double candGain = candidates[i]->getAuxClauseData()->gain;
01603 double candAbsWt = fabs(candidates[i]->getWt());
01604 if (candGain > minGain_ && candAbsWt >= minWt_)
01605 bestClauses.append(candidates[i]);
01606 else
01607 {
01608
01609
01610
01611
01612
01613
01614 delete candidates[i];
01615 }
01616 }
01617 else
01618 delete candidates[i];
01619 }
01620
01621
01622 bool bestClauseChanged;
01623 if (bestClauses.size() > 0)
01624 {
01625 if (prevBestClause == NULL) bestClauseChanged = true;
01626 else
01627 {
01628
01629 if (bestClauses[0]->same(prevBestClause) &&
01630 bestClauses[0]->getAuxClauseData()->op ==
01631 prevBestClause->getAuxClauseData()->op)
01632 bestClauseChanged = false;
01633 else
01634 {
01635
01636 if (bestClauses[0]->getAuxClauseData()->gain >
01637 prevBestClause->getAuxClauseData()->gain)
01638 bestClauseChanged = true;
01639 else
01640 bestClauseChanged = false;
01641 }
01642 }
01643 }
01644 else
01645 bestClauseChanged = false;
01646
01647 if (prevBestClause) delete prevBestClause;
01648 return bestClauseChanged;
01649 }
01650
01651
01652 bool effectBestCandidateOnMLNs(Clause* const & cand, double& score)
01653 {
01654 if (appendToAndRemoveFromMLNs(cand, score))
01655 {
01656 printMLNClausesWithWeightsAndScore(score, iter_);
01657
01658 return true;
01659 }
01660 return false;
01661 }
01662
01663
01664 bool effectBestCandidateOnMLNs(Array<Clause*>& bestCandidates, double& score)
01665 {
01666 cout << "effecting best candidate on MLN..." << endl << endl;
01667 bool ok = false;
01668 int i;
01669 for (i = 0; i < bestCandidates.size(); i++)
01670 {
01671 cout << "effecting best candidate " << i << " on MLN..." << endl << endl;
01672 if (ok=effectBestCandidateOnMLNs(bestCandidates[i], score)) break;
01673 cout << "failed to effect candidate on MLN." << endl;
01674 delete bestCandidates[i];
01675 }
01676 for (int j = i+1; j < bestCandidates.size();j++) delete bestCandidates[j];
01677 return ok;
01678 }
01679
01680
01681 double getPenalty(const Clause* const & cand)
01682 {
01683 int op = cand->getAuxClauseData()->op;
01684
01685 if (op == OP_ADD) return cand->getNumPredicates() * penalty_;
01686
01687 if (op == OP_REPLACE_ADDPRED)
01688 {
01689 int remIdx = cand->getAuxClauseData()->removedClauseIdx;
01690 int origlen = mln0_->getClause(remIdx)->getNumPredicates();
01691 int diff = cand->getNumPredicates() - origlen;
01692 assert(diff > 0);
01693 return diff * penalty_;
01694 }
01695
01696 if (op == OP_REPLACE_REMPRED)
01697 {
01698 int remIdx = cand->getAuxClauseData()->removedClauseIdx;
01699 int origlen = mln0_->getClause(remIdx)->getNumPredicates();
01700 int diff = origlen - cand->getNumPredicates();
01701 assert(diff > 0);
01702 return diff * penalty_;
01703 }
01704
01705 if (op == OP_REPLACE)
01706 {
01707
01708
01709 int remIdx = cand->getAuxClauseData()->removedClauseIdx;
01710 const Clause* mlnClause = mln0_->getClause(remIdx);
01711 assert(cand->getNumPredicates() == mlnClause->getNumPredicates());
01712 int diff = 0;
01713 for (int i = 0; i < cand->getNumPredicates(); i++)
01714 {
01715 Predicate* cpred = (Predicate*) cand->getPredicate(i);
01716 Predicate* mpred = (Predicate*) mlnClause->getPredicate(i);
01717 assert(cpred->same(mpred));
01718 if (cpred->getSense() != mpred->getSense()) diff++;
01719 }
01720 return diff * penalty_;
01721 }
01722
01723 if (op == OP_REMOVE) return cand->getNumPredicates() * penalty_;
01724
01725 assert(false);
01726 return 88888;
01727 }
01728
01729
01730 void printNewScore(const Array<Clause*>& carray, const Domain* const & domain,
01731 const int& lbfgsbIter, const double& lbfgsbSec,
01732 const double& newScore, const double& gain,
01733 const Array<double>& penalties)
01734 {
01735 cout << "*************************** " << candCnt_++ << ", iter " << iter_
01736 << ", beam search iter " << bsiter_ << endl;
01737 for (int i = 0; i < carray.size(); i++)
01738 {
01739 if (carray[i])
01740 {
01741 cout << "candidate : ";
01742 carray[i]->printWithoutWtWithStrVar(cout,domain); cout << endl;
01743 cout << "op : ";
01744 cout << Clause::getOpAsString(carray[i]->getAuxClauseData()->op) <<endl;
01745 cout << "removed clause: ";
01746 int remIdx = carray[i]->getAuxClauseData()->removedClauseIdx;
01747 if (remIdx < 0) cout << "NULL";
01748 else { mln0_->getClause(remIdx)->printWithoutWtWithStrVar(cout,domain);}
01749 cout << endl;
01750 if (carray[i]->getAuxClauseData()->prevClauseStr.length() > 0)
01751 {
01752 cout << "prevClause : ";
01753 cout << carray[i]->getAuxClauseData()->prevClauseStr << endl;
01754 }
01755 if (carray[i]->getAuxClauseData()->addedPredStr.length() > 0)
01756 {
01757 cout << "addedPred : ";
01758 cout << carray[i]->getAuxClauseData()->addedPredStr << endl;
01759 }
01760 if (carray[i]->getAuxClauseData()->removedPredIdx >= 0)
01761 {
01762 cout << "removedPredIdx: ";
01763 cout << carray[i]->getAuxClauseData()->removedPredIdx << endl;
01764 }
01765
01766 cout << "score : " << newScore << endl;
01767 cout << "gain : " << gain << endl;
01768 cout << "penalty : " << penalties[i] << endl;
01769 cout << "net gain : " << gain - penalties[i] << endl;
01770 if (carray[i]->getAuxClauseData()->op != OP_REMOVE)
01771 cout << "wt : " << carray[i]->getWt() << endl;
01772 }
01773 }
01774 cout << "num LBFGSB iter = " << lbfgsbIter << endl;
01775 cout << "time taken by LBFGSB = "; Timer::printTime(cout, lbfgsbSec);
01776 cout << endl;
01777 cout << "*************************** " << endl << endl;;
01778 }
01779
01780
01781 void printNewScore(const Clause* const & c, const Domain* const & domain,
01782 const int& lbfgsbIter, const double& lbfgsbSec,
01783 const double& newScore, const double& gain,
01784 const double& penalty)
01785 {
01786 Array<Clause*> carray;
01787 Array<double> parray;
01788 if (c) { carray.append((Clause*)c); parray.append(penalty); }
01789 printNewScore(carray, domain, lbfgsbIter, lbfgsbSec, newScore, gain,parray);
01790 }
01791
01792
01793 void printMLNClausesWithWeightsAndScore(const double& score, const int& iter)
01794 {
01795 if (iter >= 0) cout << "MLN in iteration " << iter << ":" << endl;
01796 else cout << "MLN:" << endl;
01797 cout << "------------------------------------" << endl;
01798 if (indexTrans_) indexTrans_->printClauseFormulaWts(cout, true);
01799 else mln0_->printMLNClausesFormulas(cout, (*domains_)[0], true);
01800 cout << "------------------------------------" << endl;
01801 cout << "score = "<< score << endl << endl;
01802 }
01803
01804
01805 void printMLNToFile(const char* const & appendStr, const int& iter)
01806 {
01807 string fname = outMLNFileName_;
01808
01809 if (appendStr) fname.append(".").append(appendStr);
01810
01811 if (iter >= -1)
01812 {
01813 char buf[100];
01814 sprintf(buf, "%d", iter);
01815 fname.append(".iter").append(buf);
01816 }
01817
01818 if (appendStr || iter >= -1) fname.append(".mln");
01819
01820 ofstream out(fname.c_str());
01821 if (!out.good()) { cout << "ERROR: failed to open " <<fname<<endl;exit(-1);}
01822
01823 out << "//predicate declarations" << endl;
01824 (*domains_)[0]->printPredicateTemplates(out);
01825 out << endl;
01826
01827
01828 out << "//function declarations" << endl;
01829 (*domains_)[0]->printFunctionTemplates(out);
01830 out << endl;
01831
01832 if (indexTrans_) indexTrans_->printClauseFormulaWts(out, false);
01833 else mln0_->printMLNClausesFormulas(out, (*domains_)[0], false);
01834
01835 out << endl;
01836 out.close();
01837 }
01838
01839
01840 void printClausesInBeam(const ClauseOpHashArray* const & beam)
01841 {
01842 cout.setf(ios_base::left, ios_base::adjustfield);
01843 cout << "^^^^^^^^^^^^^^^^^^^ beam ^^^^^^^^^^^^^^^^^^^" << endl;
01844 for (int i = 0; i < beam->size(); i++)
01845 {
01846 cout << i << ": ";
01847 cout.width(10); cout << (*beam)[i]->getWt(); cout.width(0); cout << " ";
01848 (*beam)[i]->printWithoutWtWithStrVar(cout, (*domains_)[0]); cout << endl;
01849 }
01850 cout.width(0);
01851 cout << "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" << endl;
01852
01853 }
01854
01855 void setPriorMeansStdDevs(Array<double>& priorMeans,
01856 Array<double>& priorStdDevs, const int& addSlots,
01857 const Array<int>* const& removedSlotIndexes)
01858 {
01859 priorMeans.clear(); priorStdDevs.clear();
01860
01861 if (indexTrans_)
01862 {
01863 indexTrans_->setPriorMeans(priorMeans);
01864 priorStdDevs.growToSize(priorMeans.size());
01865 for (int i = 0; i < priorMeans.size(); i++)
01866 priorStdDevs[i] = priorStdDev_;
01867 }
01868 else
01869 {
01870 for (int i = 0; i < mln0_->getNumClauses(); i++)
01871 {
01872 priorMeans.append(mln0_->getMLNClauseInfo(i)->priorMean);
01873 priorStdDevs.append(priorStdDev_);
01874 }
01875 }
01876
01877 if (removedSlotIndexes)
01878 {
01879 for (int i = 0; i < removedSlotIndexes->size(); i++)
01880 priorMeans[ (*removedSlotIndexes)[i] ] = 0;
01881 }
01882
01883 for (int i = 0; i < addSlots; i++)
01884 {
01885 priorMeans.append(priorMean_);
01886 priorStdDevs.append(priorStdDev_);
01887 }
01888 }
01889
01890
01891 void pllSetPriorMeansStdDevs(Array<double>& priorMeans,
01892 Array<double>& priorStdDevs, const int& addSlots,
01893 const Array<int>* const & removedSlotIndexes)
01894 {
01895 if (hasPrior_)
01896 {
01897 setPriorMeansStdDevs(priorMeans, priorStdDevs,
01898 addSlots, removedSlotIndexes);
01899 pll_->setMeansStdDevs(priorMeans.size(), priorMeans.getItems(),
01900 priorStdDevs.getItems());
01901 }
01902 else
01903 pll_->setMeansStdDevs(-1, NULL, NULL);
01904 }
01905
01906
01907 void setRemoveAppendedPriorMeans(const int& numClausesFormulas,
01908 Array<double>& priorMeans,
01909 const Array<int>& removedSlotIndexes,
01910 const int& addSlots,
01911 Array<double>& removedValues)
01912 {
01913 for (int i = 0; i < removedSlotIndexes.size(); i++)
01914 {
01915 removedValues.append(priorMeans[ removedSlotIndexes[i] ]);
01916 priorMeans[ removedSlotIndexes[i] ] = 0;
01917 }
01918
01919 for (int i = 0; i < addSlots; i++)
01920 priorMeans[numClausesFormulas+i] = priorMean_;
01921 }
01922
01923
01924 int getNumClausesFormulas()
01925 {
01926 if (indexTrans_) return indexTrans_->getNumClausesAndExistFormulas();
01927 return mln0_->getNumClauses();
01928 }
01929
01930
01931 void updateWts(const Array<double>& wts,
01932 const Array<Clause*>* const & appendedClauses,
01933 const Array<string>* const & appendedFormulas)
01934 {
01935 if (indexTrans_)
01936 {
01937 Array<double> tmpWts;
01938 tmpWts.growToSize(wts.size()-1);
01939 for (int i = 1; i < wts.size(); i++) tmpWts[i-1] = wts[i];
01940 indexTrans_->updateClauseFormulaWtsInMLNs(tmpWts, appendedClauses,
01941 appendedFormulas);
01942 }
01943 else
01944 {
01945 for (int i = 0; i < mlns_->size(); i++)
01946 for (int j = 0; j < (*mlns_)[i]->getNumClauses(); j++)
01947 ((Clause*)(*mlns_)[i]->getClause(j))->setWt(wts[j+1]);
01948 }
01949 }
01950
01951
01953
01954 private:
01955
01956 void rquicksort(Array<Clause*>& clauses, const int& l, const int& r)
01957 {
01958 Clause** items = (Clause**) clauses.getItems();
01959 if (l >= r) return;
01960 Clause* tmp = items[l];
01961 items[l] = items[(l+r)/2];
01962 items[(l+r)/2] = tmp;
01963
01964 int last = l;
01965 for (int i = l+1; i <= r; i++)
01966 if (items[i]->getAuxClauseData()->gain >
01967 items[l]->getAuxClauseData()->gain)
01968 {
01969 ++last;
01970 Clause* tmp = items[last];
01971 items[last] = items[i];
01972 items[i] = tmp;
01973 }
01974
01975 tmp = items[l];
01976 items[l] = items[last];
01977 items[last] = tmp;
01978 rquicksort(clauses, l, last-1);
01979 rquicksort(clauses, last+1, r);
01980 }
01981
01982 void rquicksort(Array<Clause*>& ca) { rquicksort(ca, 0, ca.size()-1); }
01983
01984
01985
01987 private:
01988
01989 void deleteExistFormulas(Array<ExistFormula*>& existFormulas)
01990 { existFormulas.deleteItemsAndClear(); }
01991
01992
01993 void getMLNClauses(Array<Clause*>& initialMLNClauses,
01994 Array<ExistFormula*>& existFormulas)
01995 {
01996 for (int i = 0; i < mln0_->getNumClauses(); i++)
01997 if (isModifiableClause(i))
01998 {
01999 Clause* c = (Clause*) mln0_->getClause(i);
02000 initialMLNClauses.append(new Clause(*c));
02001 }
02002
02003 const FormulaAndClausesArray* fnca = mln0_->getFormulaAndClausesArray();
02004 for (int i = 0; i < fnca->size(); i++)
02005 if (isNonModifiableFormula((*fnca)[i]))
02006 existFormulas.append(new ExistFormula((*fnca)[i]->formula));
02007
02008 for (int i = 0; i < existFormulas.size(); i++)
02009 {
02010 string formula = existFormulas[i]->formula;
02011 FormulaAndClauses tmp(formula, 0, false);
02012 const FormulaAndClausesArray* fnca
02013 = (*mlns_)[0]->getFormulaAndClausesArray();
02014 int a = fnca->find(&tmp);
02015 existFormulas[i]->numPreds = (*fnca)[a]->numPreds;
02016
02017 Array<Array<Clause*> >& cnfClausesForDomains
02018 = existFormulas[i]->cnfClausesForDomains;
02019 cnfClausesForDomains.growToSize(mlns_->size());
02020 for (int j = 0; j < mlns_->size(); j++)
02021 {
02022 Array<Clause*>& cnfClauses = cnfClausesForDomains[j];
02023 fnca = (*mlns_)[j]->getFormulaAndClausesArray();
02024 a = fnca->find(&tmp);
02025 assert(a >= 0);
02026 IndexClauseHashArray* indexClauses = (*fnca)[a]->indexClauses;
02027 for (int k = 0; k < indexClauses->size(); k++)
02028 {
02029 Clause* c = new Clause(*((*indexClauses)[k]->clause));
02030 c->newAuxClauseData();
02031 cnfClauses.append(c);
02032 }
02033 cnfClauses.compress();
02034 }
02035 }
02036 }
02037
02038
02039
02040 inline void pllCountsForExistFormula(Clause* cnfClause,
02041 const int& domainIdx,
02042 int* clauseIdxInMln,
02043 Array<UndoInfo*>* const & undoInfos)
02044 {
02045 assert(cnfClause->getAuxClauseData()->cache == NULL);
02046 bool inCache = false;
02047 bool hasDomainCounts = false;
02048 double prevCNFClauseSize = 0;
02049
02050 if (cacheClauses_)
02051 {
02052 int i;
02053 if ((i = cachedClauses_->find(cnfClause)) >= 0)
02054 {
02055 inCache = true;
02056 Array<Array<Array<CacheCount*>*>*>* cache =
02057 (*cachedClauses_)[i]->getAuxClauseData()->cache;
02058 Array<Array<CacheCount*>*>* domainCache = (*cache)[domainIdx];
02059 for (int j = 0; j < domainCache->size(); j++)
02060 if ((*domainCache)[j] != NULL) { hasDomainCounts = true; break; }
02061 }
02062
02063 if (hasDomainCounts)
02064 {
02065 pll_->insertCounts(clauseIdxInMln, undoInfos,
02066 (*cachedClauses_)[i]->getAuxClauseData()->cache,
02067 domainIdx);
02068 return;
02069 }
02070 else
02071 {
02072 if (cacheSizeMB_ < maxCacheSizeMB_)
02073 {
02074 if (inCache)
02075 {
02076 cnfClause->getAuxClauseData()->cache
02077 = (*cachedClauses_)[i]->getAuxClauseData()->cache;
02078 prevCNFClauseSize = cnfClause->sizeMB();
02079 }
02080 else
02081 cnfClause->newCache(domains_->size(),
02082 (*domains_)[0]->getNumPredicates());
02083 }
02084 else
02085 {
02086 static bool printCacheFull = true;
02087 if (printCacheFull)
02088 {
02089 cout << "Cache is full, approximate size = " << cacheSizeMB_
02090 <<" MB" << endl;
02091 printCacheFull = false;
02092 }
02093 }
02094 }
02095 }
02096
02097
02098 pll_->computeCountsForNewAppendedClause(cnfClause,clauseIdxInMln,domainIdx,
02099 undoInfos, sampleClauses_,
02100 cnfClause->getAuxClauseData()->cache);
02101
02102 if (cnfClause->getAuxClauseData()->cache)
02103 {
02104 if (inCache)
02105 {
02106 cacheSizeMB_ += cnfClause->sizeMB() - prevCNFClauseSize;
02107 cnfClause->getAuxClauseData()->cache = NULL;
02108 }
02109 else
02110 {
02111 if (cacheSizeMB_ < maxCacheSizeMB_)
02112 {
02113 cacheSizeMB_ += cnfClause->sizeMB();
02114 Array<Array<Array<CacheCount*>*>*>* cache
02115 = cnfClause->getAuxClauseData()->cache;
02116 cnfClause->getAuxClauseData()->cache = NULL;
02117 Clause* copyClause = new Clause(*cnfClause);
02118 copyClause->getAuxClauseData()->cache = cache;
02119 copyClause->compress();
02120 cachedClauses_->append(copyClause);
02121 }
02122 else
02123 {
02124 cnfClause->getAuxClauseData()->deleteCache();
02125 cnfClause->getAuxClauseData()->cache = NULL;
02126 }
02127 }
02128 }
02129 }
02130
02131
02132 inline void printNewScore(const string existFormStr,const int& lbfgsbIter,
02133 const double& lbfgsbSec, const double& newScore,
02134 const double& gain, const double& penalty,
02135 const double& wt)
02136 {
02137 cout << "*************************** " << candCnt_++ << ", iter " << iter_
02138 << ", beam search iter " << bsiter_ << endl;
02139
02140 cout << "exist formula : " << existFormStr << endl;
02141 cout << "op : OP_ADD" << endl;
02142 cout << "score : " << newScore << endl;
02143 cout << "gain : " << gain << endl;
02144 cout << "penalty : " << penalty << endl;
02145 cout << "net gain : " << gain - penalty << endl;
02146 cout << "wt : " << wt << endl;
02147 cout << "num LBFGSB iter = " << lbfgsbIter << endl;
02148 cout << "time taken by LBFGSB = "; Timer::printTime(cout, lbfgsbSec);
02149 cout << endl;
02150 cout << "*************************** " << endl << endl;;
02151 }
02152
02153
02154 inline void rquicksort(Array<ExistFormula*>& existFormulas, const int& l,
02155 const int& r)
02156 {
02157 ExistFormula** items = (ExistFormula**) existFormulas.getItems();
02158 if (l >= r) return;
02159 ExistFormula* tmp = items[l];
02160 items[l] = items[(l+r)/2];
02161 items[(l+r)/2] = tmp;
02162
02163 int last = l;
02164 for (int i = l+1; i <= r; i++)
02165 if (items[i]->gain > items[l]->gain)
02166 {
02167 ++last;
02168 ExistFormula* tmp = items[last];
02169 items[last] = items[i];
02170 items[i] = tmp;
02171 }
02172
02173 tmp = items[l];
02174 items[l] = items[last];
02175 items[last] = tmp;
02176 rquicksort(existFormulas, l, last-1);
02177 rquicksort(existFormulas, last+1, r);
02178 }
02179
02180
02181 inline void rquicksort(Array<ExistFormula*>& ef)
02182 { rquicksort(ef, 0, ef.size()-1); }
02183
02184
02185
02186
02187
02188 inline void evaluateExistFormula(ExistFormula* const& ef,
02189 const bool& computeCountsOnly,
02190 Array<Array<Array<IndexAndCount*> > >* const & iacArraysPerDomain,
02191 const double& prevScore)
02192 {
02193 bool undo = !computeCountsOnly;
02194 bool evalGainLearnWts = !computeCountsOnly;
02195
02196 Array<int*> idxPtrs;
02197 Array<UndoInfo*>* undoInfos = (undo) ? new Array<UndoInfo*> : NULL;
02198
02199
02200
02201 Array<Array<Clause*> >& cnfClausesForDomains = ef->cnfClausesForDomains;
02202 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02203 {
02204 Array<Clause*>& cnfClauses = cnfClausesForDomains[d];
02205 MLN* mln = (*mlns_)[d];
02206 for (int k = 0; k < cnfClauses.size(); k++)
02207 {
02208 Clause* c = cnfClauses[k];
02209
02210
02211
02212 if (!undo && mln->containsClause(c)) continue;
02213 int* tmpClauseIdxInMLN = new int(mln->getNumClauses() + k);
02214 idxPtrs.append(tmpClauseIdxInMLN);
02215 if (iacArraysPerDomain)
02216 {
02217 Array<Array<IndexAndCount*> >& iacArrays = (*iacArraysPerDomain)[d];
02218 assert(iacArrays.size() == cnfClauses.size());
02219 Array<IndexAndCount*>& iacArray = iacArrays[k];
02220
02221 Array<UndoInfo*> tmpUndoInfos;
02222 pllCountsForExistFormula(c, d, tmpClauseIdxInMLN, &tmpUndoInfos);
02223 for (int i = 0; i < tmpUndoInfos.size(); i++)
02224 iacArray.append(tmpUndoInfos[i]->affectedArr->lastItem());
02225 if (undoInfos) undoInfos->append(tmpUndoInfos);
02226 else tmpUndoInfos.deleteItemsAndClear();
02227 }
02228 else
02229 pllCountsForExistFormula(c, d, tmpClauseIdxInMLN, undoInfos);
02230 }
02231 }
02232
02233
02234
02235 if (evalGainLearnWts)
02236 {
02237 Array<double> priorMeans, priorStdDevs;
02238
02239
02240
02241 int numAdded = (indexTrans_) ? 1 : cnfClausesForDomains[0].size();
02242 setPriorMeansStdDevs(priorMeans, priorStdDevs, numAdded, NULL);
02243
02244 if (hasPrior_)
02245 pll_->setMeansStdDevs(priorMeans.size(), priorMeans.getItems(),
02246 priorStdDevs.getItems());
02247 else
02248 pll_->setMeansStdDevs(-1, NULL, NULL);
02249
02250 if (indexTrans_)
02251 {
02252 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02253 {
02254 int numCNFClauses = cnfClausesForDomains[d].size();
02255 indexTrans_->appendClauseIdxToClauseFormulaIdxs(1, numCNFClauses);
02256 }
02257 }
02258
02259 int iter; bool error; double elapsedSec;
02260 int numClausesFormulas = getNumClausesFormulas();
02261 Array<double>* wts = &(ef->wts);
02262 wts->growToSize(numClausesFormulas + numAdded + 1);
02263
02264 double newScore = maximizeScore(numClausesFormulas, numAdded, wts,
02265 NULL, NULL, iter, error, elapsedSec);
02266
02267 if (indexTrans_)
02268 {
02269 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02270 {
02271 int numCNFClauses = cnfClausesForDomains[d].size();
02272 indexTrans_->removeClauseIdxToClauseFormulaIdxs(1, numCNFClauses);
02273 }
02274 }
02275
02276 double totalWt = 0;
02277 for (int i = 0; i < numAdded; i++)
02278 totalWt += (*wts)[numClausesFormulas+i+1];
02279
02280 double penalty = ef->numPreds * penalty_;
02281
02282 ef->gain = newScore-prevScore-penalty;
02283 ef->wt = totalWt;
02284 ef->newScore = newScore;
02285
02286 if (error) {newScore=prevScore;cout<<"LBFGSB failed to find wts"<<endl;}
02287 printNewScore(ef->formula, iter, elapsedSec, newScore,
02288 newScore-prevScore, penalty, totalWt);
02289 }
02290
02291 if (undo) { pll_->undoAppendRemoveCounts(undoInfos); delete undoInfos; }
02292 idxPtrs.deleteItemsAndClear();
02293 }
02294
02295
02296
02297 double evaluateExistFormulas(Array<ExistFormula*>& existFormulas,
02298 Array<ExistFormula*>& highGainWtFormulas,
02299 const double& prevScore)
02300 {
02301 if (existFormulas.size() == 0) return 0;
02302 for (int i = 0; i < existFormulas.size(); i++)
02303 evaluateExistFormula(existFormulas[i], false, NULL, prevScore);
02304
02305 Array<ExistFormula*> tmp(existFormulas);
02306 rquicksort(tmp);
02307 double minGain = DBL_MAX;
02308 cout << "evaluated existential formulas " << endl;
02309 cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << endl;
02310 for (int i = 0; i < tmp.size(); i++)
02311 {
02312 existFormulas[i] = tmp[i];
02313 double gain = existFormulas[i]->gain;
02314 double wt = existFormulas[i]->wt;
02315 cout << i << "\t" << existFormulas[i]->formula << endl
02316 << "\tgain = " << gain << ", wt = " << wt << ", op = OP_ADD"
02317 << endl;
02318 if (gain > minGain_ && wt >= minWt_)
02319 {
02320 highGainWtFormulas.append(tmp[i]);
02321 if (gain < minGain) minGain = gain;
02322 }
02323 }
02324 cout << "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" << endl << endl;
02325
02326 if (minGain == DBL_MAX) minGain = 0;
02327 return minGain;
02328 }
02329
02330
02331 inline void appendExistFormulaToMLNs(ExistFormula* const & ef)
02332 {
02333 Array<Array<Array<IndexAndCount*> > > iacsPerDomain;
02334 Array<Array<Clause*> >& cnfClausesForDomains = ef->cnfClausesForDomains;
02335 iacsPerDomain.growToSize(cnfClausesForDomains.size());
02336 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02337 {
02338 Array<Array<IndexAndCount*> >& iacArrays = iacsPerDomain[d];
02339 iacArrays.growToSize( cnfClausesForDomains[d].size() );
02340 }
02341 evaluateExistFormula(ef, true, &iacsPerDomain, 0);
02342
02343 Array<double>& wts = ef->wts;
02344 int numClausesFormulas = getNumClausesFormulas();
02345
02346
02347
02348 if (indexTrans_ == NULL) updateWts(wts, NULL, NULL);
02349
02350
02351 for (int d = 0; d < cnfClausesForDomains.size(); d++)
02352 {
02353 MLN* mln = (*mlns_)[d];
02354 Array<Array<IndexAndCount*> >& iacArrays = iacsPerDomain[d];
02355 Array<Clause*>& cnfClauses = cnfClausesForDomains[d];
02356
02357 for (int i = 0; i < cnfClauses.size(); i++)
02358 {
02359 int idx;
02360
02361
02362 double wt = (indexTrans_) ? 0 : wts[numClausesFormulas+i+1];
02363
02364
02365 mln->appendClause(ef->formula, true, new Clause(*cnfClauses[i]),
02366 wt, false, idx);
02367 mln->setFormulaPriorMean(ef->formula, priorMean_);
02368 ((MLNClauseInfo*)mln->getMLNClauseInfo(idx))->priorMean
02369 += priorMean_/cnfClauses.size();
02370
02371 int* idxPtr = mln->getMLNClauseInfoIndexPtr(idx);
02372 Array<IndexAndCount*>& iacs = iacArrays[i];
02373 for (int j = 0; j < iacs.size(); j++) iacs[j]->index = idxPtr;
02374 }
02375 }
02376 assert(pll_->checkNoRepeatedIndex());
02377
02378
02379 if (indexTrans_)
02380 {
02381
02382 Array<string> appendedFormula;
02383 appendedFormula.append(ef->formula);
02384 updateWts(wts, NULL, &appendedFormula);
02385
02386 indexTrans_->createClauseIdxToClauseFormulaIdxsMap();
02387 }
02388
02389 cout << "Modified MLN: Appended formula to MLN: " << ef->formula << endl;
02390 }
02391
02392
02393 inline bool effectExistFormulaOnMLNs(ExistFormula* ef,
02394 Array<ExistFormula*>& existFormulas,
02395 double& score)
02396 {
02397 cout << "effecting existentially quantified formula " << ef->formula
02398 << " on MLN..." << endl;
02399 appendExistFormulaToMLNs(ef);
02400 score = ef->newScore;
02401 printMLNClausesWithWeightsAndScore(score, iter_);
02402 printMLNToFile(NULL, iter_);
02403
02404 int r = existFormulas.find(ef);
02405 assert(r >= 0);
02406 ExistFormula* rf = existFormulas.removeItemFastDisorder(r);
02407 assert(rf == ef);
02408 delete rf;
02409 return true;
02410 }
02411
02412
02413 bool effectBestCandidateOnMLNs(Array<Clause*>& bestCandidates,
02414 Array<ExistFormula*>& existFormulas,
02415 Array<ExistFormula*>& highGainWtFormulas,
02416 double& score)
02417 {
02418 cout << "effecting best candidate among existential formulas and "
02419 << "best candidates on MLN..." << endl << endl;
02420
02421 int a = 0, b = 0;
02422 bool ok = false;
02423 int numCands = bestCandidates.size() + highGainWtFormulas.size();
02424 for (int i = 0; i < numCands; i++)
02425 {
02426 if (a >= bestCandidates.size())
02427 {
02428 if (ok=effectExistFormulaOnMLNs(highGainWtFormulas[b++],
02429 existFormulas, score)) break;
02430 }
02431 else
02432 if (b >= highGainWtFormulas.size())
02433 {
02434 cout << "effecting best candidate " << a << " on MLN..." << endl;
02435 if (ok=effectBestCandidateOnMLNs(bestCandidates[a++], score)) break;
02436 cout << "failed to effect candidate on MLN." << endl;
02437 delete bestCandidates[a-1];
02438 }
02439 else
02440 if (highGainWtFormulas[b]->gain >
02441 bestCandidates[a]->getAuxClauseData()->gain)
02442 {
02443 if (ok=effectExistFormulaOnMLNs(highGainWtFormulas[b++],
02444 existFormulas, score)) break;
02445 }
02446 else
02447 {
02448 cout << "effecting best candidate " << a << " on MLN..." << endl;
02449 if (ok=effectBestCandidateOnMLNs(bestCandidates[a++], score)) break;
02450 cout << "failed to effect candidate on MLN." << endl;
02451 delete bestCandidates[a-1];
02452 }
02453 }
02454
02455 for (int i = a; i < bestCandidates.size(); i++) delete bestCandidates[i];
02456 return ok;
02457 }
02458
02459
02460
02461
02462
02463
02464
02465
02466
02467
02468
02469
02470
02471
02472
02473
02474
02475
02476
02477
02478
02479
02480
02481
02482
02483
02484
02485
02486
02487
02488
02489
02490
02491
02492
02493
02494
02495
02496
02497
02498
02499
02500
02501
02502
02504 private:
02505
02506
02507 MLN* mln0_;
02508 Array<MLN*>* mlns_;
02509 bool startFromEmptyMLN_;
02510 string outMLNFileName_;
02511 Array<Domain*>* domains_;
02512 Array<Predicate*>* preds_;
02513 Array<bool>* areNonEvidPreds_;
02514 ClauseFactory* clauseFactory_;
02515
02516 bool cacheClauses_;
02517 bool origCacheClauses_;
02518 ClauseHashArray* cachedClauses_;
02519 double cacheSizeMB_;
02520 double maxCacheSizeMB_;
02521
02522 bool tryAllFlips_;
02523 bool sampleClauses_;
02524 bool origSampleClauses_;
02525
02526 PseudoLogLikelihood* pll_;
02527 bool hasPrior_;
02528 double priorMean_;
02529 double priorStdDev_;
02530 bool wtPredsEqually_;
02531
02532 LBFGSB* lbfgsb_;
02533 int lbMaxIter_;
02534 double lbConvThresh_;
02535 int looseMaxIter_;
02536 double looseConvThresh_;
02537
02538 int beamSize_;
02539 int bestGainUnchangedLimit_;
02540 int numEstBestClauses_;
02541 double minGain_;
02542 double minWt_;
02543 double penalty_;
02544
02545 bool sampleGndPreds_;
02546 double fraction_;
02547 int minGndPredSamples_;
02548 int maxGndPredSamples_;
02549
02550 bool reEvalBestCandsWithTightParams_;
02551
02552 Timer timer_;
02553 int candCnt_;
02554
02555 int iter_;
02556 int bsiter_;
02557 double startSec_;
02558
02559
02560
02561
02562
02563 IndexTranslator* indexTrans_;
02564
02565 bool structGradDescent_;
02566 bool withEM_;
02567 };
02568
02569
02570 #endif