feature.h

00001 #ifndef FEATURE_H_NOV_08
00002 #define FEATURE_H_NOV_08
00003 
00004 #define KVAL 100
00005 #define NORM 0
00006 #define REC_HACK 0
00007 
00008 #define NO_NORMALIZE 1
00009 #define SIGMOID 0
00010 #define SMOOTH_MAX 0
00011 #define SMOOTH_MAX2 0
00012 
00013 #include <stdlib.h>
00014 inline double frand() {
00015     return rand()/(double)RAND_MAX;
00016 }
00017 
00018 #include "predicate.h"
00019 #include "predicatetemplate.h"
00020 #include "database.h"
00021 #include <math.h>
00022 #include "parentiter2.h"
00023 using namespace __gnu_cxx;
00024 
00025 // Forward declaration
00026 class RRF;
00027 class GroundRRF;
00028 class GroundFeature;
00029 class PredicateGroundFeature;
00030 class ClausalGroundFeature;
00031 class ConstantGroundFeature;
00032 class RecursiveGroundFeature;
00033 
00034 inline double sigmoid(double x)
00035 {
00036     if (x < -100.0) {
00037         return 0;
00038     } else if (x > 100.0) {
00039         return 1;
00040     } else {
00041        return 1.0/(1.0 + exp(-x));
00042     }
00043 }
00044 
00045 
00046 /* Feature -- an abstract class to represent an RRF feature.
00047  * This could reference other, child features, or be a ground predicate,
00048  * depending on the derived class.
00049  */
00050 class Feature
00051 {
00052 public:
00053     Feature(const char* name = NULL) : name_(NULL)
00054     { 
00055         if (name != NULL) {
00056             setName(name);
00057         }
00058     }
00059 
00060     virtual ~Feature() {
00061         delete [] name_;
00062     }
00063 
00064     const char* getName() const {
00065         return name_;
00066     }
00067 
00068     void setName(const char* name) {
00069         delete [] name_;
00070         name_ = new char[strlen(name)+1]; 
00071         strcpy(name_, name); 
00072     }
00073 
00074     void setId(int id) {   id_ = id; }
00075     int  getId()       { return id_; }
00076 
00077     /* Compute the value of one grounding of this feature, given
00078      * a particular world.
00079      */
00080     virtual double getValue(const Array<int>& grounding, Database* db) {
00081 #if 1
00082         return getCachedValue(grounding, db);
00083 #else
00084         return computeValue(grounding, db);
00085 #endif
00086     }
00087 
00088     virtual double getLogValue(const Array<int>& grounding, Database* db) {
00089 #if 1
00090         return getCachedLogValue(grounding, db);
00091 #else
00092         return computeLogValue(grounding, db);
00093 #endif
00094     }
00095 
00096     double getCachedValue(const Array<int>& grounding, Database* db) {
00097         int index = getGroundingIndex(grounding, db);
00098 
00099         // Expand the array as necessary
00100         while (cacheValid_.size() <= index) {
00101             cacheValid_.append(false);
00102             cacheValue_.append(1.0);
00103         }
00104 
00105         if (!cacheValid_[index]) {
00106             cacheValue_[index] = computeValue(grounding, db);
00107             cacheValid_[index] = true;
00108         }
00109 
00110         return cacheValue_[index];
00111     }
00112 
00113 
00114     double getCachedLogValue(const Array<int>& grounding, Database* db) {
00115         int index = getGroundingIndex(grounding, db);
00116 
00117         // Expand the array as necessary
00118         while (cacheLogValid_.size() <= index) {
00119             cacheLogValid_.append(false);
00120             cacheLogValue_.append(1.0);
00121         }
00122 
00123         if (!cacheLogValid_[index]) {
00124             cacheLogValue_[index] = computeLogValue(grounding, db);
00125             cacheLogValid_[index] = true;
00126         }
00127 
00128         return cacheLogValue_[index];
00129     }
00130 
00131     // For caching and updating partial derivatives of each weight
00132     // Used by getPsuedoCounts in gfeature.h/cpp.
00133     double getCount(int w) {
00134         if (w < counts_.size()) {
00135             return counts_[w];
00136         } else {
00137             return 0.0;
00138         }
00139     }
00140 
00141     void setCount(int w, double val) {
00142         while (w >= counts_.size()) {
00143             counts_.append(0.0);
00144         }
00145         counts_[w] = val;
00146     }
00147 
00148 
00149     void invalidateAll() {
00150         for (int i = 0; i < cacheValid_.size(); i++) {
00151             cacheValid_[i] = false;
00152         }
00153         for (int i = 0; i < cacheLogValid_.size(); i++) {
00154             cacheLogValid_[i] = false;
00155         }
00156     }
00157 
00158     virtual void invalidate(const Array<int>& fgrounding, Database* db) {
00159 
00160         // Invalidate this grounding of this feature, and all
00161         // dependent groundings of its parents
00162         int index = getGroundingIndex(fgrounding, db);
00163 
00164         bool wasValid = (cacheValid_.size() > index && cacheValid_[index])
00165             || (cacheLogValid_.size() > index && cacheLogValid_[index]);
00166 
00167         if (cacheValid_.size() > index) { 
00168             cacheValid_[index] = false; 
00169         }
00170         if (cacheLogValid_.size() > index) { 
00171             cacheLogValid_[index] = false; 
00172         }
00173 
00174         if (wasValid) {
00175             for (int i = 0; i < parents_.size(); i++) {
00176                 parents_[i]->invalidateChild(id_, fgrounding, db);
00177             }
00178         }
00179     }
00180 
00181     virtual void invalidateChild(int feature, const Array<int>& grounding,
00182             Database* db) {
00183         // This should never be called.  It should only be called for
00184         // recursive features, which override this definition.
00185         assert(false);
00186         return;
00187     }
00188     
00189     virtual double computeValue(const Array<int>& grounding, Database* db) = 0;
00190     virtual double computeLogValue(const Array<int>& grounding, Database* db) = 0;
00191 
00192     /* Compute the partial derivative of this feature with respect
00193      * to a particular weight coefficient in a particular feature.
00194      *
00195      * As in getValue(), there's no caching.
00196      */
00197     virtual double getPartialDeriv(int featureIndex, int weightIndex,
00198             const Array<int>& grounding, Database* db) {
00199 
00200         return computePartialDeriv(featureIndex, weightIndex, grounding, db);
00201     }
00202 
00203     // Access term types
00204     int getNumTerms() const         { return termTypes_.size(); }
00205     void addTermType(int type)      { termTypes_.append(type); }
00206     void setTermType(int idx, int type) { termTypes_[idx] = type; }
00207     int getTermType(int idx) const  { return termTypes_[idx]; }
00208     Array<int> getTermTypes() const { return termTypes_; }
00209 
00210     // Get and set weights (only applies to RecursiveFeatures)
00211     virtual int getNumWeights() const  { return 0; }
00212     virtual double getWeight(int idx) { assert(false); return 0.0; }
00213     virtual void setWeight(int idx, double weight) { assert(false); }
00214 
00215 
00216     /* The following two methods are used by GroundRRF for constructing
00217      * a tree (or graph) of ground features.
00218      */
00219 
00220     // Utility method
00221     int getGroundingIndex(const Array<int>& grounding, const Database* db) 
00222         const {
00223 
00224         const Domain* domain = db->getDomain();
00225 
00226         int index = 0;
00227         for (int i = 0; i < getNumTerms(); i++) {
00228             index *= domain->getNumConstantsByType(getTermType(i));
00229             index += grounding[i];
00230         }
00231 
00232         return index;
00233     }
00234 
00235     // For building a ground feature tree
00236     virtual GroundFeature* constructGroundFeature(GroundRRF* rrf,
00237             const Array<int>& grounding, Database* db) 
00238     { return NULL; }
00239 
00240     virtual void print(ostream& out) const
00241     { 
00242         // List feature name
00243         if (name_ == NULL) {
00244             out << "f" << id_ << "(";
00245         } else {
00246             out << name_ <<  "(";
00247         }
00248 
00249         // List all terms
00250         for (int i = 0; i < getNumTerms(); i++) {
00251             if (i > 0) {
00252                 out << ",";
00253             }
00254             out << (char)('a' + i);
00255             out << ':' << getTermType(i);
00256         }
00257         out << ")";
00258     }
00259 
00260     void addParent(Feature* parent) {
00261         parents_.append(parent);
00262     }
00263 
00264 
00265 protected:
00266     virtual double computePartialDeriv(int featureIndex, int weightIndex,
00267             const Array<int>& grounding, Database* db) 
00268     { return 0; }
00269 
00270 
00271 protected:
00272     // Type for each of our terms 
00273     Array<int> termTypes_;
00274 
00275     // Unique index associated with this feature. 
00276     int id_;
00277 
00278     // Name of the feature
00279     char* name_;
00280 
00281     Array<Feature*> parents_;
00282     Array<bool>   cacheValid_;
00283     Array<double> cacheValue_;
00284     Array<bool>   cacheLogValid_;
00285     Array<double> cacheLogValue_;
00286 
00287     // For caching and updating partial derivatives of each weight
00288     // Used by getPsuedoCounts in gfeature.h/cpp.
00289     Array<double> counts_;
00290 };
00291 
00292 
00293 
00294 /* A feature whose value is simply that of a predicate.
00295  * Reference no other features.
00296  */
00297 class PredicateFeature : public Feature
00298 {
00299 public:
00300     PredicateFeature(const PredicateTemplate* predTemplate) 
00301         : pred_(predTemplate)
00302     { 
00303         // Set up placeholder terms and term types
00304         for (int i = 0; i < predTemplate->getNumTerms(); i++) {
00305             pred_.appendTerm(new Term(-1));
00306             addTermType(predTemplate->getTermTypeAsInt(i));
00307         }
00308 
00309         setName(pred_.getName());
00310     }
00311 
00312     virtual double computeValue(const Array<int>& grounding, Database* db) 
00313     { 
00314         assert(grounding.size() == pred_.getNumTerms());
00315 
00316         // Substitute in constants 
00317         // (all terms in the predicate are independent)
00318         for (int i = 0; i < grounding.size(); i++) {
00319             pred_.setTermToConstant(i, grounding[i]);
00320         }
00321         // Query database
00322         return db->getValue(&pred_); 
00323     }
00324 
00325     virtual double computeLogValue(const Array<int>& grounding, Database* db)
00326     {
00327         // This shouldn't be called!
00328         assert(false);
00329         return 0.0;
00330     }
00331 
00332     virtual double getPartialDeriv(int fi, int wi, 
00333             const Array<int>& grounding, Database* db) 
00334         { return 0; }
00335 
00336     virtual void print(ostream& out) const {
00337 #if 0
00338         Feature::print(out);
00339         out << " = " << pred_.getName() << "(";
00340 
00341         for (int i = 0; i < getNumTerms(); i++) {
00342             if (i > 0) {
00343                 out << ",";
00344             }
00345             out << (char)('a' + i);
00346         }
00347         out << ")";
00348 #endif
00349     }
00350 
00351 
00352     // EFFICIENCY HACK!
00353     Predicate* getPredicate() const { return &pred_; }
00354 
00355     inline virtual GroundFeature* constructGroundFeature(GroundRRF* rrf,
00356             const Array<int>& grounding, Database* db);
00357 
00358 private:
00359     // mutable -- for efficiency hack above
00360     mutable Predicate pred_;
00361 };
00362 
00363 
00364 
00365 /* A feature with a fixed, numerical value.
00366  */
00367 class ConstantFeature : public Feature
00368 {
00369 public:
00370     ConstantFeature(double value=1) 
00371         : value_(value)
00372     { setName(""); }
00373 
00374     // Return fixed value
00375     virtual double getValue(const Array<int>& grounding, Database* db) 
00376     { return value_; }
00377 
00378     virtual double computeValue(const Array<int>& grounding, Database* db) 
00379     { return value_; }
00380 
00381     virtual double computeLogValue(const Array<int>& grounding, Database* db) 
00382     { return log(value_); }
00383 
00384     // Derivative of a constant is always zero
00385     virtual double getPartialDeriv(int fi, int wi, 
00386             const Array<int>& grounding, Database* db) 
00387     { return 0; }
00388 
00389     virtual inline GroundFeature* constructGroundFeature(GroundRRF* rrf,
00390             const Array<int>& grounding, Database* db);
00391 
00392     virtual void print(ostream& out) const {
00393         /*
00394         Feature::print(out);
00395         out << " = " << value_;
00396         out << endl;
00397         */
00398     }
00399 
00400 private:
00401     double value_;
00402 };
00403 
00404 
00405 
00406 /* An RRF feature whose value is a log linear function of other features.
00407  */
00408 class RecursiveFeature : public Feature
00409 {
00410 public:
00411     RecursiveFeature(const char* name, bool logDerivs = false, 
00412             bool normalize = true) 
00413         : doDerivsOfLog_(logDerivs), 
00414 #if NO_NORMALIZE
00415         normalize_(false), 
00416 #else
00417         normalize_(normalize), 
00418 #endif
00419           cachedZinvalid_(true), cachedValuesInvalid_(true)
00420     {
00421         setName(name);
00422     }
00423 
00424     // TODO: copy constructur, operator=, and destructor...
00425 
00426     // Get and set weights
00427     virtual int getNumWeights() const { return getNumChildren(); }
00428     virtual double getWeight(int idx) { 
00429         assert(idx >= 0 && idx < getNumWeights());
00430         return weights_[idx]; 
00431     }
00432 
00433     virtual void setWeight(int idx, double weight) { 
00434         assert(idx >= 0 && idx < getNumWeights());
00435         cachedZinvalid_ = true;
00436         cachedValuesInvalid_ = true;
00437         weights_[idx] = weight; 
00438     }
00439 
00440     void addChild(Feature* feature, double weight, const Array<int>& map) {
00441         children_.append(feature);
00442         weights_.append(weight);
00443         termMap_.append(map);
00444         feature->addParent(this);
00445     }
00446 
00447     Feature* getChild(int i) {
00448         return children_[i];
00449     }
00450 
00451     int getNumChildren() const { return children_.size(); }
00452 
00453 
00454     virtual void invalidateChild(int feature, const Array<int>& fgrounding, 
00455             Database* db) {
00456 
00457         // Hack to make the top-level work
00458         if (getNumTerms() == 0) {
00459             Array<int> nullGrounding;
00460             invalidate(nullGrounding, db);
00461             return;
00462         }
00463 
00464         const Domain* domain = db->getDomain();
00465 
00466         // Iterate over all groundings of this feature, to see
00467         // which ones are invalid.
00468         Array<Array<int> > uniqVals;
00469         for (int i = 0; i < fgrounding.size(); i++) {
00470             int type = getTermType(i);
00471             while (uniqVals.size() <= type) {
00472                 uniqVals.append(Array<int>());
00473             }
00474             if (!uniqVals[type].contains(fgrounding[i])) {
00475                 uniqVals[type].append(fgrounding[i]);
00476             }
00477         }
00478         //ParentIter iter(uniqVals[type], *(domain->getConstantsByType(1)), 
00479         //        getNumTerms());
00480         ParentIter2 iter(uniqVals, domain->getConstantsByType(), 
00481                 getTermTypes());
00482 
00483         Array<int> grounding;
00484 
00485         // Mark each relevant grounding and all its parents as invalid
00486         while (iter.hasNextGrounding()) {
00487             iter.getNextGrounding(grounding);
00488             invalidate(grounding, db);
00489         }
00490     }
00491 
00492 
00493     virtual double computePartialDeriv(int fi, int wi, 
00494             const Array<int>& grounding, Database* db)
00495     {
00496         // TODO: optimize by caching...
00497         double totalPartial = 0.0;
00498 
00499         if (numGroundings_.size() == 0) {
00500             cacheNumGroundings(db);
00501         }
00502 
00503         // First, compute the partial derivative of the weight sum
00504         if (fi == id_) {
00505 
00506             // We own the weight; partial is simply the sum of the 
00507             // corresponding child values.  (i.e., d w*x/ dw = x)
00508             totalPartial = childGroundSum(wi, grounding, db);
00509 
00510             if (normalize_) {
00511                 totalPartial -= getNorm(wi);
00512             }
00513             // getNorm(wi) results from taking the derivative of the 
00514             // 1/Z component.
00515         } else {
00516 
00517             // Sum partial derivatives for each grounding of each child,
00518             // according to the chain rule of partial derivatives.
00519             for (int i = 0; i < getNumChildren(); i++) 
00520             {
00521 
00522                 ArraysAccessor<int>* groundingIter = 
00523                     getChildGroundingIter(i, grounding, db);
00524 
00525                 Array<int> childGrounding;
00526                 int j = 0;
00527                 do {
00528                     groundingIter->getNextCombination(childGrounding);
00529                     double childValue = children_[i]->getPartialDeriv(fi, wi, 
00530                                 childGrounding, db);
00531 #if 0
00532                     // DEBUG
00533                     cout << "RRF Child " << i << "," << j <<  "," << fi 
00534                         << "," << wi << ": " << childValue << endl;
00535 #endif
00536                     totalPartial += weights_[i] * childValue;
00537                     j++;
00538                 } while (groundingIter->hasNextCombination());
00539 
00540 #if NORM
00541                 // Normalize
00542                 totalPartial /= j;
00543 #endif
00544 
00545                 releaseChildGroundingIter(i, groundingIter);
00546             }
00547         }
00548 
00549         // Multiply by value to get true partial:
00550         // d e^f(x)/dx = e^f(x) * d f(x)/dx
00551         if (doDerivsOfLog_) {
00552             return totalPartial;
00553         } else {
00554 #if SIGMOID
00555             double val = getValue(grounding, db);
00556             return val * (1.0 - val) * totalPartial;
00557 #else
00558             return getValue(grounding, db) * totalPartial;
00559 #endif
00560         }
00561     }
00562 
00563     virtual double computeValue(const Array<int>& grounding, Database* db) 
00564     {
00565 #if SIGMOID
00566         double totalValue = 0.0;
00567 
00568         if (numGroundings_.size() == 0) {
00569             cacheNumGroundings(db);
00570         }
00571 
00572         for (int i = 0; i < getNumChildren(); i++) {
00573             totalValue += weights_[i] * childGroundSum(i, grounding, db);
00574         }
00575 
00576         return sigmoid(totalValue);
00577 #else
00578         return exp(computeLogValue(grounding, db));
00579 #endif
00580     }
00581 
00582     virtual double computeLogValue(const Array<int>& grounding, Database* db) 
00583     {
00584         double totalValue = 0.0;
00585 
00586         if (numGroundings_.size() == 0) {
00587             cacheNumGroundings(db);
00588         }
00589 
00590         for (int i = 0; i < getNumChildren(); i++) {
00591             totalValue += weights_[i] * childGroundSum(i, grounding, db);
00592         }
00593 
00594         return totalValue - getLogZ();
00595     }
00596 
00597     virtual inline GroundFeature* constructGroundFeature(GroundRRF* rrf,
00598             const Array<int>& grounding, Database* db);
00599 
00600     virtual void print(ostream& out) const {
00601 
00602         Feature::print(out);
00603         out << " = exp(";
00604 
00605         // TODO: we may want to allow special groundings like f12(A,A),
00606         // even when A is a free parameter.  Currently, that's broken.
00607         char currFreeVar = 'a' + getNumTerms();
00608 
00609         // Print out weighted sum of child features
00610         for (int i = 0; i < getNumChildren(); i++) {
00611             if (i > 0) {
00612                 out << " + ";
00613             }
00614             
00615             assert(children_[i]->getName() != NULL);
00616             out << weights_[i] << " " << children_[i]->getName();
00617             
00618             if (strlen(children_[i]->getName()) > 0) {
00619 
00620                 out << "(";
00621                 for (int j = 0; j < termMap_[i].size(); j++) {
00622                     if (j > 0) {
00623                         out << ",";
00624                     }
00625                     if (termMap_[i][j] < 0) {
00626                         out << currFreeVar;
00627                         //currFreeVar++;
00628                     } else {
00629                         out << (char)('a' + termMap_[i][j]);
00630                     }
00631                 }
00632                 out << ")";
00633             }
00634         }
00635         out << ")\n";
00636     }
00637 
00638 //protected:
00639     // Utility method: compute the value of the given child over
00640     // all groundings of the child consistent with the parent grounding
00641     // and the term mapping.
00642     double childGroundSum(int childIndex, const Array<int>& parentGrounding,
00643             Database* db)
00644     {
00645         double totalValue = 0.0;
00646         ArraysAccessor<int>* groundingIter = 
00647             getChildGroundingIter(childIndex, parentGrounding, db);
00648 
00649 #if NORM
00650         // Normalize over the number of groundings
00651         int numGroundings = 0;
00652 #endif
00653         Array<int> childGrounding;
00654         do {
00655             groundingIter->getNextCombination(childGrounding);
00656             totalValue += children_[childIndex]->getValue(childGrounding, db);
00657 #if NORM
00658             numGroundings++;
00659 #endif
00660         } while (groundingIter->hasNextCombination());
00661 
00662         releaseChildGroundingIter(childIndex, groundingIter);
00663 #if NORM
00664         return totalValue/numGroundings;
00665 #else
00666         return totalValue;
00667 #endif
00668     }
00669 
00670     // Returns an iterator through all groundings of the specified
00671     // child that are consistent with the provided grounding of the 
00672     // parent (this).  Uses db to obtain a list of constants for each
00673     // type.  Must be freed by releaseChildGroundingIter().
00674     ArraysAccessor<int>* getChildGroundingIter(int childId, 
00675             const Array<int>& grounding, Database* db) 
00676     {
00677         const Domain* domain = db->getDomain();
00678         ArraysAccessor<int>* groundingIter = new ArraysAccessor<int>;
00679 
00680         // Setup object to iterate over all groundings of this child.
00681         // Respect mappings from parent terms to child terms;
00682         // all others are free to be any constant of the proper type.
00683         for (int term = 0; term < children_[childId]->getNumTerms(); 
00684                 term++) {
00685 
00686             // TODO: Make this work with multiple types!
00687             if (termMap_[childId][term] < 0) {
00688                 // Iterate over all constants of the appropriate type
00689                 int type = children_[childId]->getTermType(term);
00690                 groundingIter->appendArray(domain->getConstantsByType(type));
00691             } else {
00692                 // Use the passed in value
00693                 Array<int>* singleton = new Array<int>;
00694                 singleton->append(grounding[termMap_[childId][term]]);
00695                 groundingIter->appendArray(singleton);
00696             }
00697         }
00698 
00699         return groundingIter;
00700     }
00701 
00702     // Frees a grounding iterator created by the above function.
00703     void releaseChildGroundingIter(int childId,
00704             ArraysAccessor<int>* groundingIter) 
00705     {
00706         // Delete singleton arrays we created for the groundingIter
00707         for (int j = 0; j < children_[childId]->getNumTerms(); j++) {
00708             if (termMap_[childId][j] >= 0) {
00709                 delete groundingIter->getArray(j);
00710             }
00711         }
00712 
00713         delete groundingIter;
00714     }
00715 
00716     double getLogZ() const {
00717 #if SIGMOID
00718         return 0.0;
00719 #else
00720 
00721         if (!normalize_) { 
00722             return 0.0; 
00723         }
00724 
00725         // This array must already be set up.  HACK-ish
00726         assert(numGroundings_.size() > 0);
00727 
00728         if (cachedZinvalid_) {
00729             cachedLogZ_ = 0.0;
00730             for (int i = 0; i < getNumChildren(); i++) {
00731 #if SMOOTH_MAX
00732                 // Similar to dividing by the largest possible value,
00733                 // but softened so the derivative is continuous.
00734                 double sigma_w = sigmoid(weights_[i]);
00735                 // Compute sigma(weights_[i]) robustly, to avoid NaN
00736                 cachedLogZ_ += weights_[i] * sigma_w * numGroundings_[i];
00737 #elif SMOOTH_MAX2
00738                 if (weights_[i] < -10.0) {
00739                     cachedLogZ_ += 0.0;
00740                 } else if (weights_[i] > 10.0) {
00741                     cachedLogZ_ += weights_[i] * numGroundings_[i];
00742                 } else {
00743                     cachedLogZ_ += numGroundings_[i]
00744                         * log(1.0 + exp(KVAL * weights_[i])) / KVAL;
00745                 }
00746 #else
00747                 if (weights_[i] > 0.0) {
00748                     cachedLogZ_ += weights_[i] * numGroundings_[i];
00749                 }
00750 #endif
00751             }
00752 
00753             cachedZinvalid_ = false;
00754         }
00755         return cachedLogZ_;
00756 #endif
00757     }
00758 
00759     virtual double getZ() const {
00760         return exp(getLogZ());
00761     }
00762 
00763     virtual double getNorm(int idx) {
00764 #if SIGMOID
00765         return 0.0;
00766 #else
00767 
00768         if (!normalize_) {
00769             return 0.0;
00770         }
00771 
00772         // This array must already be set up.  HACK-ish
00773         assert(numGroundings_.size() > 0);
00774 
00775         if (cachedValuesInvalid_) {
00776             cachedNormalizers_.growToSize(getNumWeights());
00777             for (int i = 0; i < getNumWeights(); i++) {
00778 
00779 #if SMOOTH_MAX
00780                 // Continuous hack to avoid the discontinuity with using
00781                 // the maximum value
00782                 double w_i = getWeight(i);
00783                 double sigma_w = sigmoid(w_i);
00784                 
00785                 //norm = sigma_w * (1.0 - sigma_w) * numGroundings_[i];
00786                 double norm = sigma_w * (1.0 + w_i * (1.0 - sigma_w)) 
00787                     * numGroundings_[i];
00788                 cachedNormalizers_[i] = norm;
00789 #elif SMOOTH_MAX2
00790                 cachedNormalizers_[i] = sigmoid(getWeight(i)) 
00791                     * numGroundings_[i];
00792 #else
00793                 if (getWeight(i) < 0.0) {
00794                     cachedNormalizers_[i] = 0.0;
00795                 } else {
00796                     cachedNormalizers_[i] = 1.0 * numGroundings_[i];
00797                 }
00798 #endif
00799             }
00800             cachedValuesInvalid_ = false;
00801         }
00802         return cachedNormalizers_[idx];
00803 #endif
00804     }
00805 
00806     void cacheNumGroundings(Database* db)
00807     {
00808         const Domain* domain = db->getDomain();
00809 
00810         numGroundings_.clear();
00811         for (int i = 0; i < getNumChildren(); i++) {
00812             int numGroundings = 1;
00813             for (int term = 0; term < children_[i]->getNumTerms(); term++) {
00814                 if (termMap_[i][term] < 0) {
00815                     int type = children_[i]->getTermType(term);
00816                     numGroundings *= domain->getNumConstantsByType(type);
00817                 } 
00818             }
00819             numGroundings_.append(numGroundings);
00820         }
00821     }
00822 
00823 
00824 protected:
00825     Array<Feature*>    children_;
00826     Array<double>      weights_;
00827     Array<ArraysAccessor<int>* > groundingIters_;
00828 
00829     /* Map of parent feature's terms to child feature's terms.
00830      *
00831      * The outer array indexes over all children.
00832      * The inner array indexes over the terms of a child.
00833      * A value i >= 0 maps the child's terms to the parent's ith
00834      * term.  A value i < 0 maps the child's term to a new 
00835      * term.  This has the effect of summing over all values.
00836      */
00837     Array<Array<int> > termMap_;
00838 
00839     // Keep track of the number of groundings for each feature.
00840     // TODO: this may vary if we switch databases halfway through...
00841     // ...how can we avoid this?
00842     Array<int> numGroundings_;
00843 
00844     // When taking derivatives, use the log of this feature value, not
00845     // the feature value itself.  This should be true for the top feature
00846     // (and only the top feature).
00847     bool doDerivsOfLog_; 
00848 
00849     bool normalize_;
00850 
00851     mutable bool   cachedZinvalid_;
00852     mutable double cachedLogZ_;
00853     mutable bool   cachedValuesInvalid_;
00854     mutable Array<double> cachedNormalizers_;
00855 };
00856 
00857 
00858 /* A ClausalFeature is like a RecursiveFeature
00859  * except it only has PredicateFeatures as its children.
00860  *
00861  * WARNING: inheritance used for code sharing, but if you
00862  * add non-predicate features as children of a ClausalFeature...
00863  * everything will break.
00864  */
00865 class ClausalFeature : public RecursiveFeature
00866 {
00867 public:
00868     // BUG: segfaults when name is NULL... why?
00869     ClausalFeature(const char* name = NULL) 
00870         : RecursiveFeature(name), cachedZinvalid_(true), 
00871           cachedValuesInvalid_(true)
00872     { /* NOP */ }
00873         
00874     virtual double computePartialDeriv(int fi, int wi,
00875             const Array<int>& grounding, Database* db) 
00876     {
00877         if (fi != id_) {
00878             return 0.0;
00879         }
00880         
00881         // Find the ground value of the specified term.
00882         // There had better only be a *single* grounding, 
00883         // or we're in trouble... 
00884 
00885         double totalValue = 
00886             getChildValue(wi, grounding, db) - getNorm(wi);
00887         // getNorm(wi) results from taking the derivative of the 1/Z component.
00888 
00889 #if SIGMOID
00890         double val = getValue(grounding, db);
00891         return val * (1.0 - val) * totalValue;
00892 #else
00893         return getValue(grounding, db) * totalValue;
00894 #endif
00895     }
00896 
00897     virtual double computeLogValue(const Array<int>& grounding, Database* db)
00898     {
00899         return computeLogValueRaw(grounding, db) - getLogZ();
00900     }
00901 
00902     virtual double computeValue(const Array<int>& grounding, Database* db) 
00903     {
00904 #if SIGMOID
00905         double total = 0.0;
00906         for (int childId = 0; childId < getNumChildren(); childId++) {
00907             total += weights_[childId] 
00908                 * getChildValue(childId, grounding, db);
00909         }
00910 
00911         return sigmoid(total);
00912 #else
00913         double logValue = computeLogValue(grounding, db);
00914         // HACK to avoid underflow
00915         if (logValue < -100.0) {
00916             return 0.0;
00917         } else {
00918             return exp(logValue); 
00919         }
00920 #endif
00921     }
00922 
00923     virtual void setWeight(int idx, double weight) { 
00924         RecursiveFeature::setWeight(idx, weight);
00925         cachedZinvalid_ = true;
00926         cachedValuesInvalid_ = true;
00927     }
00928 
00929     virtual double getNorm(int idx) {
00930 #if SIGMOID || NO_NORMALIZE
00931         return 0.0;
00932 #else
00933         if (cachedValuesInvalid_) {
00934             cachedNormalizers_.growToSize(getNumWeights());
00935             for (int i = 0; i < getNumWeights(); i++) {
00936 #if 0
00937                 // Normalize based on all possible values
00938                 double norm = 1.0/(1.0 + exp(-getWeight(i)));
00939 #elif SMOOTH_MAX
00940                 // Continuous hack to avoid the discontinuity with using
00941                 // the maximum value
00942                 double w_i = getWeight(i);
00943                 double norm;
00944                 double sigma_w = sigmoid(w_i);
00945 
00946                 //norm = sigma_w * (1.0 - sigma_w);
00947                 norm = sigma_w * (1.0 + w_i * (1.0 - sigma_w));
00948 #elif SMOOTH_MAX2
00949                 // Continuous hack to avoid the discontinuity with using
00950                 // the maximum value
00951                 double w_i = getWeight(i);
00952                 double norm;
00953                 double sigma_w = sigmoid(w_i);
00954 
00955                 norm = sigma_w;
00956 #else
00957                 // Normalize based on maximum possible value
00958                 double norm;
00959                 if (getWeight(i) < 0.0) {
00960                     norm = 0.0;
00961                 } else {
00962                     //norm = 1.0/getZ();
00963                     norm = 1.0;
00964                 }
00965 #endif
00966                 cachedNormalizers_[i] = norm;
00967             }
00968             cachedValuesInvalid_ = false;
00969         }
00970         return cachedNormalizers_[idx];
00971 #endif
00972     }
00973 
00974     void setWeightsFromExample(const Array<int>& grounding, Database* db, 
00975             const Array<int>& queryPreds, int numWeights)
00976     {
00977         if (numWeights < 0) {
00978             numWeights = getNumChildren();
00979         }
00980 
00981         // Keep going until one of the query predicates has a decent weight
00982         bool nonZeroQueryPredicate = false;
00983         while (!nonZeroQueryPredicate) {
00984 
00985             int numNonZero = 0;
00986             for (int i = 0; i < getNumChildren(); i++) {
00987 
00988                 if (frand() < ((double)numWeights - numNonZero)
00989                         /((double)getNumChildren() - i)) {
00990 
00991                     if (getChildValue(i, grounding, db)) {
00992                         setWeight(i, 1.0 + frand() );
00993                     } else {
00994                         setWeight(i, -1.0 - frand() );
00995                     }
00996                     
00997                     int featureId = children_[i]->getId();
00998                     if (queryPreds.find(featureId) != -1) {
00999                         nonZeroQueryPredicate = true;
01000                     }
01001                     numNonZero++;
01002                 } else {
01003                     //setWeight(i, (frand() - 0.5)/100.0 );
01004                     setWeight(i, 0.0);
01005                 }
01006             }
01007         }
01008     }
01009 
01010     virtual GroundFeature* constructGroundFeature(GroundRRF* rrf,
01011             const Array<int>& grounding, Database* db);
01012 
01013     double getLogZ() const {
01014 #if SIGMOID || NO_NORMALIZE
01015         return 0.0;
01016 #else
01017         if (cachedZinvalid_) {
01018             cachedLogZ_ = 0.0;
01019             for (int i = 0; i < getNumChildren(); i++) {
01020 #if 0
01021                 // Normalize based on all possible values
01022                 if (weights_[i] > 100.0) {
01023                     // Very large weights dominate the 1.0 term
01024                     cachedLogZ_ += weights_[i];
01025                 } else if (weights_[i] > -100.0) {
01026                     // Very tiny weights contribute nothing
01027                     cachedLogZ_ += log(1.0 + exp(weights_[i]));
01028                 }
01029 #elif SMOOTH_MAX
01030                 // Similar to dividing by the largest possible value,
01031                 // but softened so the derivative is continuous.
01032                 double sigma_w = sigmoid(weights_[i]);
01033                 // Compute sigma(weights_[i]) robustly, to avoid NaN
01034                 cachedLogZ_ += weights_[i] * sigma_w;
01035 #elif SMOOTH_MAX2
01036                 if (weights_[i] < -10.0) {
01037                     /* NOP */
01038                 } else if (weights_[i] > 10.0) {
01039                     cachedLogZ_ += weights_[i];
01040                 } else {
01041                     cachedLogZ_ += log(1.0 + exp(KVAL*weights_[i]))/KVAL;
01042                 }
01043 #else
01044                 // Normalize by dividing by the largest possible value
01045                 if (weights_[i] > 0.0) {
01046                     cachedLogZ_ += weights_[i];
01047                 }
01048 #endif
01049             }
01050             cachedZinvalid_ = false;
01051         }
01052         return cachedLogZ_;
01053 #endif
01054     }
01055 
01056     virtual double getZ() const {
01057         return exp(getLogZ());
01058     }
01059 
01060 protected:
01061 
01062     double computeLogValueRaw(const Array<int>& grounding, 
01063             Database* db) const {
01064 
01065         double total = 0.0;
01066         for (int childId = 0; childId < getNumChildren(); childId++) {
01067             total += weights_[childId] 
01068                 * getChildValue(childId, grounding, db);
01069         }
01070         return total;
01071     }
01072 
01073 
01074     double getChildValue(int childId, const Array<int>& grounding, 
01075             Database* db) const
01076     {
01077 #if 0
01078         Predicate* pred = ((PredicateFeature*)children_[childId])
01079                               ->getPredicate();
01080         for (int i = 0; i < children_[childId]->getNumTerms(); i++) {
01081             pred->setTermToConstant(i, grounding[termMap_[childId][i]]);
01082         }
01083 
01084         // NOTE: This had better be 1 or 0, not 2 (UNKNOWN)
01085         return (int)db->getValue(pred);
01086 #else
01087         Array<int> cgrounding;
01088         for (int i = 0; i < children_[childId]->getNumTerms(); i++) {
01089             cgrounding.append(grounding[termMap_[childId][i]]);
01090         }
01091         return children_[childId]->getValue(cgrounding, db);
01092 #endif
01093     }
01094 
01095     mutable bool   cachedZinvalid_;
01096     mutable double cachedLogZ_;
01097     mutable bool   cachedValuesInvalid_;
01098     mutable Array<double> cachedNormalizers_;
01099 };
01100 
01101 #include "gfeature.h"
01102 
01103 #endif

Generated on Sun Jun 7 11:55:19 2009 for Alchemy by  doxygen 1.5.1