00001 #ifndef GFEATURE_H_APR_06
00002 #define GFEATURE_H_APR_06
00003
00004 #include "feature.h"
00005
00006 class GroundFeature
00007 {
00008 public:
00009 GroundFeature()
00010 : dirtyDeriv_(true), dirtyValue_(true), dirtyLogValue_(true)
00011 { }
00012
00013 virtual ~GroundFeature() { }
00014
00015 double getValue() {
00016
00017 if (dirtyValue_) {
00018 cachedValue_ = computeValue();
00019 dirtyValue_ = false;
00020 }
00021
00022 return cachedValue_;
00023 }
00024
00025 double getLogValue() {
00026
00027 if (dirtyLogValue_) {
00028 cachedLogValue_ = computeLogValue();
00029 dirtyLogValue_ = false;
00030 }
00031
00032 return cachedLogValue_;
00033 }
00034
00035 double getCounts(int featureIndex, int weightIndex) {
00036 return computePartialDeriv(featureIndex, weightIndex);
00037 }
00038
00039 virtual double computeValue() { return 0.0; }
00040 virtual double computeLogValue() { return 0.0; }
00041 virtual double computePartialDeriv(int featureIndex, int weightIndex)
00042 { return 0.0; }
00043
00044 inline double getDeriv();
00045
00046
00047
00048 virtual double getChildDeriv(int i) {
00049 assert(false);
00050 return 0.0;
00051 }
00052
00053
00054 inline void setDirty();
00055
00056 inline void updateParents(double oldValue, double newValue);
00057 inline void updateParentCounts(double oldValue, double newValue);
00058
00059 void addParent(RecursiveGroundFeature* parent, int parentIndex) {
00060 parents_.append(parent);
00061 parentIndices_.append(parentIndex);
00062 }
00063
00064
00065 protected:
00066 bool dirtyDeriv_;
00067 bool dirtyValue_;
00068 bool dirtyLogValue_;
00069 double cachedDeriv_;
00070 double cachedValue_;
00071 double cachedLogValue_;
00072 Array<RecursiveGroundFeature*> parents_;
00073 Array<int> parentIndices_;
00074 };
00075
00076
00077 class RecursiveGroundFeature : public GroundFeature
00078 {
00079 public:
00080 RecursiveGroundFeature(RecursiveFeature* featureTemplate,
00081 bool doDerivsOfLog=false)
00082 : featureTemplate_(featureTemplate), doDerivsOfLog_(doDerivsOfLog)
00083 { }
00084
00085 virtual ~RecursiveGroundFeature() { }
00086
00087 virtual double computeValue()
00088 {
00089 #if SIGMOID
00090 computeLogValue();
00091 return sigmoid(cachedSum_);
00092 #else
00093 return exp(computeLogValue());
00094 #endif
00095 }
00096
00097 virtual double computeLogValue()
00098 {
00099 cachedSum_ = 0.0;
00100 subSums_.growToSize(children_.size());
00101
00102 for (int i = 0; i < children_.size(); i++) {
00103 double currFeatureTotal = 0.0;
00104 for (int j = 0; j < children_[i].size(); j++) {
00105
00106
00107 currFeatureTotal += children_[i][j]->getValue();
00108 }
00109 #if NORM
00110 cachedSum_ += featureTemplate_->getWeight(i) * currFeatureTotal
00111 / children_[i].size();
00112 #else
00113 cachedSum_ += featureTemplate_->getWeight(i) * currFeatureTotal;
00114 #endif
00115 subSums_[i] = currFeatureTotal;
00116 }
00117
00118 RecursiveFeature* recFeature = (RecursiveFeature*)featureTemplate_;
00119
00120
00121
00122 return cachedSum_ - recFeature->getLogZ();
00123 }
00124
00125 virtual void update(int idx, double oldValue, double newValue)
00126 {
00127 dirtyDeriv_ = true;
00128
00129 if (!dirtyValue_ || !dirtyLogValue_) {
00130 double oldCachedValue = cachedValue_;
00131 #if NORM
00132 cachedSum_ += featureTemplate_->getWeight(idx)
00133 * (newValue - oldValue)
00134 /children_[idx].size();
00135 #else
00136 cachedSum_ += featureTemplate_->getWeight(idx)
00137 * (newValue - oldValue);
00138 #endif
00139 subSums_[idx] += newValue - oldValue;
00140
00141 RecursiveFeature* recFeature = (RecursiveFeature*)featureTemplate_;
00142 cachedLogValue_ = cachedSum_ - recFeature->getLogZ();
00143 #if SIGMOID
00144 cachedValue_ = sigmoid(cachedSum_);
00145 #else
00146 cachedValue_ = exp(cachedLogValue_);
00147 #endif
00148 updateParents(oldCachedValue, cachedValue_);
00149 }
00150 }
00151
00152 virtual void updateCounts(int idx, double oldValue, double newValue)
00153 {
00154 if ((dirtyValue_ && dirtyLogValue_) || dirtyDeriv_) {
00155 cout << "ERROR: cache must be up-to-date when calling "
00156 << "updateCounts!\n";
00157 return;
00158 }
00159
00160 double counts = featureTemplate_->getCount(idx);
00161
00162
00163
00164
00165 double totalPartial = 0.0;
00166 {
00167 #if 0
00168
00169 double totalValue = 0.0;
00170 for (int w = 0; w < children_.size(); w++) {
00171 double sum = 0.0;
00172 for (int j = 0; j < children_[w].size(); j++) {
00173 sum += children_[w][j]->getValue();
00174 }
00175 totalValue += sum * featureTemplate_->getWeight(w);
00176 }
00177 cout << "Old: " << cachedSum_ << endl;
00178 cout << "New: " << totalValue << endl;
00179 totalValue += (oldValue - newValue)
00180 * featureTemplate_->getWeight(idx);
00181 cout << "Computed: " << totalValue << endl;
00182
00183 #endif
00184
00185 RecursiveFeature* recTemplate
00186 = (RecursiveFeature*)featureTemplate_;
00187 totalPartial = subSums_[idx] - recTemplate->getNorm(idx);
00188
00189 if (!doDerivsOfLog_) {
00190 #if SIGMOID
00191 totalPartial *= getValue() * (1.0 - getValue());
00192 #else
00193 totalPartial *= getValue();
00194 #endif
00195 }
00196
00197 }
00198
00199
00200
00201
00202
00203
00204
00205 counts -= cachedDeriv_ * totalPartial;
00206
00207
00208
00209
00210
00211 double oldCachedValue = cachedValue_;
00212 cachedSum_ += featureTemplate_->getWeight(idx)
00213 * (newValue - oldValue);
00214 subSums_[idx] += newValue - oldValue;
00215
00216 RecursiveFeature* recFeature = (RecursiveFeature*)featureTemplate_;
00217 cachedLogValue_ = cachedSum_ - recFeature->getLogZ();
00218 #if SIGMOID
00219 cachedValue_ = sigmoid(cachedSum_);
00220 #else
00221 cachedValue_ = exp(cachedLogValue_);
00222 #endif
00223
00224
00225 updateParentCounts(oldCachedValue, cachedValue_);
00226
00227
00228 cachedDeriv_ = 0.0;
00229 for (int i = 0; i < parents_.size(); i++) {
00230 cachedDeriv_ += parents_[i]->getDeriv()
00231 * parents_[i]->getChildDeriv(parentIndices_[i]);
00232 }
00233
00234 if (parents_.size() == 0.0) {
00235 cachedDeriv_ = 1.0;
00236 }
00237
00238
00239 {
00240 RecursiveFeature* recTemplate
00241 = (RecursiveFeature*)featureTemplate_;
00242 totalPartial = subSums_[idx] - recTemplate->getNorm(idx);
00243
00244 if (!doDerivsOfLog_) {
00245 #if SIGMOID
00246 totalPartial *= getValue() * (1.0 - getValue());
00247 #else
00248 totalPartial *= getValue();
00249 #endif
00250 }
00251 }
00252
00253 counts += cachedDeriv_ * totalPartial;
00254 featureTemplate_->setCount(idx,counts);
00255 }
00256
00257
00258 virtual double getChildDeriv(int i)
00259 {
00260 double totalPartial = featureTemplate_->getWeight(i);
00261 #if NORM
00262 totalPartial /= children_[i].size();
00263 #endif
00264
00265 if (doDerivsOfLog_) {
00266 return totalPartial;
00267 } else {
00268 return getValue() * totalPartial;
00269 }
00270 }
00271
00272 virtual void addChild(int childIndex, GroundFeature* child)
00273 {
00274 while (children_.size() <= childIndex) {
00275 children_.append(Array<GroundFeature*>());
00276 }
00277
00278 children_[childIndex].append(child);
00279 child->addParent(this, childIndex);
00280 }
00281
00282 virtual double computePartialDeriv(int fi, int wi)
00283 {
00284 double totalPartial = 0.0;
00285
00286 if (fi == featureTemplate_->getId()) {
00287
00288 if (dirtyLogValue_ && dirtyValue_) {
00289 computeLogValue();
00290 }
00291
00292 RecursiveFeature* recTemplate
00293 = (RecursiveFeature*)featureTemplate_;
00294 totalPartial = subSums_[wi] - recTemplate->getNorm(wi);
00295 #if NORM
00296 totalPartial /= children_[wi].size();
00297 #endif
00298 } else {
00299
00300
00301 assert(false);
00302
00303
00304
00305 for (int i = 0; i < children_.size(); i++)
00306 {
00307 double featureTotal = 0.0;
00308 for (int j = 0; j < children_[i].size(); j++) {
00309 featureTotal += children_[i][j]->getCounts(fi, wi);
00310 }
00311 #if NORM
00312 totalPartial += featureTotal * featureTemplate_->getWeight(i)
00313 / children_[i].size();
00314 #else
00315 totalPartial += featureTotal * featureTemplate_->getWeight(i);
00316 #endif
00317 }
00318 }
00319
00320 if (doDerivsOfLog_) {
00321 return totalPartial;
00322 } else {
00323 #if SIGMOID
00324 return getValue() * (1.0 - getValue()) * totalPartial;
00325 #else
00326 return getValue() * totalPartial;
00327 #endif
00328 }
00329 }
00330
00331 protected:
00332 Array<Array<GroundFeature*> > children_;
00333 RecursiveFeature* featureTemplate_;
00334 bool doDerivsOfLog_;
00335 double cachedSum_;
00336 Array<double> subSums_;
00337 };
00338
00339
00340 class ClausalGroundFeature : public RecursiveGroundFeature
00341 {
00342 public:
00343 ClausalGroundFeature(ClausalFeature* featureTemplate)
00344 : RecursiveGroundFeature(featureTemplate)
00345 { }
00346
00347 virtual ~ClausalGroundFeature() { }
00348
00349 #if 0
00350 virtual double computeLogValue() {
00351
00352 assert(false);
00353 }
00354 #endif
00355
00356
00357
00358 #if 0 // TODO: add subSums_ to this, or something.
00359 virtual double computeValue() {
00360
00361 cachedSum_ = 0.0;
00362 for (int i = 0; i < children_.size(); i++) {
00363 cachedSum_ += featureTemplate_->getWeight(i)
00364 * children_[i][0]->getValue();
00365 }
00366 #if SIGMOID
00367 return sigmoid(cachedSum_);
00368 #else
00369 ClausalFeature* clausalTemplate = (ClausalFeature*)featureTemplate_;
00370 return exp(cachedSum_ - clausalTemplate->getLogZ());
00371 #endif
00372 }
00373 #endif
00374
00375 #if 0
00376 virtual double getChildDeriv(int i)
00377 {
00378
00379
00380
00381
00382 assert(false);
00383 double weight_i = featureTemplate_->getWeight(i);
00384 double ret = getValue() * (weight_i -
00385 - 1.0/(1.0 + exp(-children_[i][0]->getValue())));
00386 return ret;
00387 }
00388
00389 virtual double computePartialDeriv(int fi, int wi)
00390 {
00391 ClausalFeature* clausalTemplate = (ClausalFeature*)featureTemplate_;
00392
00393 if (fi != clausalTemplate->getId()) {
00394 return 0.0;
00395 }
00396
00397 assert(children_[wi].size() == 1);
00398 #if SIGMOID
00399 return getValue() * (1.0 - getValue())
00400 * children_[wi][0]->getValue();
00401 #else
00402 return getValue() * (children_[wi][0]->getValue() -
00403 clausalTemplate->getNorm(wi));
00404 #endif
00405 }
00406
00407 virtual void update(int idx, double oldValue, double newValue)
00408 {
00409 dirtyDeriv_ = true;
00410 if (!dirtyValue_) {
00411 double oldCachedValue = cachedValue_;
00412 #if 0
00413 double oldCachedSum = cachedSum_;
00414 #endif
00415 cachedSum_ += featureTemplate_->getWeight(idx)
00416 * (newValue - oldValue);
00417 #if SIGMOID
00418 cachedValue_ = sigmoid(cachedSum_);
00419 #else
00420 ClausalFeature* clausalTemplate =
00421 (ClausalFeature*) featureTemplate_;
00422 cachedValue_ = exp(cachedSum_ - clausalTemplate->getLogZ());
00423 #endif
00424 #if 0
00425 double ourCachedSum_ = cachedSum_;
00426 computeValue();
00427 if (fabs(ourCachedSum_ - cachedSum_) > 0.00001) {
00428 cout << "Sums are different!\n";
00429 cout << ourCachedSum_ << endl;
00430 cout << cachedSum_ << endl;
00431 cout << "Old: " << oldCachedSum << endl;
00432 cout << oldValue << " -> " << newValue << endl;
00433 cout << featureTemplate_->getWeight(idx) << endl;
00434 }
00435 #endif
00436 updateParents(oldCachedValue, cachedValue_);
00437 }
00438 }
00439 #endif
00440 };
00441
00442
00443 class PredicateGroundFeature : public GroundFeature
00444 {
00445 public:
00446 PredicateGroundFeature(const Predicate& pred)
00447 : groundPred_(pred), predValue_(false)
00448 { }
00449
00450 virtual ~PredicateGroundFeature() { }
00451
00452 Predicate* getPredicate() { return &groundPred_; }
00453 virtual double computeValue() { return predValue_ ? 1.0 : 0.0; }
00454 void setValue(bool val)
00455 {
00456 if (predValue_ != val) {
00457 predValue_ = val;
00458 setDirty();
00459 }
00460 }
00461
00462 void setValueAndUpdate(bool val)
00463 {
00464 if (predValue_ != val) {
00465 predValue_ = val;
00466
00467 if (val) {
00468 cachedValue_ = 1.0;
00469 updateParents(0.0, 1.0);
00470 } else {
00471 cachedValue_ = 0.0;
00472 updateParents(1.0, 0.0);
00473 }
00474 }
00475 }
00476
00477 void setValueAndUpdateCounts(bool val)
00478 {
00479 if (predValue_ != val) {
00480 predValue_ = val;
00481
00482 if (val) {
00483 cachedValue_ = 1.0;
00484 updateParentCounts(0.0, 1.0);
00485 } else {
00486 cachedValue_ = 0.0;
00487 updateParentCounts(1.0, 0.0);
00488 }
00489 }
00490 }
00491
00492 virtual double computePartialDeriv(int featureIndex, int weightIndex)
00493 { return 0; }
00494
00495 private:
00496 Predicate groundPred_;
00497 bool predValue_;
00498 };
00499
00500
00501 inline GroundFeature* PredicateFeature::constructGroundFeature(
00502 GroundRRF* rrf, const Array<int>& grounding, Database* db)
00503 {
00504 PredicateGroundFeature* ret = new PredicateGroundFeature(pred_);
00505 ret->setValue(computeValue(grounding, db));
00506 return ret;
00507 }
00508
00509 class ConstantGroundFeature : public GroundFeature
00510 {
00511 public:
00512 ConstantGroundFeature(double value)
00513 : val_(value)
00514 { }
00515
00516 virtual ~ConstantGroundFeature() { }
00517
00518 virtual double computeValue() { return val_; }
00519 virtual double computePartialDeriv(int featureIndex, int weightIndex)
00520 { return 0; }
00521
00522 protected:
00523 double val_;
00524 };
00525
00526 inline GroundFeature* ConstantFeature::constructGroundFeature(
00527 GroundRRF* rrf, const Array<int>& grounding, Database* db)
00528 {
00529 return new ConstantGroundFeature(value_);
00530 }
00531
00532 class GroundRRF
00533 {
00534 public:
00535 GroundRRF(RRF* rrf, Database* db);
00536
00537 double getValue() { return root_->getValue(); }
00538
00539 double getLogValue() { return root_->getLogValue(); }
00540
00541
00542
00543 double getLogPseudoLikelihood(const Array<int>& queryPreds);
00544 double getLogPseudoLikelihood(const Array<Predicate*>& queryPreds);
00545
00546 void getCounts(Array<double>& counts);
00547
00548 void getPseudoCounts(Array<double>& counts, const Array<int>& queryPreds,
00549 double samplingFrac);
00550 void getPseudoCountsFast(Array<double>& counts,
00551 const Array<int>& queryPreds, double samplingFrac);
00552
00553 int getNumPredicateGroundings(int predIdx) {
00554
00555
00556 return allFeatures_[predIdx-1].size();
00557 }
00558
00559 bool getPredicateValue(int predIdx, int groundIdx) {
00560
00561
00562 if (allFeatures_[predIdx-1][groundIdx] != NULL) {
00563 return (((PredicateGroundFeature*)allFeatures_[predIdx-1][groundIdx])
00564 ->getValue() != 0.0);
00565 } else {
00566 return false;
00567 }
00568 }
00569
00570 void setPredicateValue(int predIdx, int groundIdx, bool value) {
00571
00572
00573
00574 if (allFeatures_[predIdx-1][groundIdx] != NULL) {
00575 ((PredicateGroundFeature*)allFeatures_[predIdx-1][groundIdx])
00576 ->setValue(value);
00577 }
00578 }
00579
00580 void setPredicateAndUpdate(int predIdx, int groundIdx, bool value) {
00581
00582
00583 if (allFeatures_[predIdx-1][groundIdx] != NULL) {
00584 ((PredicateGroundFeature*)allFeatures_[predIdx-1][groundIdx])
00585 ->setValueAndUpdate(value);
00586
00587 }
00588 }
00589
00590 void setPredicateAndUpdateCounts(int predIdx, int groundIdx, bool value) {
00591
00592
00593 if (allFeatures_[predIdx-1][groundIdx] != NULL) {
00594 ((PredicateGroundFeature*)allFeatures_[predIdx-1][groundIdx])
00595 ->setValueAndUpdateCounts(value);
00596 } else {
00597
00598 cout << "Feature was NULL! Possible bug.\n";
00599 }
00600 }
00601
00602 GroundFeature* getGroundFeature(Feature* feature, Array<int>& grounding)
00603 {
00604 int featureIdx = feature->getId();
00605 int groundIdx = feature->getGroundingIndex(grounding, db_);
00606
00607 while (allFeatures_[featureIdx].size() <= groundIdx) {
00608 allFeatures_[featureIdx].append((GroundFeature*)NULL);
00609 }
00610
00611 if (allFeatures_[featureIdx][groundIdx] == NULL) {
00612 allFeatures_[featureIdx][groundIdx]
00613 = feature->constructGroundFeature(this, grounding, db_);
00614 }
00615
00616 return allFeatures_[featureIdx][groundIdx];
00617 }
00618
00619 int getNumGroundings(int featureId) const {
00620 return allFeatures_[featureId].size();
00621 }
00622
00623 void dirtyAll() {
00624 for (int i = 0; i < allFeatures_.size(); i++) {
00625 for (int j = 0; j < allFeatures_[i].size(); j++) {
00626 if (allFeatures_[i][j] != NULL) {
00627 allFeatures_[i][j]->setDirty();
00628 }
00629 }
00630 }
00631 }
00632
00633 private:
00634 GroundFeature* root_;
00635 RRF* rrf_;
00636 Database* db_;
00637 int numCounts_;
00638
00639 Array<Array<GroundFeature*> > allFeatures_;
00640 };
00641
00642
00643 inline GroundFeature* RecursiveFeature::constructGroundFeature(
00644 GroundRRF* rrf, const Array<int>& grounding, Database* db)
00645 {
00646 RecursiveGroundFeature* ret = new RecursiveGroundFeature(
00647 this, doDerivsOfLog_);
00648
00649 if (numGroundings_.size() == 0) {
00650 cacheNumGroundings(db);
00651 }
00652
00653
00654 for (int i = 0; i < getNumChildren(); i++) {
00655
00656 Array<int> childGrounding;
00657 ArraysAccessor<int>* groundingIter =
00658 getChildGroundingIter(i, grounding, db);
00659
00660 do {
00661 groundingIter->getNextCombination(childGrounding);
00662 ret->addChild(i, rrf->getGroundFeature(children_[i],
00663 childGrounding));
00664 } while (groundingIter->hasNextCombination());
00665
00666 releaseChildGroundingIter(i, groundingIter);
00667 }
00668
00669 return ret;
00670 }
00671
00672 inline GroundFeature* ClausalFeature::constructGroundFeature(
00673 GroundRRF* rrf, const Array<int>& grounding, Database* db)
00674 {
00675 ClausalGroundFeature* ret = new ClausalGroundFeature(this);
00676
00677
00678 for (int i = 0; i < getNumChildren(); i++) {
00679
00680 Array<int> childGrounding;
00681 ArraysAccessor<int>* groundingIter =
00682 getChildGroundingIter(i, grounding, db);
00683
00684 do {
00685 groundingIter->getNextCombination(childGrounding);
00686 ret->addChild(i, rrf->getGroundFeature(children_[i],
00687 childGrounding));
00688 } while (groundingIter->hasNextCombination());
00689
00690 releaseChildGroundingIter(i, groundingIter);
00691 }
00692
00693 return ret;
00694 }
00695
00696
00697 double GroundFeature::getDeriv()
00698 {
00699 if (dirtyDeriv_) {
00700 if (parents_.size() == 0) {
00701 cachedDeriv_ = 1.0;
00702 } else {
00703 cachedDeriv_ = 0.0;
00704 for (int i = 0; i < parents_.size(); i++) {
00705 cachedDeriv_ += parents_[i]->getDeriv()
00706 * parents_[i]->getChildDeriv(parentIndices_[i]);
00707 }
00708 }
00709 dirtyDeriv_ = false;
00710 }
00711
00712 return cachedDeriv_;
00713 }
00714
00715 void GroundFeature::setDirty()
00716 {
00717 if (!dirtyDeriv_ || !dirtyValue_ || !dirtyLogValue_) {
00718 dirtyDeriv_ = true;
00719 dirtyValue_ = true;
00720 dirtyLogValue_ = true;
00721 for (int i = 0; i < parents_.size(); i++) {
00722 parents_[i]->setDirty();
00723 }
00724 }
00725 }
00726
00727 void GroundFeature::updateParents(double oldValue, double newValue)
00728 {
00729 for (int i = 0; i < parents_.size(); i++) {
00730 parents_[i]->update(parentIndices_[i], oldValue, newValue);
00731 }
00732 }
00733
00734 void GroundFeature::updateParentCounts(double oldValue, double newValue)
00735 {
00736 for (int i = 0; i < parents_.size(); i++) {
00737 parents_[i]->updateCounts(parentIndices_[i], oldValue, newValue);
00738 }
00739 }
00740
00741 #endif // ndef GFEATURE_H_APR_06