00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef LOGPSEUDOLIKELIHOOD_H_AUG_18_2005
00067 #define LOGPSEUDOLIKELIHOOD_H_AUG_18_2005
00068
00069 #include <cmath>
00070 #include "array.h"
00071 #include "random.h"
00072 #include "domain.h"
00073 #include "clause.h"
00074 #include "mln.h"
00075 #include "indextranslator.h"
00076
00077
00078
00079
00080 #ifndef expl
00081 # define expl exp
00082 # define logl log
00083 #endif
00084
00086
00087
00088
00089 const bool DB_HAS_UNKNOWN_PREDS = false;
00090
00091 struct IndexAndCount
00092 {
00093 IndexAndCount() : index(NULL), count(0) {}
00094 IndexAndCount(int* const & i, const double& c) : index(i), count(c) {}
00095 int* index;
00096 double count;
00097 };
00098
00099
00100 struct UndoInfo
00101 {
00102 UndoInfo(Array<IndexAndCount*>* const & affArr,
00103 IndexAndCount* const & iac, const int& remIdx, const int& domIdx)
00104 : affectedArr(affArr), remIac(iac), remIacIdx(remIdx), domainIdx(domIdx) {}
00105 ~UndoInfo() {if (remIac) delete remIac; }
00106 Array<IndexAndCount*>* affectedArr;
00107 IndexAndCount* remIac;
00108 int remIacIdx;
00109 int domainIdx;
00110 };
00111
00112
00113 struct SampledGndings
00114 {
00115 IntHashArray trueGndings;
00116 IntHashArray falseGndings;
00117 int totalTrue;
00118 int totalFalse;
00119 };
00120
00121
00123
00124
00125 class PseudoLogLikelihood
00126 {
00127 public:
00128 PseudoLogLikelihood(const Array<bool>* const & areNonEvidPreds,
00129 const Array<Domain*>* const & domains,
00130 const bool& wtFOPred, const bool& sampleGndPreds,
00131 const double& fraction, const int& minGndPredSamples,
00132 const int& maxGndPredSamples)
00133 : domains_(new Array<Domain*>) , numMeans_(-1),
00134 priorMeans_(NULL), priorStdDevs_(NULL), wtFOPred_(wtFOPred),
00135 sampleGndPreds_(sampleGndPreds), idxTrans_(NULL)
00136 {
00137
00138 if (areNonEvidPreds)
00139 {
00140 areNonEvidPreds_ = new Array<bool>(*areNonEvidPreds);
00141 assert(areNonEvidPreds_->size() == (*domains)[0]->getNumPredicates());
00142 }
00143 else
00144 {
00145 areNonEvidPreds_ = new Array<bool>;
00146 areNonEvidPreds_->growToSize((*domains)[0]->getNumPredicates(), true);
00147 }
00148
00149 int numDomains = domains->size();
00150 assert(numDomains > 0);
00151 domains_->growToSize(numDomains,NULL);
00152 for (int i = 0; i < numDomains; i++)
00153 (*domains_)[i] = (*domains)[i];
00154
00155 gndPredClauseIndexesAndCountsArr_
00156 = new Array<Array<Array<Array<Array<IndexAndCount*>*>*>*>*>;
00157 gndPredClauseIndexesAndCountsArr_->growToSize(numDomains,NULL);
00158 for (int i = 0; i < numDomains; i++)
00159 {
00160 (*gndPredClauseIndexesAndCountsArr_)[i]
00161 = new Array<Array<Array<Array<IndexAndCount*>*>*>*>;
00162 (*gndPredClauseIndexesAndCountsArr_)[i]->growToSize(
00163 (*domains_)[i]->getNumPredicates(),NULL);
00164 }
00165
00166 createNumGndings();
00167
00168 if (sampleGndPreds_)
00169 {
00170 sampledGndingsMaps_ = new Array<Array<SampledGndings*>*>;
00171 sampledGndingsMaps_->growToSize(numDomains, NULL);
00172 for (int i = 0; i < numDomains; i++)
00173 {
00174 (*sampledGndingsMaps_)[i] = new Array<SampledGndings*>;
00175 (*sampledGndingsMaps_)[i]->growToSize((*domains_)[i]->getNumPredicates()
00176 , NULL);
00177 }
00178 random_ = new Random;
00179 random_->init(-3);
00180 samplePredGroundings(fraction, minGndPredSamples, maxGndPredSamples);
00181 }
00182 else
00183 {
00184 sampledGndingsMaps_ = NULL;
00185 random_ = NULL;
00186 }
00187 }
00188
00189
00190 ~PseudoLogLikelihood()
00191 {
00192 delete areNonEvidPreds_;
00193 delete domains_;
00194
00195 for (int i = 0; i < gndPredClauseIndexesAndCountsArr_->size(); i++)
00196 {
00197 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00198 gndPredClauseIndexesAndCounts
00199 = (*gndPredClauseIndexesAndCountsArr_)[i];
00200
00201 int numPreds = gndPredClauseIndexesAndCounts->size();
00202 for (int p = 0; p < numPreds; p++)
00203 {
00204 if ( (*gndPredClauseIndexesAndCounts)[p] )
00205 {
00206 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00207 = (*gndPredClauseIndexesAndCounts)[p];
00208 int numGnds = gndingsToClauseIndexesAndCounts->size();
00209 for (int g = 0; g < numGnds; g++)
00210 {
00211 for (int h = 0; h < (*gndingsToClauseIndexesAndCounts)[g]->size();
00212 h++)
00213 {
00214 for (int j = 0;
00215 j < (*(*gndingsToClauseIndexesAndCounts)[g])[h]->size();
00216 j++)
00217 {
00218 delete (*(*(*gndingsToClauseIndexesAndCounts)[g])[h])[j];
00219 }
00220 delete (*(*gndingsToClauseIndexesAndCounts)[g])[h];
00221 }
00222 delete (*gndingsToClauseIndexesAndCounts)[g];
00223 }
00224 delete gndingsToClauseIndexesAndCounts;
00225 }
00226 }
00227 delete gndPredClauseIndexesAndCounts;
00228 }
00229 delete gndPredClauseIndexesAndCountsArr_;
00230
00231 numGndings_->deleteItemsAndClear();
00232 delete numGndings_;
00233
00234 if (sampledGndingsMaps_)
00235 {
00236 for (int i = 0; i < sampledGndingsMaps_->size(); i++)
00237 {
00238 (*sampledGndingsMaps_)[i]->deleteItemsAndClear();
00239 delete (*sampledGndingsMaps_)[i];
00240 }
00241 delete sampledGndingsMaps_;
00242 }
00243
00244 if (random_) delete random_;
00245 }
00246
00247
00248 void compress()
00249 {
00250 for (int i = 0; i < gndPredClauseIndexesAndCountsArr_->size(); i++)
00251 {
00252 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00253 gndPredClauseIndexesAndCounts
00254 = (*gndPredClauseIndexesAndCountsArr_)[i];
00255
00256 int numPreds = gndPredClauseIndexesAndCounts->size();
00257 for (int p = 0; p < numPreds; p++)
00258 {
00259
00260 if ((*gndPredClauseIndexesAndCounts)[p])
00261 {
00262 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00263 = (*gndPredClauseIndexesAndCounts)[p];
00264 int numGnds = gndingsToClauseIndexesAndCounts->size();
00265 for (int g = 0; g < numGnds; g++)
00266 {
00267 Array<Array<IndexAndCount*>*>* combosToClauseIndexesAndCounts
00268 = (*gndingsToClauseIndexesAndCounts)[g];
00269 int numCombos = combosToClauseIndexesAndCounts->size();
00270 for (int c = 0; c < numCombos; c++)
00271 (*combosToClauseIndexesAndCounts)[c]->compress();
00272 }
00273 }
00274 }
00275 }
00276
00277
00278 }
00279
00280
00281
00282 void insertCounts(int* const & clauseIdxInMLN,
00283 Array<UndoInfo*>* const & undoInfos,
00284 Array<Array<Array<CacheCount*>*>*>* const & cache,
00285 const int& d)
00286 {
00287 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00288 gndPredClauseIndexesAndCounts;
00289
00290 Array<IndexAndCount*>* gArr;
00291 CacheCount* cc;
00292 assert(cache->size() == domains_->size());
00293
00294 gndPredClauseIndexesAndCounts = (*gndPredClauseIndexesAndCountsArr_)[d];
00295 for (int p = 0; p < (*cache)[d]->size(); p++)
00296 {
00297 Array<CacheCount*>* ccArr = (*(*cache)[d])[p];
00298 if (ccArr == NULL) continue;
00299 for (int i = 0; i < ccArr->size(); i++)
00300 {
00301 assert((*gndPredClauseIndexesAndCounts)[p] != NULL);
00302 cc = (*ccArr)[i];
00303 gArr = (*(*(*gndPredClauseIndexesAndCounts)[p])[cc->g])[cc->c];
00304
00305 assert(gArr->size()==0 ||*(gArr->lastItem()->index)!=*clauseIdxInMLN);
00306 assert(cc->cnt != 0);
00307 gArr->append(new IndexAndCount(clauseIdxInMLN, cc->cnt));
00308 if (undoInfos) undoInfos->append(new UndoInfo(gArr, NULL, -1, d));
00309 }
00310 }
00311 }
00312
00313
00314
00315 void insertCounts(const Array<int*>& clauseIdxInMLNs,
00316 Array<UndoInfo*>* const & undoInfos,
00317 Array<Array<Array<CacheCount*>*>*>* const & cache)
00318 {
00319 assert(cache->size() == domains_->size());
00320 assert(clauseIdxInMLNs.size() == domains_->size());
00321
00322 for (int d = 0; d < cache->size(); d++)
00323 insertCounts(clauseIdxInMLNs[d], undoInfos, cache, d);
00324 }
00325
00326
00327
00328
00329
00330 void computeCountsForNewAppendedClause(const Clause* const & c,
00331 int* const & clauseIdxInMLN,
00332 const int& domainIdx,
00333 Array<UndoInfo*>* const & undoInfos,
00334 const bool& sampleClauses,
00335 Array<Array<Array<CacheCount*>*>*>* const & cache)
00336 {
00337 computeCountsRemoveCountsHelper(true, c, clauseIdxInMLN, domainIdx,
00338 undoInfos, sampleClauses, cache);
00339 }
00340
00341
00342 void removeCountsForClause(const Clause* const & c,
00343 int* const & clauseIdxInMLN, const int& domainIdx,
00344 Array<UndoInfo*>* const & undoInfos)
00345 {
00346 computeCountsRemoveCountsHelper(false, c, clauseIdxInMLN, domainIdx,
00347 undoInfos, false, NULL);
00348 }
00349
00350
00351
00352 void undoAppendRemoveCounts(const Array<UndoInfo*>* const & undoInfos)
00353 {
00354 for (int i = undoInfos->size() - 1; i >= 0; i--)
00355 {
00356 if ((*undoInfos)[i]->remIacIdx >= 0)
00357 {
00358 Array<IndexAndCount*>* affectedArr = (*undoInfos)[i]->affectedArr;
00359 IndexAndCount* remIac = (*undoInfos)[i]->remIac;
00360 (*undoInfos)[i]->remIac = NULL;
00361 int remIacIdx = (*undoInfos)[i]->remIacIdx;
00362
00363 if (affectedArr->size() == remIacIdx)
00364 affectedArr->append(remIac);
00365 else
00366 {
00367 assert(remIacIdx < affectedArr->size());
00368 IndexAndCount* tmpRemIac = (*affectedArr)[remIacIdx];
00369 (*affectedArr)[remIacIdx] = remIac;
00370 affectedArr->append(tmpRemIac);
00371 }
00372 }
00373 else
00374 {
00375 IndexAndCount* iac = (*undoInfos)[i]->affectedArr->removeLastItem();
00376 delete iac;
00377 }
00378
00379 assert(noRepeatedIndex((*undoInfos)[i]->affectedArr));
00380 delete (*undoInfos)[i];
00381 }
00382 }
00383
00384
00385 double getValueAndGradient(double* const & gradient, const double* const & wt,
00386 const int& arrSize)
00387 {
00388 double wpll = 0;
00389 memset(gradient, 0, arrSize*sizeof(double));
00390
00391
00392 if (idxTrans_ == NULL)
00393 {
00394 for (int i = 0; i < domains_->size(); i++)
00395 wpll += getValueAndGradientForDomain(gradient, wt, arrSize, i);
00396 }
00397 else
00398 {
00399 Array<Array<double> >* wtsPerDomain = idxTrans_->getWtsPerDomain();
00400 Array<Array<double> >* gradsPerDomain = idxTrans_->getGradsPerDomain();
00401 const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain
00402 = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00403
00404 for (int i = 0; i < domains_->size(); i++)
00405 {
00406 Array<double>& wts = (*wtsPerDomain)[i];
00407 Array<double>& grads = (*gradsPerDomain)[i];
00408 assert(grads.size() == wts.size());
00409 memset((double*)wts.getItems(), 0, wts.size()*sizeof(double));
00410 memset((double*)grads.getItems(), 0, grads.size()*sizeof(double));
00411
00412
00413
00414
00415
00416
00417
00418 for (int j = 0; j < wts.size(); j++)
00419 {
00420 Array<IdxDiv>* idxDivs =(*cIdxToCFIdxsPerDomain)[i][j];
00421 for (int k = 0; k < idxDivs->size(); k++)
00422 if ((*idxDivs)[k].idx < arrSize)
00423 wts[j] += wt[ (*idxDivs)[k].idx ] / (*idxDivs)[k].div;
00424 }
00425
00426 wpll += getValueAndGradientForDomain((double*)grads.getItems(),
00427 (double*)wts.getItems(),
00428 wts.size(), i);
00429
00430
00431 for (int j = 0; j < grads.size(); j++)
00432 {
00433 Array<IdxDiv>* idxDivs =(*cIdxToCFIdxsPerDomain)[i][j];
00434 for (int k = 0; k < idxDivs->size(); k++)
00435 if ((*idxDivs)[k].idx < arrSize)
00436 gradient[ (*idxDivs)[k].idx ] += grads[j] / (*idxDivs)[k].div;
00437 }
00438 }
00439 }
00440
00441
00442
00443
00444 if (numMeans_ > 0)
00445 {
00446
00447
00448
00449
00450
00451 for (int i = 0; i < arrSize; i++)
00452 {
00453
00454
00455 wpll += (wt[i]-priorMeans_[i])*(wt[i]-priorMeans_[i])/
00456 (2*priorStdDevs_[i]*priorStdDevs_[i]);
00457
00458 gradient[i] += (wt[i]-priorMeans_[i])/
00459 (priorStdDevs_[i]*priorStdDevs_[i]);
00460 }
00461 }
00462
00463
00464
00465 return wpll;
00466 }
00467
00468
00469
00470
00471 void setMeansStdDevs(const int& arrSize, const double* const & priorMeans,
00472 const double* const & priorStdDevs)
00473 {
00474 numMeans_ = arrSize;
00475 priorMeans_ = priorMeans;
00476 priorStdDevs_ = priorStdDevs;
00477 }
00478
00479
00480 void setSampleGndPreds(const bool& sgp)
00481 {
00482 if (sgp) { assert(sampledGndingsMaps_); assert(random_); }
00483 sampleGndPreds_ = sgp;
00484 }
00485
00486
00487 bool checkNoRepeatedIndex(const MLN* const & mln=NULL)
00488 {
00489 bool ret = true;
00490 for (int d = 0; d < domains_->size(); d++)
00491 {
00492 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00493 gndPredClauseIndexesAndCounts
00494 = (*gndPredClauseIndexesAndCountsArr_)[d];
00495
00496 int numPreds = gndPredClauseIndexesAndCounts->size();
00497 for (int p = 0; p < numPreds; p++)
00498 {
00499 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00500 = (*gndPredClauseIndexesAndCounts)[p];
00501
00502 if (gndingsToClauseIndexesAndCounts == NULL) continue;
00503
00504 int numGnds = gndingsToClauseIndexesAndCounts->size();
00505 for (int g = 0; g < numGnds; g++)
00506 {
00507 Array<Array<IndexAndCount*>*>* gndings
00508 = (*gndingsToClauseIndexesAndCounts)[g];
00509
00510 for (int c = 0; c < gndings->size(); c++)
00511 {
00512 bool ok = noRepeatedIndex((*gndings)[c], mln);
00513 if (!ok)
00514 {
00515 cout << "ERROR: repeated index in domain " << d << " for pred "
00516 << (*domains_)[0]->getPredicateName(p) << " gnding " << g
00517 << " combination " << c << endl;
00518 ret = false;
00519 }
00520 }
00521 }
00522 }
00523 }
00524 return ret;
00525 }
00526
00527
00528 void printGndPredClauseIndexesAndCounts(const MLN* const & mln=NULL)
00529 {
00530 for (int d = 0; d < domains_->size(); d++)
00531 {
00532 cout << "domainIdx: " << d << endl;
00533 cout << "gndPredClauseIndexesAndCounts[predIdx][gndingIdx][combIdx][i]"
00534 << endl;
00535
00536 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00537 gndPredClauseIndexesAndCounts
00538 = (*gndPredClauseIndexesAndCountsArr_)[d];
00539
00540 int numPreds = gndPredClauseIndexesAndCounts->size();
00541 for (int p = 0; p < numPreds; p++)
00542 {
00543 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00544 = (*gndPredClauseIndexesAndCounts)[p];
00545
00546 if (gndingsToClauseIndexesAndCounts == NULL)
00547 {
00548 cout << "gndPredClauseIndexesAndCounts[" << p << "] = NULL" << endl;
00549 continue;
00550 }
00551
00552 int numGnds = gndingsToClauseIndexesAndCounts->size();
00553 for (int g = 0; g < numGnds; g++)
00554 {
00555 Array<Array<IndexAndCount*>*>* gndings
00556 = (*gndingsToClauseIndexesAndCounts)[g];
00557
00558 for (int c = 0; c < gndings->size(); c++)
00559 {
00560 Array<IndexAndCount*>* combos
00561 = (*gndings)[c];
00562 int numClauseIdx = combos->size();
00563
00564 if (numClauseIdx == 0)
00565 {
00566 cout << "gndPredClauseIndexesAndCounts[" << p << "][" << g
00567 << "][" << c << "] = empty" << endl;
00568 continue;
00569 }
00570
00571 for (int i = 0; i < numClauseIdx; i++)
00572 {
00573 cout << "gndPredClauseIndexesAndCounts[" << p << "][" << g << "]["
00574 << c << "][" << i << "] (clauseIndex,count) = "
00575 << *((*combos)[i]->index)
00576 << ", " << (*combos)[i]->count
00577 << ", " << (*combos)[i] << endl;
00578
00579 if (mln)
00580 {
00581 cout << " \t";
00582 mln->getClause(*((*combos)[i]->index)) ->print(cout,
00583 (*domains_)[0]);
00584 cout << endl;
00585 }
00586 }
00587 }
00588
00589 }
00590 }
00591 }
00592 }
00593
00594
00595 void setIndexTranslator(IndexTranslator* const & it) { idxTrans_ = it; }
00596
00597
00598 IndexTranslator* getIndexTranslator() const { return idxTrans_; }
00599
00600
00601 private:
00602 void createNumGndings()
00603 {
00604 numGndings_ = new Array<Array<double>*>;
00605 numGndings_->growToSize(domains_->size());
00606 for (int i = 0; i < domains_->size(); i++)
00607 {
00608 (*numGndings_)[i] = new Array<double>;
00609 (*numGndings_)[i]->growToSize((*domains_)[i]->getNumPredicates(), -1);
00610 }
00611
00612 for (int i = 0; i < domains_->size(); i++)
00613 {
00614 const Domain* domain = (*domains_)[i];
00615 for (int j = 0; j < domain->getNumPredicates(); j++)
00616 {
00617 if (!(*areNonEvidPreds_)[j]) continue;
00618 const PredicateTemplate* pt = domain->getPredicateTemplate(j);
00619
00620 double numGndings = 1;
00621 for (int t = 0; t < pt->getNumTerms(); t++)
00622 {
00623 int typeId = pt->getTermTypeAsInt(t);
00624 numGndings *= domain->getNumConstantsByType(typeId);
00625 }
00626 (*((*numGndings_)[i]))[j] = numGndings;
00627 }
00628 }
00629 }
00630
00631
00632 void createAllPossiblePredsGroundings(const Predicate* const & pred,
00633 const Domain* const & domain,
00634 ArraysAccessor<int>& acc)
00635 {
00636 for (int i = 0; i < pred->getNumTerms(); i++)
00637 {
00638 int typeId = pred->getTermTypeAsInt(i);
00639 const Array<int>* constArr = domain->getConstantsByType(typeId);
00640 acc.appendArray(constArr);
00641 }
00642 }
00643
00644
00645 bool isSampledGndPred(const int& g, const SampledGndings* const & sg)
00646 {
00647 if (sg->falseGndings.contains(g)) return true;
00648 if (sg->trueGndings.contains(g)) return true;
00649 return false;
00650 }
00651
00652
00653 void computeCountsRemoveCountsHelper(bool computeCounts,
00654 const Clause* const & c,
00655 int* const & clauseIdxInMLN,
00656 const int& domainIdx,
00657 Array<UndoInfo*>* const & undoInfos,
00658 const bool& sampleClauses,
00659 Array<Array<Array<CacheCount*>*>*>* const & cache)
00660 {
00661
00662 const Domain* domain = (*domains_)[domainIdx];
00663 Database* db = domain->getDB();
00664
00665
00666 Array<int> predIdxWithAllTermsDiffVars(domain->getNumPredicates());
00667 predIdxWithAllTermsDiffVars.growToSize(domain->getNumPredicates());
00668 int* parr = (int*) predIdxWithAllTermsDiffVars.getItems();
00669 memset(parr, -1, domain->getNumPredicates()*sizeof(int));
00670
00671
00672 Array<bool> predAllTermsAreDiffVars(c->getNumPredicates());
00673 createPredAllTermsAreDiffVars(c, predAllTermsAreDiffVars,
00674 predIdxWithAllTermsDiffVars);
00675
00676
00677 PredicateHashArray seenPreds;
00678
00679
00680 for (int p = 0; p < c->getNumPredicates(); p++)
00681 {
00682 Predicate* pred = c->getPredicate(p);
00683 int predId = pred->getId();
00684 if (!(*areNonEvidPreds_)[predId]) continue;
00685
00686 Predicate gndPred(*pred);
00687 gndPred.canonicalize();
00688 bool predIsInitiallyGnded = gndPred.isGrounded();
00689
00690 SampledGndings* sg = NULL;
00691 if (sampleGndPreds_) sg = (*(*sampledGndingsMaps_)[domainIdx])[predId];
00692
00693 Predicate* seenPred = new Predicate(gndPred);
00694
00695 if (seenPreds.append(seenPred) < 0) { delete seenPred; continue; }
00696
00697 if (predAllTermsAreDiffVars[p])
00698 {
00699
00700
00701
00702 ArraysAccessor<int> acc;
00703 createAllPossiblePredsGroundings(&gndPred, domain, acc);
00704
00705
00706 int g = -1;
00707 while (acc.hasNextCombination())
00708 {
00709 ++g;
00710
00711 int t = 0; int constId;
00712 while (acc.nextItemInCombination(constId))
00713 ((Term*) gndPred.getTerm(t++))->setId(constId);
00714
00715 if (sampleGndPreds_ && !isSampledGndPred(g,sg)) continue;
00716
00717 if (computeCounts)
00718 {
00719 computeAndSetCounts(c, clauseIdxInMLN, predId, gndPred, g, db,
00720 domainIdx, undoInfos, sampleClauses, cache);
00721 }
00722 else
00723 removeCounts(clauseIdxInMLN, predId, g, domainIdx, undoInfos);
00724 }
00725 }
00726 else
00727 {
00728
00729
00730 if (predIdxWithAllTermsDiffVars[predId] >= 0) continue;
00731
00732
00733
00734
00735 Array<int> multipliers(gndPred.getNumTerms());
00736 createMultipliers(multipliers, gndPred, domain);
00737
00738
00739
00740 int offsetDueToConstants = 0;
00741
00742
00743
00744
00745
00746 Array<Array<pair<int,Term*> >* > varIdToMults;
00747 Array<int> negVarIdsArr;
00748 ArraysAccessor<int> groundings;
00749 createMappingOfVarIdToMultipliersAndVarGroundingsAndOffset(
00750 gndPred, domain, multipliers, offsetDueToConstants, varIdToMults,
00751 negVarIdsArr, groundings);
00752
00753
00754 if (!predIsInitiallyGnded)
00755 {
00756
00757 int constId, constIdx;
00758 while (groundings.hasNextCombination())
00759 {
00760 int g = offsetDueToConstants;
00761 int j = -1;
00762 while (groundings.nextItemInCombination(constId, constIdx))
00763 {
00764 ++j;
00765 int negVarId = negVarIdsArr[j];
00766 Array<pair<int,Term*> >* multsAndTerms = varIdToMults[negVarId];
00767 for (int m = 0; m < multsAndTerms->size(); m++)
00768 {
00769 g += constIdx * (*multsAndTerms)[m].first;
00770 (*multsAndTerms)[m].second->setId(constId);
00771 }
00772 }
00773
00774 if (sampleGndPreds_ && !isSampledGndPred(g,sg)) continue;
00775
00776 if (computeCounts)
00777 computeAndSetCounts(c, clauseIdxInMLN, predId, gndPred, g, db,
00778 domainIdx, undoInfos, sampleClauses, cache);
00779 else
00780 removeCounts(clauseIdxInMLN, predId, g, domainIdx, undoInfos);
00781 }
00782 }
00783 else
00784 {
00785 int g = offsetDueToConstants;
00786
00787 bool ok = true;
00788 if (sampleGndPreds_) ok = isSampledGndPred(g,sg);
00789
00790 if (ok)
00791 {
00792 if (computeCounts)
00793 computeAndSetCounts(c, clauseIdxInMLN, predId, gndPred, g, db,
00794 domainIdx, undoInfos, sampleClauses, cache);
00795 else
00796 removeCounts(clauseIdxInMLN, predId, g, domainIdx, undoInfos);
00797 }
00798 }
00799
00800 for (int j = 0; j < varIdToMults.size(); j++) delete varIdToMults[j];
00801 }
00802 }
00803
00804 for (int i = 0; i < seenPreds.size(); i++) delete seenPreds[i];
00805
00806
00807 }
00808
00809
00810 void computePerPredPllAndGrad(const Array<Array<Array<IndexAndCount*>*>*>*
00811 const& gndingsToClauseIndexesAndCounts,
00812 const int& g, const double* const & wt,
00813 long double& perPredPll,
00814 long double * const & perPredGrad)
00815 {
00816
00817
00818
00819
00820
00821
00822
00823
00824
00825
00826
00827
00828
00829
00830
00831
00832
00833
00834
00835
00836
00837
00838 long double wdotn = 0;
00839
00840 long double pmb = 1;
00841
00842 Array<Array<IndexAndCount*>*>* gndings =
00843 (*gndingsToClauseIndexesAndCounts)[g];
00844
00845
00846 for (int c = 0; c < gndings->size(); c++)
00847 {
00848 Array<IndexAndCount*>* clauseIndexesAndCounts = (*gndings)[c];
00849 assert(noRepeatedIndex(clauseIndexesAndCounts));
00850
00851 int numClausesUnifyWith = clauseIndexesAndCounts->size();
00852
00853
00854 for (int i = 0; i < numClausesUnifyWith; i++)
00855 {
00856
00857
00858
00859 wdotn += wt[ *( (*clauseIndexesAndCounts)[i]->index ) ] *
00860 (*clauseIndexesAndCounts)[i]->count;
00861 }
00862
00863 pmb += expl(wdotn);
00864 }
00865
00866 perPredPll -= logl(pmb);
00867
00868
00869 for (int c = 0; c < gndings->size(); c++)
00870 {
00871 Array<IndexAndCount*>* clauseIndexesAndCounts
00872 = (*gndings)[c];
00873
00874 for (int i = 0; i < clauseIndexesAndCounts->size(); i++)
00875 perPredGrad[ *( (*clauseIndexesAndCounts)[i]->index ) ]
00876 += ( (1.0/pmb-1) * (*clauseIndexesAndCounts)[i]->count );
00877 }
00878 }
00879
00880 void computeSampledPerPredPllAndGrad(IntHashArray& gndings,
00881 const int& totalGndings,
00882 long double& tmpPerPredPll,
00883 long double* const & tmpPerPredGrad,
00884 long double& perPredPll,
00885 long double* const & perPredGrad,
00886 const int& arrSize,
00887 const double* const & wt,
00888 const Array<Array<Array<IndexAndCount*>*>*>*
00889 const& gndingsToClauseIndexesAndCounts)
00890 {
00891 tmpPerPredPll = 0;
00892 memset(tmpPerPredGrad, 0, arrSize*sizeof(long double));
00893 for (int i = 0; i < gndings.size(); i++)
00894 computePerPredPllAndGrad(gndingsToClauseIndexesAndCounts,
00895 gndings[i], wt, tmpPerPredPll,
00896 tmpPerPredGrad);
00897 if (gndings.size() > 0)
00898 {
00899 perPredPll += totalGndings * tmpPerPredPll/gndings.size();
00900
00901 for (int i = 0; i < arrSize; i++)
00902 perPredGrad[i] += totalGndings * tmpPerPredGrad[i]/gndings.size();
00903 }
00904
00905 }
00906
00907
00908
00909 double getValueAndGradientForDomain(double* const & gradient,
00910 const double* const & wt,
00911 const int& arrSize, const int& domainIdx)
00912 {
00913 long double wpll = 0;
00914 long double* perPredGrad = new long double[arrSize];
00915
00916
00917 long double tmpPerPredPll;
00918 long double* tmpPerPredGrad = NULL;
00919 if (sampleGndPreds_) tmpPerPredGrad = new long double[arrSize];
00920
00921 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
00922 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
00923
00924 int numPreds = gndPredClauseIndexesAndCounts->size();
00925 for (int p = 0; p < numPreds; p++)
00926 {
00927 if (!(*areNonEvidPreds_)[p]) continue;
00928
00929
00930
00931
00932
00933 long double perPredPll = 0;
00934 memset(perPredGrad, 0, arrSize*sizeof(long double));
00935
00936
00937 if ((*gndPredClauseIndexesAndCounts)[p] != NULL)
00938 {
00939 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00940 = (*gndPredClauseIndexesAndCounts)[p];
00941
00942 if (sampleGndPreds_)
00943 {
00944 SampledGndings* sg = (*(*sampledGndingsMaps_)[domainIdx])[p];
00945 computeSampledPerPredPllAndGrad(sg->trueGndings, sg->totalTrue,
00946 tmpPerPredPll, tmpPerPredGrad,
00947 perPredPll, perPredGrad, arrSize,
00948 wt, gndingsToClauseIndexesAndCounts);
00949 computeSampledPerPredPllAndGrad(sg->falseGndings, sg->totalFalse,
00950 tmpPerPredPll, tmpPerPredGrad,
00951 perPredPll, perPredGrad, arrSize,
00952 wt, gndingsToClauseIndexesAndCounts);
00953 }
00954 else
00955 {
00956 int numGnds = gndingsToClauseIndexesAndCounts->size();
00957 assert(numGnds == (*((*numGndings_)[domainIdx]))[p]);
00958 for (int g = 0; g < numGnds; g++)
00959 computePerPredPllAndGrad(gndingsToClauseIndexesAndCounts, g, wt,
00960 perPredPll, perPredGrad);
00961 }
00962 }
00963 else
00964 {
00965 perPredPll = (*((*numGndings_)[domainIdx]))[p] * -log(2);
00966
00967 }
00968
00969 if (wtFOPred_)
00970 {
00971
00972 wpll -= perPredPll / (*((*numGndings_)[domainIdx]))[p];
00973 for (int i = 0; i < arrSize; i++)
00974 gradient[i] -= perPredGrad[i]/(*((*numGndings_)[domainIdx]))[p];
00975 }
00976 else
00977 {
00978
00979 wpll -= perPredPll;
00980 for (int i = 0; i < arrSize; i++)
00981 gradient[i] -= perPredGrad[i];
00982 }
00983 }
00984
00985 delete [] perPredGrad;
00986 if (sampleGndPreds_) delete [] tmpPerPredGrad;
00987
00988 return wpll;
00989 }
00990
00991 void createClauseIndexesAndCountsArrays(const int& predId,
00992 const int& domainIdx)
00993 {
00994 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
00995 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
00996 if ((*gndPredClauseIndexesAndCounts)[predId] != NULL) return;
00997
00998 Array<Array<Array<IndexAndCount*>*>*>* arr =
00999 new Array<Array<Array<IndexAndCount*>*>*>;
01000 double numGndings = (*((*numGndings_)[domainIdx]))[predId];
01001
01002
01003
01004 for (int g = 0; g < numGndings; g++)
01005 arr->append(new Array<Array<IndexAndCount*>*>);
01006 arr->compress();
01007 (*gndPredClauseIndexesAndCounts)[predId] = arr;
01008 }
01009
01010 void createComboClauseIndexesAndCountsArrays(const int& predId,
01011 const int& domainIdx,
01012 Predicate* const & gndPred,
01013 const int& g)
01014 {
01015 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
01016 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
01017
01018 Array<Array<IndexAndCount*>*>* comboClauseIndexesAndCounts
01019 = (*(*gndPredClauseIndexesAndCounts)[predId])[g];
01020 if (comboClauseIndexesAndCounts->size() > 0) return;
01021
01022
01023 int numCombInBlock = 1;
01024
01025 int blockIdx = (*domains_)[domainIdx]->getBlock(gndPred);
01026 if (blockIdx >= 0)
01027 numCombInBlock = (*domains_)[domainIdx]->getBlockSize(blockIdx) - 1;
01028
01029 comboClauseIndexesAndCounts->growToSize(numCombInBlock, NULL);
01030 for (int c = 0; c < numCombInBlock; c++)
01031 {
01032 (*comboClauseIndexesAndCounts)[c] = new Array<IndexAndCount*>;
01033 }
01034 comboClauseIndexesAndCounts->compress();
01035 }
01036
01037
01038
01039
01040
01041
01042
01043
01044
01045 bool computeAndSetCounts(const Clause* const clause,
01046 int* const & clauseIdxInMLN,
01047 const int& predId, Predicate& gndPred,
01048 const int& g, Database* const & db,
01049 const int& domainIdx,
01050 Array<UndoInfo*>* const & undoInfos,
01051 const bool& sampleClauses,
01052 Array<Array<Array<CacheCount*>*>*>* const & cache)
01053 {
01054 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
01055 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
01056 const Domain* domain = (*domains_)[domainIdx];
01057
01058 if ((*gndPredClauseIndexesAndCounts)[predId] == NULL)
01059 createClauseIndexesAndCountsArrays(predId, domainIdx);
01060
01061 createComboClauseIndexesAndCountsArrays(predId, domainIdx, &gndPred, g);
01062
01063 Array<Array<IndexAndCount*>*>* comboClauseIndexesAndCounts
01064 = (*(*gndPredClauseIndexesAndCounts)[predId])[g];
01065
01066
01067
01068
01069
01070
01071
01072
01073
01074
01075
01076
01077 for (int c = 0; c < comboClauseIndexesAndCounts->size(); c++)
01078 {
01079
01080
01081 Array<IndexAndCount*>* gArr = (*comboClauseIndexesAndCounts)[c];
01082 if (gArr->size() > 0 && *( gArr->lastItem()->index ) == *clauseIdxInMLN)
01083 {
01084
01085
01086 continue;
01087 }
01088
01089 double cnt =
01090 ((Clause*)clause)->countDiffNumTrueGroundings(&gndPred, domain, db,
01091 DB_HAS_UNKNOWN_PREDS,
01092 sampleClauses, c);
01093
01094
01095 if (cnt != 0)
01096 {
01097
01098 gArr->append(new IndexAndCount(clauseIdxInMLN, cnt));
01099 if (undoInfos)
01100 undoInfos->append(new UndoInfo(gArr, NULL, -1, domainIdx));
01101
01102 if (cache)
01103 {
01104 Array<CacheCount*>*& ccArr = (*(*cache)[domainIdx])[predId];
01105 if (ccArr == NULL) ccArr = new Array<CacheCount*>;
01106 ccArr->append(new CacheCount(g, c, cnt));
01107 }
01108 }
01109
01110 assert(noRepeatedIndex(gArr));
01111 }
01112
01113 return true;
01114 }
01115
01116
01117
01118 bool removeCounts(int* const & clauseIdxInMLN, const int& predId,
01119 const int& g, const int& domainIdx,
01120 Array<UndoInfo*>* const & undoInfos)
01121 {
01122 bool removed = false;
01123 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
01124 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
01125
01126 if ((*gndPredClauseIndexesAndCounts)[predId] == NULL) return false;
01127
01128 Array<Array<IndexAndCount*>*>* comboClauseIndexesAndCounts
01129 = (*(*gndPredClauseIndexesAndCounts)[predId])[g];
01130
01131 for (int c = 0; c < comboClauseIndexesAndCounts->size(); c++)
01132 {
01133 Array<IndexAndCount*>* gArr =(*comboClauseIndexesAndCounts)[c];
01134 for (int i = 0; i < gArr->size(); i++)
01135 {
01136 if ((*gArr)[i]->index == clauseIdxInMLN)
01137 {
01138 IndexAndCount* ic = gArr->removeItemFastDisorder(i);
01139
01140 if (undoInfos)
01141 {
01142 undoInfos->append(new UndoInfo(gArr, ic, i, domainIdx));
01143 assert(noRepeatedIndex(gArr));
01144
01145 removed = true;
01146 }
01147 else
01148 {
01149 delete ic;
01150 assert(noRepeatedIndex(gArr));
01151
01152 removed = true;
01153 }
01154 }
01155 }
01156 assert(noRepeatedIndex(gArr));
01157 }
01158
01159
01160 return removed;
01161 }
01162
01163
01164 void createPredAllTermsAreDiffVars(const Clause* const & c,
01165 Array<bool>& predAllTermsAreDiffVars,
01166 Array<int>& predIdxWithAllTermsDiffVars)
01167 {
01168 for (int p = 0; p < c->getNumPredicates(); p++)
01169 {
01170 Predicate* pred = c->getPredicate(p);
01171 bool allDiffVars;
01172
01173 if (c->isDirty())
01174 {
01175
01176 allDiffVars = pred->checkAllTermsAreDiffVars();
01177 }
01178 else
01179 {
01180 assert(!pred->isDirty());
01181 allDiffVars = pred->allTermsAreDiffVars();
01182 }
01183
01184 predAllTermsAreDiffVars.append(allDiffVars);
01185 if (allDiffVars)
01186 {
01187 int predId = pred->getId();
01188 if (predIdxWithAllTermsDiffVars[predId] < 0)
01189 predIdxWithAllTermsDiffVars[predId] = p;
01190 }
01191 }
01192 predAllTermsAreDiffVars.compress();
01193 }
01194
01195
01196 void createMultipliers(Array<int>& multipliers,
01197 const Predicate& gndPred,
01198 const Domain* const & domain)
01199 {
01200 int mult = 1;
01201 int numTerms = gndPred.getNumTerms();
01202 multipliers.growToSize(numTerms);
01203 for (int j = numTerms-1; j >= 0; j--)
01204 {
01205 multipliers[j] = mult;
01206 int typeId = gndPred.getTermTypeAsInt(j);
01207 mult *= domain->getNumConstantsByType(typeId);
01208 }
01209 }
01210
01211
01212 void createMappingOfVarIdToMultipliersAndVarGroundingsAndOffset(
01213 const Predicate& gndPred,
01214 const Domain* const & domain,
01215 Array<int>& multipliers,
01216 int& offsetDueToConstants,
01217 Array<Array<pair<int,Term*> >* >& varIdToMults,
01218 Array<int>& negVarIdsArr,
01219 ArraysAccessor<int>& groundings)
01220 {
01221 for (int j = 0; j < gndPred.getNumTerms(); j++)
01222 {
01223 const Term* t = gndPred.getTerm(j);
01224 if (t->getType() == Term::VARIABLE)
01225 {
01226 assert(t->getId()<0);
01227 int id = -(t->getId());
01228 if (id >= varIdToMults.size()) varIdToMults.growToSize(id+1,NULL);
01229 if (varIdToMults[id] == NULL)
01230 {
01231 negVarIdsArr.append(id);
01232 varIdToMults[id] = new Array<pair<int,Term*> >;
01233 int typeId = gndPred.getTermTypeAsInt(j);
01234 const Array<int>* constants = domain->getConstantsByType(typeId);
01235 groundings.appendArray(constants);
01236 }
01237 varIdToMults[id]->append(pair<int,Term*>(multipliers[j], (Term*)t));
01238 }
01239 else
01240 if (t->getType() == Term::CONSTANT)
01241 {
01242 int id = t->getId();
01243 assert(id >= 0);
01244 int typeId = gndPred.getTermTypeAsInt(j);
01245 const Array<int>* constants = domain->getConstantsByType(typeId);
01246 assert(constants->size() > 0);
01247
01248
01249 for (int i = 0; i < constants->size(); i++)
01250 {
01251 if ((*constants)[i] == id)
01252 {
01253 offsetDueToConstants += i*multipliers[j];
01254 break;
01255 }
01256 }
01257
01258 }
01259 else
01260 {
01261 assert(false);
01262 }
01263 }
01264 }
01265
01266
01267 void randomlySelect(IntHashArray& gndings, const double& fraction,
01268 const int& min, const int& max)
01269 {
01270 int size = int(fraction * gndings.size() + 0.5);
01271 if (min >= 0 && size < min) size = min;
01272 else if (max >= 0 && size > max) size = max;
01273 while (gndings.size() > size)
01274 gndings.removeItemFastDisorder(random_->randomOneOf(gndings.size()));
01275 }
01276
01277
01278 void samplePredGroundingsForDomain(const Predicate* const& foPred,
01279 const Domain* const & domain,
01280 SampledGndings* sampledGndings,
01281 const double& fraction,
01282 const int& min, const int& max)
01283 {
01284 cout << "sampling predicate "; foPred->printWithStrVar(cout, domain);
01285 cout << endl;
01286
01287 assert(((Predicate*)foPred)->allTermsAreDiffVars());
01288 ArraysAccessor<int> acc;
01289 createAllPossiblePredsGroundings(foPred, domain, acc);
01290 Predicate* gndPred = (Predicate*) foPred;
01291 const Database* db = domain->getDB();
01292
01293 IntHashArray& trueGndings = sampledGndings->trueGndings;
01294 IntHashArray& falseGndings = sampledGndings->falseGndings;
01295
01296 int g = -1;
01297 while (acc.hasNextCombination())
01298 {
01299 ++g;
01300 int t = 0;
01301 int constId;
01302 while (acc.nextItemInCombination(constId))
01303 ((Term*) gndPred->getTerm(t++))->setId(constId);
01304
01305 TruthValue tv = db->getValue(gndPred);
01306 if (tv == TRUE) trueGndings.append(g);
01307 else if (tv == FALSE) falseGndings.append(g);
01308 }
01309
01310
01311 sampledGndings->totalTrue = trueGndings.size();
01312 sampledGndings->totalFalse = falseGndings.size();
01313 randomlySelect(trueGndings, fraction, min, max);
01314 randomlySelect(falseGndings, fraction, min, max);
01315 trueGndings.compress();
01316 falseGndings.compress();
01317
01318 cout << "\tsampled/total (true ground atoms) = "
01319 << trueGndings.size() << "/" << sampledGndings->totalTrue << endl;
01320 cout << "\tsampled/total (false ground atoms) = "
01321 << falseGndings.size() << "/" << sampledGndings->totalFalse << endl;
01322 }
01323
01324
01325 void samplePredGroundings(const double& fraction,
01326 const int& min, const int& max)
01327 {
01328 for (int d = 0; d < domains_->size(); d++)
01329 {
01330 cout << "domain " << d << endl;
01331 const Domain* domain = (*domains_)[d];
01332 Array<SampledGndings*>* sgm = (*sampledGndingsMaps_)[d];
01333 for (int p = 0; p < domain->getNumPredicates(); p++)
01334 {
01335 if (!(*areNonEvidPreds_)[p]) continue;
01336 SampledGndings* sg = new SampledGndings;
01337 assert((*sgm)[p] == NULL);
01338 (*sgm)[p] = sg;
01339 Predicate* foPred = domain->createPredicate(p, true);
01340 assert(foPred);
01341 samplePredGroundingsForDomain(foPred, domain, sg, fraction, min, max);
01342 delete foPred;
01343 }
01344 }
01345 }
01346
01347
01348 bool noRepeatedIndex(const Array<IndexAndCount*>* const & gArr,
01349 const MLN* const & mln0=NULL)
01350 {
01351 hash_set<int> set;
01352 for (int i = 0; i < gArr->size(); i++)
01353 {
01354 int ii = *((*gArr)[i]->index);
01355 if (set.find(ii) != set.end())
01356 {
01357 cout << "ERROR: in PseudoLogLikelihood::noRepeatedIndex. "
01358 << "Repeated index " << ii << " found. ";
01359 if (mln0)
01360 mln0->getClause(ii)->printWithoutWtWithStrVar(cout, (*domains_)[0]);
01361 cout << endl;
01362 return false;
01363 }
01364 set.insert(ii);
01365 }
01366 return true;
01367 }
01368
01369
01370 void printWtsGradsWPLL(const double* const & wts, const double* const & grads,
01371 const int& arrSize, const double& wpll)
01372 {
01373 cout.precision(10);
01374 cout << "wts = " << endl;
01375 for (int i = 0; i < arrSize; i++) cout << " " << i << ":" << wts[i];
01376 cout << endl;
01377 cout << "grads = " << endl;
01378 for (int i = 0; i < arrSize; i++) cout << " " << i << ":" << grads[i];
01379 cout << endl;
01380 cout << "wpll = " << wpll << endl;
01381 cout << endl;
01382 cout.precision(6);
01383 }
01384
01385
01386 private:
01387 Array<bool>* areNonEvidPreds_;
01388
01389
01390 Array<Domain*>* domains_;
01391
01392
01393
01394
01395
01396
01397
01398
01399
01400 Array<Array<Array<Array<Array<IndexAndCount*>*>*>*>*>*
01401 gndPredClauseIndexesAndCountsArr_;
01402
01403 int numMeans_;
01404 const double* priorMeans_;
01405 const double* priorStdDevs_;
01406
01407 bool wtFOPred_;
01408 Array<Array<double>*>* numGndings_;
01409
01410 bool sampleGndPreds_;
01411
01412
01413 Array<Array<SampledGndings*>*>* sampledGndingsMaps_;
01414 Random* random_;
01415
01416 IndexTranslator* idxTrans_;
01417 };
01418
01419
01420 #endif