00001 #ifndef FEATURE_H_NOV_08
00002 #define FEATURE_H_NOV_08
00003
00004 #define KVAL 100
00005 #define NORM 0
00006 #define REC_HACK 0
00007
00008 #define NO_NORMALIZE 1
00009 #define SIGMOID 0
00010 #define SMOOTH_MAX 0
00011 #define SMOOTH_MAX2 0
00012
00013 #include <stdlib.h>
00014 inline double frand() {
00015 return rand()/(double)RAND_MAX;
00016 }
00017
00018 #include "predicate.h"
00019 #include "predicatetemplate.h"
00020 #include "database.h"
00021 #include <math.h>
00022 #include "parentiter2.h"
00023 using namespace __gnu_cxx;
00024
00025
00026 class RRF;
00027 class GroundRRF;
00028 class GroundFeature;
00029 class PredicateGroundFeature;
00030 class ClausalGroundFeature;
00031 class ConstantGroundFeature;
00032 class RecursiveGroundFeature;
00033
00034 inline double sigmoid(double x)
00035 {
00036 if (x < -100.0) {
00037 return 0;
00038 } else if (x > 100.0) {
00039 return 1;
00040 } else {
00041 return 1.0/(1.0 + exp(-x));
00042 }
00043 }
00044
00045
00046
00047
00048
00049
00050 class Feature
00051 {
00052 public:
00053 Feature(const char* name = NULL) : name_(NULL)
00054 {
00055 if (name != NULL) {
00056 setName(name);
00057 }
00058 }
00059
00060 virtual ~Feature() {
00061 delete [] name_;
00062 }
00063
00064 const char* getName() const {
00065 return name_;
00066 }
00067
00068 void setName(const char* name) {
00069 delete [] name_;
00070 name_ = new char[strlen(name)+1];
00071 strcpy(name_, name);
00072 }
00073
00074 void setId(int id) { id_ = id; }
00075 int getId() { return id_; }
00076
00077
00078
00079
00080 virtual double getValue(const Array<int>& grounding, Database* db) {
00081 #if 1
00082 return getCachedValue(grounding, db);
00083 #else
00084 return computeValue(grounding, db);
00085 #endif
00086 }
00087
00088 virtual double getLogValue(const Array<int>& grounding, Database* db) {
00089 #if 1
00090 return getCachedLogValue(grounding, db);
00091 #else
00092 return computeLogValue(grounding, db);
00093 #endif
00094 }
00095
00096 double getCachedValue(const Array<int>& grounding, Database* db) {
00097 int index = getGroundingIndex(grounding, db);
00098
00099
00100 while (cacheValid_.size() <= index) {
00101 cacheValid_.append(false);
00102 cacheValue_.append(1.0);
00103 }
00104
00105 if (!cacheValid_[index]) {
00106 cacheValue_[index] = computeValue(grounding, db);
00107 cacheValid_[index] = true;
00108 }
00109
00110 return cacheValue_[index];
00111 }
00112
00113
00114 double getCachedLogValue(const Array<int>& grounding, Database* db) {
00115 int index = getGroundingIndex(grounding, db);
00116
00117
00118 while (cacheLogValid_.size() <= index) {
00119 cacheLogValid_.append(false);
00120 cacheLogValue_.append(1.0);
00121 }
00122
00123 if (!cacheLogValid_[index]) {
00124 cacheLogValue_[index] = computeLogValue(grounding, db);
00125 cacheLogValid_[index] = true;
00126 }
00127
00128 return cacheLogValue_[index];
00129 }
00130
00131
00132
00133 double getCount(int w) {
00134 if (w < counts_.size()) {
00135 return counts_[w];
00136 } else {
00137 return 0.0;
00138 }
00139 }
00140
00141 void setCount(int w, double val) {
00142 while (w >= counts_.size()) {
00143 counts_.append(0.0);
00144 }
00145 counts_[w] = val;
00146 }
00147
00148
00149 void invalidateAll() {
00150 for (int i = 0; i < cacheValid_.size(); i++) {
00151 cacheValid_[i] = false;
00152 }
00153 for (int i = 0; i < cacheLogValid_.size(); i++) {
00154 cacheLogValid_[i] = false;
00155 }
00156 }
00157
00158 virtual void invalidate(const Array<int>& fgrounding, Database* db) {
00159
00160
00161
00162 int index = getGroundingIndex(fgrounding, db);
00163
00164 bool wasValid = (cacheValid_.size() > index && cacheValid_[index])
00165 || (cacheLogValid_.size() > index && cacheLogValid_[index]);
00166
00167 if (cacheValid_.size() > index) {
00168 cacheValid_[index] = false;
00169 }
00170 if (cacheLogValid_.size() > index) {
00171 cacheLogValid_[index] = false;
00172 }
00173
00174 if (wasValid) {
00175 for (int i = 0; i < parents_.size(); i++) {
00176 parents_[i]->invalidateChild(id_, fgrounding, db);
00177 }
00178 }
00179 }
00180
00181 virtual void invalidateChild(int feature, const Array<int>& grounding,
00182 Database* db) {
00183
00184
00185 assert(false);
00186 return;
00187 }
00188
00189 virtual double computeValue(const Array<int>& grounding, Database* db) = 0;
00190 virtual double computeLogValue(const Array<int>& grounding, Database* db) = 0;
00191
00192
00193
00194
00195
00196
00197 virtual double getPartialDeriv(int featureIndex, int weightIndex,
00198 const Array<int>& grounding, Database* db) {
00199
00200 return computePartialDeriv(featureIndex, weightIndex, grounding, db);
00201 }
00202
00203
00204 int getNumTerms() const { return termTypes_.size(); }
00205 void addTermType(int type) { termTypes_.append(type); }
00206 void setTermType(int idx, int type) { termTypes_[idx] = type; }
00207 int getTermType(int idx) const { return termTypes_[idx]; }
00208 Array<int> getTermTypes() const { return termTypes_; }
00209
00210
00211 virtual int getNumWeights() const { return 0; }
00212 virtual double getWeight(int idx) { assert(false); return 0.0; }
00213 virtual void setWeight(int idx, double weight) { assert(false); }
00214
00215
00216
00217
00218
00219
00220
00221 int getGroundingIndex(const Array<int>& grounding, const Database* db)
00222 const {
00223
00224 const Domain* domain = db->getDomain();
00225
00226 int index = 0;
00227 for (int i = 0; i < getNumTerms(); i++) {
00228 index *= domain->getNumConstantsByType(getTermType(i));
00229 index += grounding[i];
00230 }
00231
00232 return index;
00233 }
00234
00235
00236 virtual GroundFeature* constructGroundFeature(GroundRRF* rrf,
00237 const Array<int>& grounding, Database* db)
00238 { return NULL; }
00239
00240 virtual void print(ostream& out) const
00241 {
00242
00243 if (name_ == NULL) {
00244 out << "f" << id_ << "(";
00245 } else {
00246 out << name_ << "(";
00247 }
00248
00249
00250 for (int i = 0; i < getNumTerms(); i++) {
00251 if (i > 0) {
00252 out << ",";
00253 }
00254 out << (char)('a' + i);
00255 out << ':' << getTermType(i);
00256 }
00257 out << ")";
00258 }
00259
00260 void addParent(Feature* parent) {
00261 parents_.append(parent);
00262 }
00263
00264
00265 protected:
00266 virtual double computePartialDeriv(int featureIndex, int weightIndex,
00267 const Array<int>& grounding, Database* db)
00268 { return 0; }
00269
00270
00271 protected:
00272
00273 Array<int> termTypes_;
00274
00275
00276 int id_;
00277
00278
00279 char* name_;
00280
00281 Array<Feature*> parents_;
00282 Array<bool> cacheValid_;
00283 Array<double> cacheValue_;
00284 Array<bool> cacheLogValid_;
00285 Array<double> cacheLogValue_;
00286
00287
00288
00289 Array<double> counts_;
00290 };
00291
00292
00293
00294
00295
00296
00297 class PredicateFeature : public Feature
00298 {
00299 public:
00300 PredicateFeature(const PredicateTemplate* predTemplate)
00301 : pred_(predTemplate)
00302 {
00303
00304 for (int i = 0; i < predTemplate->getNumTerms(); i++) {
00305 pred_.appendTerm(new Term(-1));
00306 addTermType(predTemplate->getTermTypeAsInt(i));
00307 }
00308
00309 setName(pred_.getName());
00310 }
00311
00312 virtual double computeValue(const Array<int>& grounding, Database* db)
00313 {
00314 assert(grounding.size() == pred_.getNumTerms());
00315
00316
00317
00318 for (int i = 0; i < grounding.size(); i++) {
00319 pred_.setTermToConstant(i, grounding[i]);
00320 }
00321
00322 return db->getValue(&pred_);
00323 }
00324
00325 virtual double computeLogValue(const Array<int>& grounding, Database* db)
00326 {
00327
00328 assert(false);
00329 return 0.0;
00330 }
00331
00332 virtual double getPartialDeriv(int fi, int wi,
00333 const Array<int>& grounding, Database* db)
00334 { return 0; }
00335
00336 virtual void print(ostream& out) const {
00337 #if 0
00338 Feature::print(out);
00339 out << " = " << pred_.getName() << "(";
00340
00341 for (int i = 0; i < getNumTerms(); i++) {
00342 if (i > 0) {
00343 out << ",";
00344 }
00345 out << (char)('a' + i);
00346 }
00347 out << ")";
00348 #endif
00349 }
00350
00351
00352
00353 Predicate* getPredicate() const { return &pred_; }
00354
00355 inline virtual GroundFeature* constructGroundFeature(GroundRRF* rrf,
00356 const Array<int>& grounding, Database* db);
00357
00358 private:
00359
00360 mutable Predicate pred_;
00361 };
00362
00363
00364
00365
00366
00367 class ConstantFeature : public Feature
00368 {
00369 public:
00370 ConstantFeature(double value=1)
00371 : value_(value)
00372 { setName(""); }
00373
00374
00375 virtual double getValue(const Array<int>& grounding, Database* db)
00376 { return value_; }
00377
00378 virtual double computeValue(const Array<int>& grounding, Database* db)
00379 { return value_; }
00380
00381 virtual double computeLogValue(const Array<int>& grounding, Database* db)
00382 { return log(value_); }
00383
00384
00385 virtual double getPartialDeriv(int fi, int wi,
00386 const Array<int>& grounding, Database* db)
00387 { return 0; }
00388
00389 virtual inline GroundFeature* constructGroundFeature(GroundRRF* rrf,
00390 const Array<int>& grounding, Database* db);
00391
00392 virtual void print(ostream& out) const {
00393
00394
00395
00396
00397
00398 }
00399
00400 private:
00401 double value_;
00402 };
00403
00404
00405
00406
00407
00408 class RecursiveFeature : public Feature
00409 {
00410 public:
00411 RecursiveFeature(const char* name, bool logDerivs = false,
00412 bool normalize = true)
00413 : doDerivsOfLog_(logDerivs),
00414 #if NO_NORMALIZE
00415 normalize_(false),
00416 #else
00417 normalize_(normalize),
00418 #endif
00419 cachedZinvalid_(true), cachedValuesInvalid_(true)
00420 {
00421 setName(name);
00422 }
00423
00424
00425
00426
00427 virtual int getNumWeights() const { return getNumChildren(); }
00428 virtual double getWeight(int idx) {
00429 assert(idx >= 0 && idx < getNumWeights());
00430 return weights_[idx];
00431 }
00432
00433 virtual void setWeight(int idx, double weight) {
00434 assert(idx >= 0 && idx < getNumWeights());
00435 cachedZinvalid_ = true;
00436 cachedValuesInvalid_ = true;
00437 weights_[idx] = weight;
00438 }
00439
00440 void addChild(Feature* feature, double weight, const Array<int>& map) {
00441 children_.append(feature);
00442 weights_.append(weight);
00443 termMap_.append(map);
00444 feature->addParent(this);
00445 }
00446
00447 Feature* getChild(int i) {
00448 return children_[i];
00449 }
00450
00451 int getNumChildren() const { return children_.size(); }
00452
00453
00454 virtual void invalidateChild(int feature, const Array<int>& fgrounding,
00455 Database* db) {
00456
00457
00458 if (getNumTerms() == 0) {
00459 Array<int> nullGrounding;
00460 invalidate(nullGrounding, db);
00461 return;
00462 }
00463
00464 const Domain* domain = db->getDomain();
00465
00466
00467
00468 Array<Array<int> > uniqVals;
00469 for (int i = 0; i < fgrounding.size(); i++) {
00470 int type = getTermType(i);
00471 while (uniqVals.size() <= type) {
00472 uniqVals.append(Array<int>());
00473 }
00474 if (!uniqVals[type].contains(fgrounding[i])) {
00475 uniqVals[type].append(fgrounding[i]);
00476 }
00477 }
00478
00479
00480 ParentIter2 iter(uniqVals, domain->getConstantsByType(),
00481 getTermTypes());
00482
00483 Array<int> grounding;
00484
00485
00486 while (iter.hasNextGrounding()) {
00487 iter.getNextGrounding(grounding);
00488 invalidate(grounding, db);
00489 }
00490 }
00491
00492
00493 virtual double computePartialDeriv(int fi, int wi,
00494 const Array<int>& grounding, Database* db)
00495 {
00496
00497 double totalPartial = 0.0;
00498
00499 if (numGroundings_.size() == 0) {
00500 cacheNumGroundings(db);
00501 }
00502
00503
00504 if (fi == id_) {
00505
00506
00507
00508 totalPartial = childGroundSum(wi, grounding, db);
00509
00510 if (normalize_) {
00511 totalPartial -= getNorm(wi);
00512 }
00513
00514
00515 } else {
00516
00517
00518
00519 for (int i = 0; i < getNumChildren(); i++)
00520 {
00521
00522 ArraysAccessor<int>* groundingIter =
00523 getChildGroundingIter(i, grounding, db);
00524
00525 Array<int> childGrounding;
00526 int j = 0;
00527 do {
00528 groundingIter->getNextCombination(childGrounding);
00529 double childValue = children_[i]->getPartialDeriv(fi, wi,
00530 childGrounding, db);
00531 #if 0
00532
00533 cout << "RRF Child " << i << "," << j << "," << fi
00534 << "," << wi << ": " << childValue << endl;
00535 #endif
00536 totalPartial += weights_[i] * childValue;
00537 j++;
00538 } while (groundingIter->hasNextCombination());
00539
00540 #if NORM
00541
00542 totalPartial /= j;
00543 #endif
00544
00545 releaseChildGroundingIter(i, groundingIter);
00546 }
00547 }
00548
00549
00550
00551 if (doDerivsOfLog_) {
00552 return totalPartial;
00553 } else {
00554 #if SIGMOID
00555 double val = getValue(grounding, db);
00556 return val * (1.0 - val) * totalPartial;
00557 #else
00558 return getValue(grounding, db) * totalPartial;
00559 #endif
00560 }
00561 }
00562
00563 virtual double computeValue(const Array<int>& grounding, Database* db)
00564 {
00565 #if SIGMOID
00566 double totalValue = 0.0;
00567
00568 if (numGroundings_.size() == 0) {
00569 cacheNumGroundings(db);
00570 }
00571
00572 for (int i = 0; i < getNumChildren(); i++) {
00573 totalValue += weights_[i] * childGroundSum(i, grounding, db);
00574 }
00575
00576 return sigmoid(totalValue);
00577 #else
00578 return exp(computeLogValue(grounding, db));
00579 #endif
00580 }
00581
00582 virtual double computeLogValue(const Array<int>& grounding, Database* db)
00583 {
00584 double totalValue = 0.0;
00585
00586 if (numGroundings_.size() == 0) {
00587 cacheNumGroundings(db);
00588 }
00589
00590 for (int i = 0; i < getNumChildren(); i++) {
00591 totalValue += weights_[i] * childGroundSum(i, grounding, db);
00592 }
00593
00594 return totalValue - getLogZ();
00595 }
00596
00597 virtual inline GroundFeature* constructGroundFeature(GroundRRF* rrf,
00598 const Array<int>& grounding, Database* db);
00599
00600 virtual void print(ostream& out) const {
00601
00602 Feature::print(out);
00603 out << " = exp(";
00604
00605
00606
00607 char currFreeVar = 'a' + getNumTerms();
00608
00609
00610 for (int i = 0; i < getNumChildren(); i++) {
00611 if (i > 0) {
00612 out << " + ";
00613 }
00614
00615 assert(children_[i]->getName() != NULL);
00616 out << weights_[i] << " " << children_[i]->getName();
00617
00618 if (strlen(children_[i]->getName()) > 0) {
00619
00620 out << "(";
00621 for (int j = 0; j < termMap_[i].size(); j++) {
00622 if (j > 0) {
00623 out << ",";
00624 }
00625 if (termMap_[i][j] < 0) {
00626 out << currFreeVar;
00627
00628 } else {
00629 out << (char)('a' + termMap_[i][j]);
00630 }
00631 }
00632 out << ")";
00633 }
00634 }
00635 out << ")\n";
00636 }
00637
00638
00639
00640
00641
00642 double childGroundSum(int childIndex, const Array<int>& parentGrounding,
00643 Database* db)
00644 {
00645 double totalValue = 0.0;
00646 ArraysAccessor<int>* groundingIter =
00647 getChildGroundingIter(childIndex, parentGrounding, db);
00648
00649 #if NORM
00650
00651 int numGroundings = 0;
00652 #endif
00653 Array<int> childGrounding;
00654 do {
00655 groundingIter->getNextCombination(childGrounding);
00656 totalValue += children_[childIndex]->getValue(childGrounding, db);
00657 #if NORM
00658 numGroundings++;
00659 #endif
00660 } while (groundingIter->hasNextCombination());
00661
00662 releaseChildGroundingIter(childIndex, groundingIter);
00663 #if NORM
00664 return totalValue/numGroundings;
00665 #else
00666 return totalValue;
00667 #endif
00668 }
00669
00670
00671
00672
00673
00674 ArraysAccessor<int>* getChildGroundingIter(int childId,
00675 const Array<int>& grounding, Database* db)
00676 {
00677 const Domain* domain = db->getDomain();
00678 ArraysAccessor<int>* groundingIter = new ArraysAccessor<int>;
00679
00680
00681
00682
00683 for (int term = 0; term < children_[childId]->getNumTerms();
00684 term++) {
00685
00686
00687 if (termMap_[childId][term] < 0) {
00688
00689 int type = children_[childId]->getTermType(term);
00690 groundingIter->appendArray(domain->getConstantsByType(type));
00691 } else {
00692
00693 Array<int>* singleton = new Array<int>;
00694 singleton->append(grounding[termMap_[childId][term]]);
00695 groundingIter->appendArray(singleton);
00696 }
00697 }
00698
00699 return groundingIter;
00700 }
00701
00702
00703 void releaseChildGroundingIter(int childId,
00704 ArraysAccessor<int>* groundingIter)
00705 {
00706
00707 for (int j = 0; j < children_[childId]->getNumTerms(); j++) {
00708 if (termMap_[childId][j] >= 0) {
00709 delete groundingIter->getArray(j);
00710 }
00711 }
00712
00713 delete groundingIter;
00714 }
00715
00716 double getLogZ() const {
00717 #if SIGMOID
00718 return 0.0;
00719 #else
00720
00721 if (!normalize_) {
00722 return 0.0;
00723 }
00724
00725
00726 assert(numGroundings_.size() > 0);
00727
00728 if (cachedZinvalid_) {
00729 cachedLogZ_ = 0.0;
00730 for (int i = 0; i < getNumChildren(); i++) {
00731 #if SMOOTH_MAX
00732
00733
00734 double sigma_w = sigmoid(weights_[i]);
00735
00736 cachedLogZ_ += weights_[i] * sigma_w * numGroundings_[i];
00737 #elif SMOOTH_MAX2
00738 if (weights_[i] < -10.0) {
00739 cachedLogZ_ += 0.0;
00740 } else if (weights_[i] > 10.0) {
00741 cachedLogZ_ += weights_[i] * numGroundings_[i];
00742 } else {
00743 cachedLogZ_ += numGroundings_[i]
00744 * log(1.0 + exp(KVAL * weights_[i])) / KVAL;
00745 }
00746 #else
00747 if (weights_[i] > 0.0) {
00748 cachedLogZ_ += weights_[i] * numGroundings_[i];
00749 }
00750 #endif
00751 }
00752
00753 cachedZinvalid_ = false;
00754 }
00755 return cachedLogZ_;
00756 #endif
00757 }
00758
00759 virtual double getZ() const {
00760 return exp(getLogZ());
00761 }
00762
00763 virtual double getNorm(int idx) {
00764 #if SIGMOID
00765 return 0.0;
00766 #else
00767
00768 if (!normalize_) {
00769 return 0.0;
00770 }
00771
00772
00773 assert(numGroundings_.size() > 0);
00774
00775 if (cachedValuesInvalid_) {
00776 cachedNormalizers_.growToSize(getNumWeights());
00777 for (int i = 0; i < getNumWeights(); i++) {
00778
00779 #if SMOOTH_MAX
00780
00781
00782 double w_i = getWeight(i);
00783 double sigma_w = sigmoid(w_i);
00784
00785
00786 double norm = sigma_w * (1.0 + w_i * (1.0 - sigma_w))
00787 * numGroundings_[i];
00788 cachedNormalizers_[i] = norm;
00789 #elif SMOOTH_MAX2
00790 cachedNormalizers_[i] = sigmoid(getWeight(i))
00791 * numGroundings_[i];
00792 #else
00793 if (getWeight(i) < 0.0) {
00794 cachedNormalizers_[i] = 0.0;
00795 } else {
00796 cachedNormalizers_[i] = 1.0 * numGroundings_[i];
00797 }
00798 #endif
00799 }
00800 cachedValuesInvalid_ = false;
00801 }
00802 return cachedNormalizers_[idx];
00803 #endif
00804 }
00805
00806 void cacheNumGroundings(Database* db)
00807 {
00808 const Domain* domain = db->getDomain();
00809
00810 numGroundings_.clear();
00811 for (int i = 0; i < getNumChildren(); i++) {
00812 int numGroundings = 1;
00813 for (int term = 0; term < children_[i]->getNumTerms(); term++) {
00814 if (termMap_[i][term] < 0) {
00815 int type = children_[i]->getTermType(term);
00816 numGroundings *= domain->getNumConstantsByType(type);
00817 }
00818 }
00819 numGroundings_.append(numGroundings);
00820 }
00821 }
00822
00823
00824 protected:
00825 Array<Feature*> children_;
00826 Array<double> weights_;
00827 Array<ArraysAccessor<int>* > groundingIters_;
00828
00829
00830
00831
00832
00833
00834
00835
00836
00837 Array<Array<int> > termMap_;
00838
00839
00840
00841
00842 Array<int> numGroundings_;
00843
00844
00845
00846
00847 bool doDerivsOfLog_;
00848
00849 bool normalize_;
00850
00851 mutable bool cachedZinvalid_;
00852 mutable double cachedLogZ_;
00853 mutable bool cachedValuesInvalid_;
00854 mutable Array<double> cachedNormalizers_;
00855 };
00856
00857
00858
00859
00860
00861
00862
00863
00864
00865 class ClausalFeature : public RecursiveFeature
00866 {
00867 public:
00868
00869 ClausalFeature(const char* name = NULL)
00870 : RecursiveFeature(name), cachedZinvalid_(true),
00871 cachedValuesInvalid_(true)
00872 { }
00873
00874 virtual double computePartialDeriv(int fi, int wi,
00875 const Array<int>& grounding, Database* db)
00876 {
00877 if (fi != id_) {
00878 return 0.0;
00879 }
00880
00881
00882
00883
00884
00885 double totalValue =
00886 getChildValue(wi, grounding, db) - getNorm(wi);
00887
00888
00889 #if SIGMOID
00890 double val = getValue(grounding, db);
00891 return val * (1.0 - val) * totalValue;
00892 #else
00893 return getValue(grounding, db) * totalValue;
00894 #endif
00895 }
00896
00897 virtual double computeLogValue(const Array<int>& grounding, Database* db)
00898 {
00899 return computeLogValueRaw(grounding, db) - getLogZ();
00900 }
00901
00902 virtual double computeValue(const Array<int>& grounding, Database* db)
00903 {
00904 #if SIGMOID
00905 double total = 0.0;
00906 for (int childId = 0; childId < getNumChildren(); childId++) {
00907 total += weights_[childId]
00908 * getChildValue(childId, grounding, db);
00909 }
00910
00911 return sigmoid(total);
00912 #else
00913 double logValue = computeLogValue(grounding, db);
00914
00915 if (logValue < -100.0) {
00916 return 0.0;
00917 } else {
00918 return exp(logValue);
00919 }
00920 #endif
00921 }
00922
00923 virtual void setWeight(int idx, double weight) {
00924 RecursiveFeature::setWeight(idx, weight);
00925 cachedZinvalid_ = true;
00926 cachedValuesInvalid_ = true;
00927 }
00928
00929 virtual double getNorm(int idx) {
00930 #if SIGMOID || NO_NORMALIZE
00931 return 0.0;
00932 #else
00933 if (cachedValuesInvalid_) {
00934 cachedNormalizers_.growToSize(getNumWeights());
00935 for (int i = 0; i < getNumWeights(); i++) {
00936 #if 0
00937
00938 double norm = 1.0/(1.0 + exp(-getWeight(i)));
00939 #elif SMOOTH_MAX
00940
00941
00942 double w_i = getWeight(i);
00943 double norm;
00944 double sigma_w = sigmoid(w_i);
00945
00946
00947 norm = sigma_w * (1.0 + w_i * (1.0 - sigma_w));
00948 #elif SMOOTH_MAX2
00949
00950
00951 double w_i = getWeight(i);
00952 double norm;
00953 double sigma_w = sigmoid(w_i);
00954
00955 norm = sigma_w;
00956 #else
00957
00958 double norm;
00959 if (getWeight(i) < 0.0) {
00960 norm = 0.0;
00961 } else {
00962
00963 norm = 1.0;
00964 }
00965 #endif
00966 cachedNormalizers_[i] = norm;
00967 }
00968 cachedValuesInvalid_ = false;
00969 }
00970 return cachedNormalizers_[idx];
00971 #endif
00972 }
00973
00974 void setWeightsFromExample(const Array<int>& grounding, Database* db,
00975 const Array<int>& queryPreds, int numWeights)
00976 {
00977 if (numWeights < 0) {
00978 numWeights = getNumChildren();
00979 }
00980
00981
00982 bool nonZeroQueryPredicate = false;
00983 while (!nonZeroQueryPredicate) {
00984
00985 int numNonZero = 0;
00986 for (int i = 0; i < getNumChildren(); i++) {
00987
00988 if (frand() < ((double)numWeights - numNonZero)
00989 /((double)getNumChildren() - i)) {
00990
00991 if (getChildValue(i, grounding, db)) {
00992 setWeight(i, 1.0 + frand() );
00993 } else {
00994 setWeight(i, -1.0 - frand() );
00995 }
00996
00997 int featureId = children_[i]->getId();
00998 if (queryPreds.find(featureId) != -1) {
00999 nonZeroQueryPredicate = true;
01000 }
01001 numNonZero++;
01002 } else {
01003
01004 setWeight(i, 0.0);
01005 }
01006 }
01007 }
01008 }
01009
01010 virtual GroundFeature* constructGroundFeature(GroundRRF* rrf,
01011 const Array<int>& grounding, Database* db);
01012
01013 double getLogZ() const {
01014 #if SIGMOID || NO_NORMALIZE
01015 return 0.0;
01016 #else
01017 if (cachedZinvalid_) {
01018 cachedLogZ_ = 0.0;
01019 for (int i = 0; i < getNumChildren(); i++) {
01020 #if 0
01021
01022 if (weights_[i] > 100.0) {
01023
01024 cachedLogZ_ += weights_[i];
01025 } else if (weights_[i] > -100.0) {
01026
01027 cachedLogZ_ += log(1.0 + exp(weights_[i]));
01028 }
01029 #elif SMOOTH_MAX
01030
01031
01032 double sigma_w = sigmoid(weights_[i]);
01033
01034 cachedLogZ_ += weights_[i] * sigma_w;
01035 #elif SMOOTH_MAX2
01036 if (weights_[i] < -10.0) {
01037
01038 } else if (weights_[i] > 10.0) {
01039 cachedLogZ_ += weights_[i];
01040 } else {
01041 cachedLogZ_ += log(1.0 + exp(KVAL*weights_[i]))/KVAL;
01042 }
01043 #else
01044
01045 if (weights_[i] > 0.0) {
01046 cachedLogZ_ += weights_[i];
01047 }
01048 #endif
01049 }
01050 cachedZinvalid_ = false;
01051 }
01052 return cachedLogZ_;
01053 #endif
01054 }
01055
01056 virtual double getZ() const {
01057 return exp(getLogZ());
01058 }
01059
01060 protected:
01061
01062 double computeLogValueRaw(const Array<int>& grounding,
01063 Database* db) const {
01064
01065 double total = 0.0;
01066 for (int childId = 0; childId < getNumChildren(); childId++) {
01067 total += weights_[childId]
01068 * getChildValue(childId, grounding, db);
01069 }
01070 return total;
01071 }
01072
01073
01074 double getChildValue(int childId, const Array<int>& grounding,
01075 Database* db) const
01076 {
01077 #if 0
01078 Predicate* pred = ((PredicateFeature*)children_[childId])
01079 ->getPredicate();
01080 for (int i = 0; i < children_[childId]->getNumTerms(); i++) {
01081 pred->setTermToConstant(i, grounding[termMap_[childId][i]]);
01082 }
01083
01084
01085 return (int)db->getValue(pred);
01086 #else
01087 Array<int> cgrounding;
01088 for (int i = 0; i < children_[childId]->getNumTerms(); i++) {
01089 cgrounding.append(grounding[termMap_[childId][i]]);
01090 }
01091 return children_[childId]->getValue(cgrounding, db);
01092 #endif
01093 }
01094
01095 mutable bool cachedZinvalid_;
01096 mutable double cachedLogZ_;
01097 mutable bool cachedValuesInvalid_;
01098 mutable Array<double> cachedNormalizers_;
01099 };
01100
01101 #include "gfeature.h"
01102
01103 #endif