00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef LOGPSEUDOLIKELIHOOD_H_AUG_18_2005
00067 #define LOGPSEUDOLIKELIHOOD_H_AUG_18_2005
00068
00069 #include <cmath>
00070 #include "array.h"
00071 #include "random.h"
00072 #include "domain.h"
00073 #include "clause.h"
00074 #include "mln.h"
00075 #include "indextranslator.h"
00076
00077
00078
00080
00081
00082
00083 const bool DB_HAS_UNKNOWN_PREDS = false;
00084
00085 struct IndexAndCount
00086 {
00087 IndexAndCount() : index(NULL), count(0) {}
00088 IndexAndCount(int* const & i, const double& c) : index(i), count(c) {}
00089 int* index;
00090 double count;
00091 };
00092
00093
00094 struct UndoInfo
00095 {
00096 UndoInfo(Array<IndexAndCount*>* const & affArr,
00097 IndexAndCount* const & iac, const int& remIdx, const int& domIdx)
00098 : affectedArr(affArr), remIac(iac), remIacIdx(remIdx), domainIdx(domIdx) {}
00099 ~UndoInfo() {if (remIac) delete remIac; }
00100 Array<IndexAndCount*>* affectedArr;
00101 IndexAndCount* remIac;
00102 int remIacIdx;
00103 int domainIdx;
00104 };
00105
00106
00107 struct SampledGndings
00108 {
00109 IntHashArray trueGndings;
00110 IntHashArray falseGndings;
00111 int totalTrue;
00112 int totalFalse;
00113 };
00114
00115
00117
00118
00119 class PseudoLogLikelihood
00120 {
00121 public:
00122 PseudoLogLikelihood(const Array<bool>* const & areNonEvidPreds,
00123 const Array<Domain*>* const & domains,
00124 const bool& wtFOPred, const bool& sampleGndPreds,
00125 const double& fraction, const int& minGndPredSamples,
00126 const int& maxGndPredSamples)
00127 : domains_(new Array<Domain*>) , numMeans_(-1),
00128 priorMeans_(NULL), priorStdDevs_(NULL), wtFOPred_(wtFOPred),
00129 sampleGndPreds_(sampleGndPreds), idxTrans_(NULL)
00130 {
00131
00132 if (areNonEvidPreds)
00133 {
00134 areNonEvidPreds_ = new Array<bool>(*areNonEvidPreds);
00135 assert(areNonEvidPreds_->size() == (*domains)[0]->getNumPredicates());
00136 }
00137 else
00138 {
00139 areNonEvidPreds_ = new Array<bool>;
00140 areNonEvidPreds_->growToSize((*domains)[0]->getNumPredicates(), true);
00141 }
00142
00143 int numDomains = domains->size();
00144 assert(numDomains > 0);
00145 domains_->growToSize(numDomains,NULL);
00146 for (int i = 0; i < numDomains; i++)
00147 (*domains_)[i] = (*domains)[i];
00148
00149 gndPredClauseIndexesAndCountsArr_
00150 = new Array<Array<Array<Array<Array<IndexAndCount*>*>*>*>*>;
00151 gndPredClauseIndexesAndCountsArr_->growToSize(numDomains,NULL);
00152 for (int i = 0; i < numDomains; i++)
00153 {
00154 (*gndPredClauseIndexesAndCountsArr_)[i]
00155 = new Array<Array<Array<Array<IndexAndCount*>*>*>*>;
00156 (*gndPredClauseIndexesAndCountsArr_)[i]->growToSize(
00157 (*domains_)[i]->getNumPredicates(),NULL);
00158 }
00159
00160 createNumGndings();
00161
00162 if (sampleGndPreds_)
00163 {
00164 sampledGndingsMaps_ = new Array<Array<SampledGndings*>*>;
00165 sampledGndingsMaps_->growToSize(numDomains, NULL);
00166 for (int i = 0; i < numDomains; i++)
00167 {
00168 (*sampledGndingsMaps_)[i] = new Array<SampledGndings*>;
00169 (*sampledGndingsMaps_)[i]->growToSize((*domains_)[i]->getNumPredicates()
00170 , NULL);
00171 }
00172 random_ = new Random;
00173 random_->init(-3);
00174 samplePredGroundings(fraction, minGndPredSamples, maxGndPredSamples);
00175 }
00176 else
00177 {
00178 sampledGndingsMaps_ = NULL;
00179 random_ = NULL;
00180 }
00181 }
00182
00183
00184 ~PseudoLogLikelihood()
00185 {
00186 delete areNonEvidPreds_;
00187 delete domains_;
00188
00189 for (int i = 0; i < gndPredClauseIndexesAndCountsArr_->size(); i++)
00190 {
00191 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00192 gndPredClauseIndexesAndCounts
00193 = (*gndPredClauseIndexesAndCountsArr_)[i];
00194
00195 int numPreds = gndPredClauseIndexesAndCounts->size();
00196 for (int p = 0; p < numPreds; p++)
00197 {
00198 if ( (*gndPredClauseIndexesAndCounts)[p] )
00199 {
00200 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00201 = (*gndPredClauseIndexesAndCounts)[p];
00202 int numGnds = gndingsToClauseIndexesAndCounts->size();
00203 for (int g = 0; g < numGnds; g++)
00204 {
00205 for (int h = 0; h < (*gndingsToClauseIndexesAndCounts)[g]->size();
00206 h++)
00207 {
00208 for (int j = 0;
00209 j < (*(*gndingsToClauseIndexesAndCounts)[g])[h]->size();
00210 j++)
00211 {
00212 delete (*(*(*gndingsToClauseIndexesAndCounts)[g])[h])[j];
00213 }
00214 delete (*(*gndingsToClauseIndexesAndCounts)[g])[h];
00215 }
00216 delete (*gndingsToClauseIndexesAndCounts)[g];
00217 }
00218 delete gndingsToClauseIndexesAndCounts;
00219 }
00220 }
00221 delete gndPredClauseIndexesAndCounts;
00222 }
00223 delete gndPredClauseIndexesAndCountsArr_;
00224
00225 numGndings_->deleteItemsAndClear();
00226 delete numGndings_;
00227
00228 if (sampledGndingsMaps_)
00229 {
00230 for (int i = 0; i < sampledGndingsMaps_->size(); i++)
00231 {
00232 (*sampledGndingsMaps_)[i]->deleteItemsAndClear();
00233 delete (*sampledGndingsMaps_)[i];
00234 }
00235 delete sampledGndingsMaps_;
00236 }
00237
00238 if (random_) delete random_;
00239 }
00240
00241
00242 void compress()
00243 {
00244 for (int i = 0; i < gndPredClauseIndexesAndCountsArr_->size(); i++)
00245 {
00246 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00247 gndPredClauseIndexesAndCounts
00248 = (*gndPredClauseIndexesAndCountsArr_)[i];
00249
00250 int numPreds = gndPredClauseIndexesAndCounts->size();
00251 for (int p = 0; p < numPreds; p++)
00252 {
00253
00254 if ((*gndPredClauseIndexesAndCounts)[p])
00255 {
00256 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00257 = (*gndPredClauseIndexesAndCounts)[p];
00258 int numGnds = gndingsToClauseIndexesAndCounts->size();
00259 for (int g = 0; g < numGnds; g++)
00260 {
00261 Array<Array<IndexAndCount*>*>* combosToClauseIndexesAndCounts
00262 = (*gndingsToClauseIndexesAndCounts)[g];
00263 int numCombos = combosToClauseIndexesAndCounts->size();
00264 for (int c = 0; c < numCombos; c++)
00265 (*combosToClauseIndexesAndCounts)[c]->compress();
00266 }
00267 }
00268 }
00269 }
00270
00271
00272 }
00273
00274
00275
00276 void insertCounts(int* const & clauseIdxInMLN,
00277 Array<UndoInfo*>* const & undoInfos,
00278 Array<Array<Array<CacheCount*>*>*>* const & cache,
00279 const int& d)
00280 {
00281 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00282 gndPredClauseIndexesAndCounts;
00283
00284 Array<IndexAndCount*>* gArr;
00285 CacheCount* cc;
00286 assert(cache->size() == domains_->size());
00287
00288 gndPredClauseIndexesAndCounts = (*gndPredClauseIndexesAndCountsArr_)[d];
00289 for (int p = 0; p < (*cache)[d]->size(); p++)
00290 {
00291 Array<CacheCount*>* ccArr = (*(*cache)[d])[p];
00292 if (ccArr == NULL) continue;
00293 for (int i = 0; i < ccArr->size(); i++)
00294 {
00295 assert((*gndPredClauseIndexesAndCounts)[p] != NULL);
00296 cc = (*ccArr)[i];
00297 gArr = (*(*(*gndPredClauseIndexesAndCounts)[p])[cc->g])[cc->c];
00298
00299 assert(gArr->size()==0 ||*(gArr->lastItem()->index)!=*clauseIdxInMLN);
00300 assert(cc->cnt != 0);
00301 gArr->append(new IndexAndCount(clauseIdxInMLN, cc->cnt));
00302 if (undoInfos) undoInfos->append(new UndoInfo(gArr, NULL, -1, d));
00303 }
00304 }
00305 }
00306
00307
00308
00309 void insertCounts(const Array<int*>& clauseIdxInMLNs,
00310 Array<UndoInfo*>* const & undoInfos,
00311 Array<Array<Array<CacheCount*>*>*>* const & cache)
00312 {
00313 assert(cache->size() == domains_->size());
00314 assert(clauseIdxInMLNs.size() == domains_->size());
00315
00316 for (int d = 0; d < cache->size(); d++)
00317 insertCounts(clauseIdxInMLNs[d], undoInfos, cache, d);
00318 }
00319
00320
00321
00322
00323
00324 void computeCountsForNewAppendedClause(const Clause* const & c,
00325 int* const & clauseIdxInMLN,
00326 const int& domainIdx,
00327 Array<UndoInfo*>* const & undoInfos,
00328 const bool& sampleClauses,
00329 Array<Array<Array<CacheCount*>*>*>* const & cache)
00330 {
00331 computeCountsRemoveCountsHelper(true, c, clauseIdxInMLN, domainIdx,
00332 undoInfos, sampleClauses, cache);
00333 }
00334
00335
00336 void removeCountsForClause(const Clause* const & c,
00337 int* const & clauseIdxInMLN, const int& domainIdx,
00338 Array<UndoInfo*>* const & undoInfos)
00339 {
00340 computeCountsRemoveCountsHelper(false, c, clauseIdxInMLN, domainIdx,
00341 undoInfos, false, NULL);
00342 }
00343
00344
00345
00346 void undoAppendRemoveCounts(const Array<UndoInfo*>* const & undoInfos)
00347 {
00348 for (int i = undoInfos->size() - 1; i >= 0; i--)
00349 {
00350 if ((*undoInfos)[i]->remIacIdx >= 0)
00351 {
00352 Array<IndexAndCount*>* affectedArr = (*undoInfos)[i]->affectedArr;
00353 IndexAndCount* remIac = (*undoInfos)[i]->remIac;
00354 (*undoInfos)[i]->remIac = NULL;
00355 int remIacIdx = (*undoInfos)[i]->remIacIdx;
00356
00357 if (affectedArr->size() == remIacIdx)
00358 affectedArr->append(remIac);
00359 else
00360 {
00361 assert(remIacIdx < affectedArr->size());
00362 IndexAndCount* tmpRemIac = (*affectedArr)[remIacIdx];
00363 (*affectedArr)[remIacIdx] = remIac;
00364 affectedArr->append(tmpRemIac);
00365 }
00366 }
00367 else
00368 {
00369 IndexAndCount* iac = (*undoInfos)[i]->affectedArr->removeLastItem();
00370 delete iac;
00371 }
00372
00373 assert(noRepeatedIndex((*undoInfos)[i]->affectedArr));
00374 delete (*undoInfos)[i];
00375 }
00376 }
00377
00378
00379 double getValueAndGradient(double* const & gradient, const double* const & wt,
00380 const int& arrSize)
00381 {
00382 double wpll = 0;
00383 memset(gradient, 0, arrSize*sizeof(double));
00384
00385
00386 if (idxTrans_ == NULL)
00387 {
00388 for (int i = 0; i < domains_->size(); i++)
00389 wpll += getValueAndGradientForDomain(gradient, wt, arrSize, i);
00390 }
00391 else
00392 {
00393 Array<Array<double> >* wtsPerDomain = idxTrans_->getWtsPerDomain();
00394 Array<Array<double> >* gradsPerDomain = idxTrans_->getGradsPerDomain();
00395 const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain
00396 = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00397
00398 for (int i = 0; i < domains_->size(); i++)
00399 {
00400 Array<double>& wts = (*wtsPerDomain)[i];
00401 Array<double>& grads = (*gradsPerDomain)[i];
00402 assert(grads.size() == wts.size());
00403 memset((double*)wts.getItems(), 0, wts.size()*sizeof(double));
00404 memset((double*)grads.getItems(), 0, grads.size()*sizeof(double));
00405
00406
00407
00408
00409
00410
00411
00412 for (int j = 0; j < wts.size(); j++)
00413 {
00414 Array<IdxDiv>* idxDivs =(*cIdxToCFIdxsPerDomain)[i][j];
00415 for (int k = 0; k < idxDivs->size(); k++)
00416 if ((*idxDivs)[k].idx < arrSize)
00417 wts[j] += wt[ (*idxDivs)[k].idx ] / (*idxDivs)[k].div;
00418 }
00419
00420 wpll += getValueAndGradientForDomain((double*)grads.getItems(),
00421 (double*)wts.getItems(),
00422 wts.size(), i);
00423
00424
00425 for (int j = 0; j < grads.size(); j++)
00426 {
00427 Array<IdxDiv>* idxDivs =(*cIdxToCFIdxsPerDomain)[i][j];
00428 for (int k = 0; k < idxDivs->size(); k++)
00429 if ((*idxDivs)[k].idx < arrSize)
00430 gradient[ (*idxDivs)[k].idx ] += grads[j] / (*idxDivs)[k].div;
00431 }
00432 }
00433 }
00434
00435
00436
00437
00438 if (numMeans_ > 0)
00439 {
00440
00441
00442
00443
00444
00445 for (int i = 0; i < arrSize; i++)
00446 {
00447
00448
00449 wpll += (wt[i]-priorMeans_[i])*(wt[i]-priorMeans_[i])/
00450 (2*priorStdDevs_[i]*priorStdDevs_[i]);
00451
00452 gradient[i] += (wt[i]-priorMeans_[i])/
00453 (priorStdDevs_[i]*priorStdDevs_[i]);
00454 }
00455 }
00456
00457
00458
00459 return wpll;
00460 }
00461
00462
00463
00464
00465 void setMeansStdDevs(const int& arrSize, const double* const & priorMeans,
00466 const double* const & priorStdDevs)
00467 {
00468 numMeans_ = arrSize;
00469 priorMeans_ = priorMeans;
00470 priorStdDevs_ = priorStdDevs;
00471 }
00472
00473
00474 void setSampleGndPreds(const bool& sgp)
00475 {
00476 if (sgp) { assert(sampledGndingsMaps_); assert(random_); }
00477 sampleGndPreds_ = sgp;
00478 }
00479
00480
00481 bool checkNoRepeatedIndex(const MLN* const & mln=NULL)
00482 {
00483 bool ret = true;
00484 for (int d = 0; d < domains_->size(); d++)
00485 {
00486 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00487 gndPredClauseIndexesAndCounts
00488 = (*gndPredClauseIndexesAndCountsArr_)[d];
00489
00490 int numPreds = gndPredClauseIndexesAndCounts->size();
00491 for (int p = 0; p < numPreds; p++)
00492 {
00493 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00494 = (*gndPredClauseIndexesAndCounts)[p];
00495
00496 if (gndingsToClauseIndexesAndCounts == NULL) continue;
00497
00498 int numGnds = gndingsToClauseIndexesAndCounts->size();
00499 for (int g = 0; g < numGnds; g++)
00500 {
00501 Array<Array<IndexAndCount*>*>* gndings
00502 = (*gndingsToClauseIndexesAndCounts)[g];
00503
00504 for (int c = 0; c < gndings->size(); c++)
00505 {
00506 bool ok = noRepeatedIndex((*gndings)[c], mln);
00507 if (!ok)
00508 {
00509 cout << "ERROR: repeated index in domain " << d << " for pred "
00510 << (*domains_)[0]->getPredicateName(p) << " gnding " << g
00511 << " combination " << c << endl;
00512 ret = false;
00513 }
00514 }
00515 }
00516 }
00517 }
00518 return ret;
00519 }
00520
00521
00522 void printGndPredClauseIndexesAndCounts(const MLN* const & mln=NULL)
00523 {
00524 for (int d = 0; d < domains_->size(); d++)
00525 {
00526 cout << "domainIdx: " << d << endl;
00527 cout << "gndPredClauseIndexesAndCounts[predIdx][gndingIdx][combIdx][i]"
00528 << endl;
00529
00530 Array<Array<Array<Array<IndexAndCount*>*>*>*>*
00531 gndPredClauseIndexesAndCounts
00532 = (*gndPredClauseIndexesAndCountsArr_)[d];
00533
00534 int numPreds = gndPredClauseIndexesAndCounts->size();
00535 for (int p = 0; p < numPreds; p++)
00536 {
00537 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00538 = (*gndPredClauseIndexesAndCounts)[p];
00539
00540 if (gndingsToClauseIndexesAndCounts == NULL)
00541 {
00542 cout << "gndPredClauseIndexesAndCounts[" << p << "] = NULL" << endl;
00543 continue;
00544 }
00545
00546 int numGnds = gndingsToClauseIndexesAndCounts->size();
00547 for (int g = 0; g < numGnds; g++)
00548 {
00549 Array<Array<IndexAndCount*>*>* gndings
00550 = (*gndingsToClauseIndexesAndCounts)[g];
00551
00552 for (int c = 0; c < gndings->size(); c++)
00553 {
00554 Array<IndexAndCount*>* combos
00555 = (*gndings)[c];
00556 int numClauseIdx = combos->size();
00557
00558 if (numClauseIdx == 0)
00559 {
00560 cout << "gndPredClauseIndexesAndCounts[" << p << "][" << g
00561 << "][" << c << "] = empty" << endl;
00562 continue;
00563 }
00564
00565 for (int i = 0; i < numClauseIdx; i++)
00566 {
00567 cout << "gndPredClauseIndexesAndCounts[" << p << "][" << g << "]["
00568 << c << "][" << i << "] (clauseIndex,count) = "
00569 << *((*combos)[i]->index)
00570 << ", " << (*combos)[i]->count
00571 << ", " << (*combos)[i] << endl;
00572
00573 if (mln)
00574 {
00575 cout << " \t";
00576 mln->getClause(*((*combos)[i]->index)) ->print(cout,
00577 (*domains_)[0]);
00578 cout << endl;
00579 }
00580 }
00581 }
00582
00583 }
00584 }
00585 }
00586 }
00587
00588
00589 void setIndexTranslator(IndexTranslator* const & it) { idxTrans_ = it; }
00590
00591
00592 IndexTranslator* getIndexTranslator() const { return idxTrans_; }
00593
00594
00595 private:
00596 void createNumGndings()
00597 {
00598 numGndings_ = new Array<Array<double>*>;
00599 numGndings_->growToSize(domains_->size());
00600 for (int i = 0; i < domains_->size(); i++)
00601 {
00602 (*numGndings_)[i] = new Array<double>;
00603 (*numGndings_)[i]->growToSize((*domains_)[i]->getNumPredicates(), -1);
00604 }
00605
00606 for (int i = 0; i < domains_->size(); i++)
00607 {
00608 const Domain* domain = (*domains_)[i];
00609 for (int j = 0; j < domain->getNumPredicates(); j++)
00610 {
00611 if (!(*areNonEvidPreds_)[j]) continue;
00612 const PredicateTemplate* pt = domain->getPredicateTemplate(j);
00613
00614 double numGndings = 1;
00615 for (int t = 0; t < pt->getNumTerms(); t++)
00616 {
00617 int typeId = pt->getTermTypeAsInt(t);
00618 numGndings *= domain->getNumConstantsByType(typeId);
00619 }
00620 (*((*numGndings_)[i]))[j] = numGndings;
00621 }
00622 }
00623 }
00624
00625
00626 void createAllPossiblePredsGroundings(const Predicate* const & pred,
00627 const Domain* const & domain,
00628 ArraysAccessor<int>& acc)
00629 {
00630 for (int i = 0; i < pred->getNumTerms(); i++)
00631 {
00632 int typeId = pred->getTermTypeAsInt(i);
00633 const Array<int>* constArr = domain->getConstantsByType(typeId);
00634 acc.appendArray(constArr);
00635 }
00636 }
00637
00638
00639 bool isSampledGndPred(const int& g, const SampledGndings* const & sg)
00640 {
00641 if (sg->falseGndings.contains(g)) return true;
00642 if (sg->trueGndings.contains(g)) return true;
00643 return false;
00644 }
00645
00646
00647 void computeCountsRemoveCountsHelper(bool computeCounts,
00648 const Clause* const & c,
00649 int* const & clauseIdxInMLN,
00650 const int& domainIdx,
00651 Array<UndoInfo*>* const & undoInfos,
00652 const bool& sampleClauses,
00653 Array<Array<Array<CacheCount*>*>*>* const & cache)
00654 {
00655
00656 const Domain* domain = (*domains_)[domainIdx];
00657 Database* db = domain->getDB();
00658
00659
00660 Array<int> predIdxWithAllTermsDiffVars(domain->getNumPredicates());
00661 predIdxWithAllTermsDiffVars.growToSize(domain->getNumPredicates());
00662 int* parr = (int*) predIdxWithAllTermsDiffVars.getItems();
00663 memset(parr, -1, domain->getNumPredicates()*sizeof(int));
00664
00665
00666 Array<bool> predAllTermsAreDiffVars(c->getNumPredicates());
00667 createPredAllTermsAreDiffVars(c, predAllTermsAreDiffVars,
00668 predIdxWithAllTermsDiffVars);
00669
00670
00671 PredicateHashArray seenPreds;
00672
00673
00674 for (int p = 0; p < c->getNumPredicates(); p++)
00675 {
00676 Predicate* pred = c->getPredicate(p);
00677 int predId = pred->getId();
00678 if (!(*areNonEvidPreds_)[predId]) continue;
00679
00680 Predicate gndPred(*pred);
00681 gndPred.canonicalize();
00682 bool predIsInitiallyGnded = gndPred.isGrounded();
00683
00684 SampledGndings* sg = NULL;
00685 if (sampleGndPreds_) sg = (*(*sampledGndingsMaps_)[domainIdx])[predId];
00686
00687 Predicate* seenPred = new Predicate(gndPred);
00688
00689 if (seenPreds.append(seenPred) < 0) { delete seenPred; continue; }
00690
00691 if (predAllTermsAreDiffVars[p])
00692 {
00693
00694
00695
00696 ArraysAccessor<int> acc;
00697 createAllPossiblePredsGroundings(&gndPred, domain, acc);
00698
00699
00700 int g = -1;
00701 while (acc.hasNextCombination())
00702 {
00703 ++g;
00704
00705 int t = 0; int constId;
00706 while (acc.nextItemInCombination(constId))
00707 ((Term*) gndPred.getTerm(t++))->setId(constId);
00708
00709 if (sampleGndPreds_ && !isSampledGndPred(g,sg)) continue;
00710
00711 if (computeCounts)
00712 {
00713 computeAndSetCounts(c, clauseIdxInMLN, predId, gndPred, g, db,
00714 domainIdx, undoInfos, sampleClauses, cache);
00715 }
00716 else
00717 removeCounts(clauseIdxInMLN, predId, g, domainIdx, undoInfos);
00718 }
00719 }
00720 else
00721 {
00722
00723
00724 if (predIdxWithAllTermsDiffVars[predId] >= 0) continue;
00725
00726
00727
00728
00729 Array<int> multipliers(gndPred.getNumTerms());
00730 createMultipliers(multipliers, gndPred, domain);
00731
00732
00733
00734 int offsetDueToConstants = 0;
00735
00736
00737
00738
00739
00740 Array<Array<pair<int,Term*> >* > varIdToMults;
00741 Array<int> negVarIdsArr;
00742 ArraysAccessor<int> groundings;
00743 createMappingOfVarIdToMultipliersAndVarGroundingsAndOffset(
00744 gndPred, domain, multipliers, offsetDueToConstants, varIdToMults,
00745 negVarIdsArr, groundings);
00746
00747
00748 if (!predIsInitiallyGnded)
00749 {
00750
00751 int constId, constIdx;
00752 while (groundings.hasNextCombination())
00753 {
00754 int g = offsetDueToConstants;
00755 int j = -1;
00756 while (groundings.nextItemInCombination(constId, constIdx))
00757 {
00758 ++j;
00759 int negVarId = negVarIdsArr[j];
00760 Array<pair<int,Term*> >* multsAndTerms = varIdToMults[negVarId];
00761 for (int m = 0; m < multsAndTerms->size(); m++)
00762 {
00763 g += constIdx * (*multsAndTerms)[m].first;
00764 (*multsAndTerms)[m].second->setId(constId);
00765 }
00766 }
00767
00768 if (sampleGndPreds_ && !isSampledGndPred(g,sg)) continue;
00769
00770 if (computeCounts)
00771 computeAndSetCounts(c, clauseIdxInMLN, predId, gndPred, g, db,
00772 domainIdx, undoInfos, sampleClauses, cache);
00773 else
00774 removeCounts(clauseIdxInMLN, predId, g, domainIdx, undoInfos);
00775 }
00776 }
00777 else
00778 {
00779 int g = offsetDueToConstants;
00780
00781 bool ok = true;
00782 if (sampleGndPreds_) ok = isSampledGndPred(g,sg);
00783
00784 if (ok)
00785 {
00786 if (computeCounts)
00787 computeAndSetCounts(c, clauseIdxInMLN, predId, gndPred, g, db,
00788 domainIdx, undoInfos, sampleClauses, cache);
00789 else
00790 removeCounts(clauseIdxInMLN, predId, g, domainIdx, undoInfos);
00791 }
00792 }
00793
00794 for (int j = 0; j < varIdToMults.size(); j++) delete varIdToMults[j];
00795 }
00796 }
00797
00798 for (int i = 0; i < seenPreds.size(); i++) delete seenPreds[i];
00799
00800
00801 }
00802
00803
00804 void computePerPredPllAndGrad(const Array<Array<Array<IndexAndCount*>*>*>*
00805 const& gndingsToClauseIndexesAndCounts,
00806 const int& g, const double* const & wt,
00807 long double& perPredPll,
00808 long double * const & perPredGrad)
00809 {
00810
00811
00812
00813
00814
00815
00816
00817
00818
00819
00820
00821
00822
00823
00824
00825
00826
00827
00828
00829
00830
00831
00832 long double wdotn = 0;
00833
00834 long double pmb = 1;
00835
00836 Array<Array<IndexAndCount*>*>* gndings =
00837 (*gndingsToClauseIndexesAndCounts)[g];
00838
00839
00840 for (int c = 0; c < gndings->size(); c++)
00841 {
00842 Array<IndexAndCount*>* clauseIndexesAndCounts
00843 = (*gndings)[c];
00844 assert(noRepeatedIndex(clauseIndexesAndCounts));
00845
00846 int numClausesUnifyWith = clauseIndexesAndCounts->size();
00847
00848
00849 for (int i = 0; i < numClausesUnifyWith; i++)
00850 {
00851
00852
00853
00854 wdotn += wt[ *( (*clauseIndexesAndCounts)[i]->index ) ] *
00855 (*clauseIndexesAndCounts)[i]->count;
00856 }
00857
00858 pmb += expl(wdotn);
00859 }
00860
00861 perPredPll -= logl(pmb);
00862
00863
00864 for (int c = 0; c < gndings->size(); c++)
00865 {
00866 Array<IndexAndCount*>* clauseIndexesAndCounts
00867 = (*gndings)[c];
00868
00869 for (int i = 0; i < clauseIndexesAndCounts->size(); i++)
00870 perPredGrad[ *( (*clauseIndexesAndCounts)[i]->index ) ]
00871 += ( (1.0/pmb-1) * (*clauseIndexesAndCounts)[i]->count );
00872 }
00873 }
00874
00875 void computeSampledPerPredPllAndGrad(IntHashArray& gndings,
00876 const int& totalGndings,
00877 long double& tmpPerPredPll,
00878 long double* const & tmpPerPredGrad,
00879 long double& perPredPll,
00880 long double* const & perPredGrad,
00881 const int& arrSize,
00882 const double* const & wt,
00883 const Array<Array<Array<IndexAndCount*>*>*>*
00884 const& gndingsToClauseIndexesAndCounts)
00885 {
00886 tmpPerPredPll = 0;
00887 memset(tmpPerPredGrad, 0, arrSize*sizeof(long double));
00888 for (int i = 0; i < gndings.size(); i++)
00889 computePerPredPllAndGrad(gndingsToClauseIndexesAndCounts,
00890 gndings[i], wt, tmpPerPredPll,
00891 tmpPerPredGrad);
00892 if (gndings.size() > 0)
00893 {
00894 perPredPll += totalGndings * tmpPerPredPll/gndings.size();
00895
00896 for (int i = 0; i < arrSize; i++)
00897 perPredGrad[i] += totalGndings * tmpPerPredGrad[i]/gndings.size();
00898 }
00899
00900 }
00901
00902
00903
00904 double getValueAndGradientForDomain(double* const & gradient,
00905 const double* const & wt,
00906 const int& arrSize, const int& domainIdx)
00907 {
00908 long double wpll = 0;
00909 long double* perPredGrad = new long double[arrSize];
00910
00911
00912 long double tmpPerPredPll;
00913 long double* tmpPerPredGrad = NULL;
00914 if (sampleGndPreds_) tmpPerPredGrad = new long double[arrSize];
00915
00916 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
00917 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
00918
00919 int numPreds = gndPredClauseIndexesAndCounts->size();
00920 for (int p = 0; p < numPreds; p++)
00921 {
00922 if (!(*areNonEvidPreds_)[p]) continue;
00923
00924
00925
00926
00927
00928 long double perPredPll = 0;
00929 memset(perPredGrad, 0, arrSize*sizeof(long double));
00930
00931
00932 if ((*gndPredClauseIndexesAndCounts)[p] != NULL)
00933 {
00934 Array<Array<Array<IndexAndCount*>*>*>* gndingsToClauseIndexesAndCounts
00935 = (*gndPredClauseIndexesAndCounts)[p];
00936
00937 if (sampleGndPreds_)
00938 {
00939 SampledGndings* sg = (*(*sampledGndingsMaps_)[domainIdx])[p];
00940 computeSampledPerPredPllAndGrad(sg->trueGndings, sg->totalTrue,
00941 tmpPerPredPll, tmpPerPredGrad,
00942 perPredPll, perPredGrad, arrSize,
00943 wt, gndingsToClauseIndexesAndCounts);
00944 computeSampledPerPredPllAndGrad(sg->falseGndings, sg->totalFalse,
00945 tmpPerPredPll, tmpPerPredGrad,
00946 perPredPll, perPredGrad, arrSize,
00947 wt, gndingsToClauseIndexesAndCounts);
00948 }
00949 else
00950 {
00951 int numGnds = gndingsToClauseIndexesAndCounts->size();
00952 assert(numGnds == (*((*numGndings_)[domainIdx]))[p]);
00953 for (int g = 0; g < numGnds; g++)
00954 computePerPredPllAndGrad(gndingsToClauseIndexesAndCounts, g, wt,
00955 perPredPll, perPredGrad);
00956 }
00957 }
00958 else
00959 {
00960 perPredPll = (*((*numGndings_)[domainIdx]))[p] * -log(2);
00961
00962 }
00963
00964 if (wtFOPred_)
00965 {
00966
00967 wpll -= perPredPll / (*((*numGndings_)[domainIdx]))[p];
00968 for (int i = 0; i < arrSize; i++)
00969 gradient[i] -= perPredGrad[i]/(*((*numGndings_)[domainIdx]))[p];
00970 }
00971 else
00972 {
00973
00974 wpll -= perPredPll;
00975 for (int i = 0; i < arrSize; i++)
00976 gradient[i] -= perPredGrad[i];
00977 }
00978 }
00979
00980 delete [] perPredGrad;
00981 if (sampleGndPreds_) delete [] tmpPerPredGrad;
00982
00983 return wpll;
00984 }
00985
00986 void createClauseIndexesAndCountsArrays(const int& predId,
00987 const int& domainIdx)
00988 {
00989 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
00990 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
00991 if ((*gndPredClauseIndexesAndCounts)[predId] != NULL) return;
00992
00993 Array<Array<Array<IndexAndCount*>*>*>* arr =
00994 new Array<Array<Array<IndexAndCount*>*>*>;
00995 double numGndings = (*((*numGndings_)[domainIdx]))[predId];
00996
00997
00998
00999 for (int g = 0; g < numGndings; g++)
01000 arr->append(new Array<Array<IndexAndCount*>*>);
01001 arr->compress();
01002 (*gndPredClauseIndexesAndCounts)[predId] = arr;
01003 }
01004
01005 void createComboClauseIndexesAndCountsArrays(const int& predId,
01006 const int& domainIdx,
01007 Predicate* const & gndPred,
01008 const int& g)
01009 {
01010 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
01011 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
01012
01013 Array<Array<IndexAndCount*>*>* comboClauseIndexesAndCounts
01014 = (*(*gndPredClauseIndexesAndCounts)[predId])[g];
01015 if (comboClauseIndexesAndCounts->size() > 0) return;
01016
01017
01018 int numCombInBlock = 1;
01019
01020 int blockIdx = (*domains_)[domainIdx]->getBlock(gndPred);
01021 if (blockIdx >= 0)
01022 {
01023 const Array<Predicate*>* block =
01024 (*domains_)[domainIdx]->getPredBlock(blockIdx);
01025 numCombInBlock = block->size() - 1;
01026 }
01027
01028 comboClauseIndexesAndCounts->growToSize(numCombInBlock, NULL);
01029 for (int c = 0; c < numCombInBlock; c++)
01030 {
01031 (*comboClauseIndexesAndCounts)[c] = new Array<IndexAndCount*>;
01032 }
01033 comboClauseIndexesAndCounts->compress();
01034 }
01035
01036
01037
01038
01039
01040
01041
01042
01043
01044 bool computeAndSetCounts(const Clause* const clause,
01045 int* const & clauseIdxInMLN,
01046 const int& predId, Predicate& gndPred,
01047 const int& g, Database* const & db,
01048 const int& domainIdx,
01049 Array<UndoInfo*>* const & undoInfos,
01050 const bool& sampleClauses,
01051 Array<Array<Array<CacheCount*>*>*>* const & cache)
01052 {
01053 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
01054 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
01055 const Domain* domain = (*domains_)[domainIdx];
01056
01057 if ((*gndPredClauseIndexesAndCounts)[predId] == NULL)
01058 createClauseIndexesAndCountsArrays(predId, domainIdx);
01059
01060 createComboClauseIndexesAndCountsArrays(predId, domainIdx, &gndPred, g);
01061
01062 Array<Array<IndexAndCount*>*>* comboClauseIndexesAndCounts
01063 = (*(*gndPredClauseIndexesAndCounts)[predId])[g];
01064
01065
01066
01067
01068
01069
01070
01071
01072
01073
01074
01075
01076 for (int c = 0; c < comboClauseIndexesAndCounts->size(); c++)
01077 {
01078
01079
01080 Array<IndexAndCount*>* gArr = (*comboClauseIndexesAndCounts)[c];
01081 if (gArr->size() > 0 && *( gArr->lastItem()->index ) == *clauseIdxInMLN)
01082 {
01083
01084
01085 continue;
01086 }
01087
01088 double cnt =
01089 ((Clause*)clause)->countDiffNumTrueGroundings(&gndPred, domain, db,
01090 DB_HAS_UNKNOWN_PREDS,
01091 sampleClauses, c);
01092
01093
01094 if (cnt != 0)
01095 {
01096
01097 gArr->append(new IndexAndCount(clauseIdxInMLN, cnt));
01098 if (undoInfos)
01099 undoInfos->append(new UndoInfo(gArr, NULL, -1, domainIdx));
01100
01101 if (cache)
01102 {
01103 Array<CacheCount*>*& ccArr = (*(*cache)[domainIdx])[predId];
01104 if (ccArr == NULL) ccArr = new Array<CacheCount*>;
01105 ccArr->append(new CacheCount(g, c, cnt));
01106 }
01107 }
01108
01109 assert(noRepeatedIndex(gArr));
01110 }
01111
01112 return true;
01113 }
01114
01115
01116
01117 bool removeCounts(int* const & clauseIdxInMLN, const int& predId,
01118 const int& g, const int& domainIdx,
01119 Array<UndoInfo*>* const & undoInfos)
01120 {
01121 bool removed = false;
01122 Array<Array<Array<Array<IndexAndCount*>*>*>*>* gndPredClauseIndexesAndCounts
01123 = (*gndPredClauseIndexesAndCountsArr_)[domainIdx];
01124
01125 if ((*gndPredClauseIndexesAndCounts)[predId] == NULL) return false;
01126
01127 Array<Array<IndexAndCount*>*>* comboClauseIndexesAndCounts
01128 = (*(*gndPredClauseIndexesAndCounts)[predId])[g];
01129
01130 for (int c = 0; c < comboClauseIndexesAndCounts->size(); c++)
01131 {
01132 Array<IndexAndCount*>* gArr =(*comboClauseIndexesAndCounts)[c];
01133 for (int i = 0; i < gArr->size(); i++)
01134 {
01135 if ((*gArr)[i]->index == clauseIdxInMLN)
01136 {
01137 IndexAndCount* ic = gArr->removeItemFastDisorder(i);
01138
01139 if (undoInfos)
01140 {
01141 undoInfos->append(new UndoInfo(gArr, ic, i, domainIdx));
01142 assert(noRepeatedIndex(gArr));
01143
01144 removed = true;
01145 }
01146 else
01147 {
01148 delete ic;
01149 assert(noRepeatedIndex(gArr));
01150
01151 removed = true;
01152 }
01153 }
01154 }
01155 assert(noRepeatedIndex(gArr));
01156 }
01157
01158
01159 return removed;
01160 }
01161
01162
01163 void createPredAllTermsAreDiffVars(const Clause* const & c,
01164 Array<bool>& predAllTermsAreDiffVars,
01165 Array<int>& predIdxWithAllTermsDiffVars)
01166 {
01167 for (int p = 0; p < c->getNumPredicates(); p++)
01168 {
01169 Predicate* pred = c->getPredicate(p);
01170 bool allDiffVars;
01171
01172 if (c->isDirty())
01173 {
01174
01175 allDiffVars = pred->checkAllTermsAreDiffVars();
01176 }
01177 else
01178 {
01179 assert(!pred->isDirty());
01180 allDiffVars = pred->allTermsAreDiffVars();
01181 }
01182
01183 predAllTermsAreDiffVars.append(allDiffVars);
01184 if (allDiffVars)
01185 {
01186 int predId = pred->getId();
01187 if (predIdxWithAllTermsDiffVars[predId] < 0)
01188 predIdxWithAllTermsDiffVars[predId] = p;
01189 }
01190 }
01191 predAllTermsAreDiffVars.compress();
01192 }
01193
01194
01195 void createMultipliers(Array<int>& multipliers,
01196 const Predicate& gndPred,
01197 const Domain* const & domain)
01198 {
01199 int mult = 1;
01200 int numTerms = gndPred.getNumTerms();
01201 multipliers.growToSize(numTerms);
01202 for (int j = numTerms-1; j >= 0; j--)
01203 {
01204 multipliers[j] = mult;
01205 int typeId = gndPred.getTermTypeAsInt(j);
01206 mult *= domain->getNumConstantsByType(typeId);
01207 }
01208 }
01209
01210
01211 void createMappingOfVarIdToMultipliersAndVarGroundingsAndOffset(
01212 const Predicate& gndPred,
01213 const Domain* const & domain,
01214 Array<int>& multipliers,
01215 int& offsetDueToConstants,
01216 Array<Array<pair<int,Term*> >* >& varIdToMults,
01217 Array<int>& negVarIdsArr,
01218 ArraysAccessor<int>& groundings)
01219 {
01220 for (int j = 0; j < gndPred.getNumTerms(); j++)
01221 {
01222 const Term* t = gndPred.getTerm(j);
01223 if (t->getType() == Term::VARIABLE)
01224 {
01225 assert(t->getId()<0);
01226 int id = -(t->getId());
01227 if (id >= varIdToMults.size()) varIdToMults.growToSize(id+1,NULL);
01228 if (varIdToMults[id] == NULL)
01229 {
01230 negVarIdsArr.append(id);
01231 varIdToMults[id] = new Array<pair<int,Term*> >;
01232 int typeId = gndPred.getTermTypeAsInt(j);
01233 const Array<int>* constants = domain->getConstantsByType(typeId);
01234 groundings.appendArray(constants);
01235 }
01236 varIdToMults[id]->append(pair<int,Term*>(multipliers[j], (Term*)t));
01237
01238 }
01239 else
01240 if (t->getType() == Term::CONSTANT)
01241 {
01242 int id = t->getId();
01243 assert(id >= 0);
01244 int typeId = gndPred.getTermTypeAsInt(j);
01245 const Array<int>* constants = domain->getConstantsByType(typeId);
01246 assert(constants->size() > 0);
01247
01248
01249 int firstConstId = (*constants)[0];
01250 assert(id >= firstConstId);
01251 offsetDueToConstants += (id-firstConstId)*multipliers[j];
01252 }
01253 else
01254 {
01255 assert(false);
01256 }
01257 }
01258 }
01259
01260
01261 void randomlySelect(IntHashArray& gndings, const double& fraction,
01262 const int& min, const int& max)
01263 {
01264 int size = int(fraction * gndings.size() + 0.5);
01265 if (min >= 0 && size < min) size = min;
01266 else if (max >= 0 && size > max) size = max;
01267 while (gndings.size() > size)
01268 gndings.removeItemFastDisorder(random_->randomOneOf(gndings.size()));
01269 }
01270
01271
01272 void samplePredGroundingsForDomain(const Predicate* const& foPred,
01273 const Domain* const & domain,
01274 SampledGndings* sampledGndings,
01275 const double& fraction,
01276 const int& min, const int& max)
01277 {
01278 cout << "sampling predicate "; foPred->printWithStrVar(cout, domain);
01279 cout << endl;
01280
01281 assert(((Predicate*)foPred)->allTermsAreDiffVars());
01282 ArraysAccessor<int> acc;
01283 createAllPossiblePredsGroundings(foPred, domain, acc);
01284 Predicate* gndPred = (Predicate*) foPred;
01285 const Database* db = domain->getDB();
01286
01287 IntHashArray& trueGndings = sampledGndings->trueGndings;
01288 IntHashArray& falseGndings = sampledGndings->falseGndings;
01289
01290 int g = -1;
01291 while (acc.hasNextCombination())
01292 {
01293 ++g;
01294 int t = 0;
01295 int constId;
01296 while (acc.nextItemInCombination(constId))
01297 ((Term*) gndPred->getTerm(t++))->setId(constId);
01298
01299 TruthValue tv = db->getValue(gndPred);
01300 if (tv == TRUE) trueGndings.append(g);
01301 else if (tv == FALSE) falseGndings.append(g);
01302 }
01303
01304 sampledGndings->totalTrue = trueGndings.size();
01305 sampledGndings->totalFalse = falseGndings.size();
01306 randomlySelect(trueGndings, fraction, min, max);
01307 randomlySelect(falseGndings, fraction, min, max);
01308 trueGndings.compress();
01309 falseGndings.compress();
01310
01311 cout << "\tsampled/total (true ground atoms) = "
01312 << trueGndings.size() << "/" << sampledGndings->totalTrue << endl;
01313 cout << "\tsampled/total (false ground atoms) = "
01314 << falseGndings.size() << "/" << sampledGndings->totalFalse << endl;
01315 }
01316
01317
01318 void samplePredGroundings(const double& fraction,
01319 const int& min, const int& max)
01320 {
01321 for (int d = 0; d < domains_->size(); d++)
01322 {
01323 cout << "domain " << d << endl;
01324 const Domain* domain = (*domains_)[d];
01325 Array<SampledGndings*>* sgm = (*sampledGndingsMaps_)[d];
01326 for (int p = 0; p < domain->getNumPredicates(); p++)
01327 {
01328 if (!(*areNonEvidPreds_)[p]) continue;
01329 SampledGndings* sg = new SampledGndings;
01330 assert((*sgm)[p] == NULL);
01331 (*sgm)[p] = sg;
01332 Predicate* foPred = domain->createPredicate(p, true);
01333 assert(foPred);
01334 samplePredGroundingsForDomain(foPred, domain, sg, fraction, min, max);
01335 delete foPred;
01336 }
01337 }
01338 }
01339
01340
01341 bool noRepeatedIndex(const Array<IndexAndCount*>* const & gArr,
01342 const MLN* const & mln0=NULL)
01343 {
01344 hash_set<int> set;
01345 for (int i = 0; i < gArr->size(); i++)
01346 {
01347 int ii = *((*gArr)[i]->index);
01348 if (set.find(ii) != set.end())
01349 {
01350 cout << "ERROR: in PseudoLogLikelihood::noRepeatedIndex. "
01351 << "Repeated index " << ii << " found. ";
01352 if (mln0)
01353 mln0->getClause(ii)->printWithoutWtWithStrVar(cout, (*domains_)[0]);
01354 cout << endl;
01355 return false;
01356 }
01357 set.insert(ii);
01358 }
01359 return true;
01360 }
01361
01362
01363 void printWtsGradsWPLL(const double* const & wts, const double* const & grads,
01364 const int& arrSize, const double& wpll)
01365 {
01366 cout.precision(10);
01367 cout << "wts = " << endl;
01368 for (int i = 0; i < arrSize; i++) cout << " " << i << ":" << wts[i];
01369 cout << endl;
01370 cout << "grads = " << endl;
01371 for (int i = 0; i < arrSize; i++) cout << " " << i << ":" << grads[i];
01372 cout << endl;
01373 cout << "wpll = " << wpll << endl;
01374 cout << endl;
01375 cout.precision(6);
01376 }
01377
01378
01379 private:
01380 Array<bool>* areNonEvidPreds_;
01381
01382
01383 Array<Domain*>* domains_;
01384
01385
01386
01387
01388
01389
01390
01391
01392
01393 Array<Array<Array<Array<Array<IndexAndCount*>*>*>*>*>*
01394 gndPredClauseIndexesAndCountsArr_;
01395
01396 int numMeans_;
01397 const double* priorMeans_;
01398 const double* priorStdDevs_;
01399
01400 bool wtFOPred_;
01401 Array<Array<double>*>* numGndings_;
01402
01403 bool sampleGndPreds_;
01404
01405
01406 Array<Array<SampledGndings*>*>* sampledGndingsMaps_;
01407 Random* random_;
01408
01409 IndexTranslator* idxTrans_;
01410 };
01411
01412
01413 #endif