rrf.h

00001 #ifndef RRF_H
00002 #define RRF_H
00003 #define SIGMASQ 10000
00004 /*
00005  * rrf.h -- Defines the RRF class.
00006  *
00007  * RRF stands for Recursive Random Field, a probabilistic model that
00008  * generalizes Markov Logic Networks (MLNs).  In an RRF,
00009  *
00010  *     P(World) = 1/Z exp( sum_i w_i f_i )
00011  *
00012  * That is, an RRF is a log linear model (just like an MLN).  The
00013  * difference is in how the features, f_i are defined.  In an MLN,
00014  * each is typically a clause, e.g.:  (X_1 v X_2 v X_4)
00015  * In an RRF, each feature is itself an RRF:
00016  *
00017  *     f_i(x) = 1/Z_i exp( sum_j w_j f_j )
00018  *
00019  * ...or a ground predicate...
00020  *   
00021  *     f_i(x) = Pred(x)
00022  *
00023  * We can therefore think of the entire probability distribution as
00024  * one top-level feature:
00025  *
00026  *     P(World) = f_0
00027  *
00028  * Each feature may have one or more terms, represented here as x.
00029  * Each term in a child feature may be bound to a term of the
00030  * parent feature, or may be left unspecified, in which case we sum over
00031  * all groundings.  (This is similar to summing over all groundings in an
00032  * MLN.)
00033  */
00034 
00035 // TODO: const correctness
00036 
00037 #include "feature.h"
00038 #include "predicate.h"
00039 #include "predicatetemplate.h"
00040 #include <math.h>
00041 #include <iostream>
00042 using namespace __gnu_cxx;
00043 
00044 class RRF
00045 {
00046 private:
00047     // Private, unimplemented copy-constructor so it won't get called
00048     // accidentally.
00049     RRF(const RRF& other);
00050 
00051 public:
00052 
00053     RRF() { maxFeatureId_ = -1; topFeature_ = NULL; }
00054 
00055     ~RRF() {
00056         featureArray_.deleteItemsAndClear();
00057     }
00058 
00059     // To be called *after* all features have been allocated.
00060     // TODO -- mmake this more flexible and intelligent
00061     void load(istream& in, Domain* domain);
00062 
00063     // Create a two-level RRF with specified number of features.
00064     // 
00065     // of each arity.  The first dimension is 1.  So, for the
00066     // array {2, 1, 4}, this will create 2 features with arity 1
00067     // (\forall X), 1 of arity 2 (\forall X,Y), and 4 of arity 3
00068     // (\forall X,Y,Z).
00069     //
00070     // HACK: For now, we assume only a single type.
00071     // Caller takes ownership over returned RRF
00072     static RRF* generateTwoLevel(Domain* domain);
00073 
00074     // Add a feature with randomly selected term mappings
00075     void addRandomFeature(int numChildren, Array<int> queryPreds,
00076         Array<int> typeArity, Domain* domain);
00077 
00078     // Add a feature with every possible term mapping of every predicate
00079     void addCompleteFeature(Array<int> typeArity, Domain* domain, 
00080             const Array<int>& queryPreds, int numChildren);
00081 
00082     double getValue(Database* db) { 
00083         Array<int> emptyGrounding;
00084         // DEBUG
00085         //invalidateAll();
00086         return topFeature_->getValue(emptyGrounding, db); 
00087     }
00088 
00089     double getLogValue(Database* db) { 
00090         Array<int> emptyGrounding;
00091         // DEBUG
00092         //invalidateAll();
00093         return topFeature_->getLogValue(emptyGrounding, db); 
00094     }
00095 
00096     // TODO: should all of this be log likelihood...?
00097     double getExactLikelihood(Database* db) {
00098         return getValue(db) / getExactZ(db->getDomain());
00099     }
00100 
00101     double getExactConditionalLikelihood(Database* db, 
00102             const Array<int>& queryPreds) {
00103         return getValue(db) / getExactZ(db->getDomain(), queryPreds, db);
00104     }
00105 
00106     // Get log pseudo likelihood
00107     double getLogPseudoLikelihood(Database* db, const Array<int>& queryPreds);
00108 
00109     double getLogPseudoLikelihood(Database* db, const Array<Predicate*>& 
00110             queryPreds);
00111 
00112 #if 0
00113     // Get log likelihood via Gibbs sampling 
00114     // (estimated as product of marginals)
00115     double getGibbsLogLikelihood(Database* db, const Array<int>& queryPreds);
00116 #endif
00117 
00118     double getWeightLogLikelihood(double sigmaSq) const {
00119 
00120         double ll = 0.0;
00121         for (int i = 0; i < getNumWeights(); i++) {
00122             ll -= getWeight(i) * getWeight(i) / sigmaSq;
00123         }
00124         return ll;
00125     }
00126 
00127     void changedPredicate(const Predicate* pred, Database* db) {
00128         Array<int> grounding;
00129         for (int i = 0; i < pred->getNumTerms(); i++) {
00130             grounding.append(pred->getTerm(i)->getId());
00131         }
00132         featureArray_[pred->getId()-1]->invalidate(grounding, db);
00133     }
00134 
00135     void invalidateAll() {
00136 
00137         for (int i = 0; i < getNumFeatures(); i++) {
00138             featureArray_[i]->invalidateAll();
00139         }
00140     }
00141 
00142 #if 0
00143     double getPredicateLogLikelihood(const Predicate* pred, Database* db)
00144     {
00145 
00146         const Domain* domain = db->getDomain();
00147         TruthValue originalValue = db->getValue(pred);
00148 
00149         RecursiveFeature* root = (RecursiveFeature*)topFeature_;
00150 
00151         double posWeightSum = 0.0;
00152         double negWeightSum = 0.0;
00153         Array<int> nullGrounding;
00154         for (int wi = 0; wi < topFeature_->getNumWeights(); wi++)
00155         {
00156             ArraysAccessor<int>* iter 
00157                 = root->getChildGroundingIter(wi, nullGrounding, db);
00158 
00159 #if 0
00160             Array<int> grounding;
00161 #else
00162             // Prepare array once.  Efficiency hack.
00163             Array<int> grounding(root->getChild(wi)->getNumTerms());
00164             for (int i = 0; i < root->getChild(wi)->getNumTerms(); i++) {
00165                 grounding.append(-1);
00166             }
00167 #endif
00168             while (iter->hasNextCombination()) {
00169 #if 0
00170                 iter->getNextCombination(grounding);
00171 #else
00172                 for (int i = 0; i < grounding.size(); i++) {
00173                     grounding[i] = iter->getItem(i);
00174                 }
00175                 iter->next();
00176 #endif
00177                 // HACK DEBUG
00178                 //posWeightSum += root->getWeight(wi) * 1.0;
00179                 //continue;
00180 
00181                 bool allTermsPresent = true;
00182                 for (int i = 0; i < pred->getNumTerms(); i++) {
00183                     int id = pred->getTerm(i)->getId();
00184                     bool termPresent = false;
00185                     for (int j = 0; j < grounding.size(); j++) {
00186                         if (grounding[j] == id) {
00187                             termPresent = true;
00188                             break;
00189                         }
00190                     }
00191 
00192                     if (!termPresent) {
00193                         allTermsPresent = false;
00194                         break;
00195                     }
00196                 }
00197 
00198                 if (allTermsPresent) {
00199                     db->setValue(pred, TRUE);
00200                     posWeightSum += root->getWeight(wi) * 
00201                         featureArray_[domain->getNumPredicates()+wi]
00202                         ->computeValue(grounding, db);
00203                     db->setValue(pred, FALSE);
00204                     negWeightSum += root->getWeight(wi) * 
00205                         featureArray_[domain->getNumPredicates()+wi]
00206                         ->computeValue(grounding, db);
00207                 }
00208             }
00209         }
00210 
00211         db->setValue(pred, originalValue);
00212 
00213         if (originalValue == TRUE) {
00214             return -log(1.0 + exp(negWeightSum - posWeightSum));
00215         } else {
00216             return -log(1.0 + exp(posWeightSum - negWeightSum));
00217         }
00218     }
00219 #endif
00220 
00221 
00222     // Note: SLOW!  This computes Z, the partition function, by looping
00223     // over all 2^n possible worlds, where n is the number of ground 
00224     // predicates.
00225     double getExactZ(const Domain* domain); 
00226 
00227     double getExactZ(const Domain* domain, const Array<int>& queryPreds,
00228         Database* origDb);
00229 
00230     
00231     void getCounts(Array<double>& counts, Database* db) 
00232     {
00233         Array<int> emptyGrounding;
00234         counts.clear();
00235 
00236         // For each weight of each feature, gather counts
00237         for (int i = 0; i < featureArray_.size(); i++) {
00238             for (int j = 0; j < featureArray_[i]->getNumWeights(); j++) {
00239                 counts.append(topFeature_->getPartialDeriv(
00240                         i, j, emptyGrounding, db));
00241             }
00242         }
00243     }
00244 
00245     void getPseudoCounts(Array<double>& counts, Database* db, 
00246             Array<Predicate*> queryPreds)
00247     {
00248         // TODO...
00249     }
00250 
00251     void setWeight(int idx, double wt) {
00252         // TODO: optimize?
00253         for (int i = 0; i < featureArray_.size(); i++) {
00254             int numWeights = featureArray_[i]->getNumWeights();
00255             if (idx < numWeights) {
00256                 featureArray_[i]->setWeight(idx, wt);
00257                 return;
00258             }
00259             idx -= numWeights;
00260         }
00261 
00262         // We should not reach here unless given invalid input.
00263         // TODO: better error reporting?
00264         assert(false);
00265     }
00266 
00267     double getWeight(int idx) const {
00268         // TODO: optimize?
00269         for (int i = 0; i < featureArray_.size(); i++) {
00270             int numWeights = featureArray_[i]->getNumWeights();
00271             if (idx < numWeights) {
00272                 return featureArray_[i]->getWeight(idx);
00273             }
00274             idx -= numWeights;
00275         }
00276 
00277         // We should not reach here unless given invalid input.
00278         // TODO: better error reporting?
00279         assert(false);
00280         return 0;
00281     }
00282 
00283     Feature* getRoot() {
00284         return topFeature_;
00285     }
00286 
00287     Feature* getFeature(int idx) {
00288         assert(idx < featureArray_.size());
00289         return featureArray_[idx];
00290     }
00291 
00292     const Feature* getFeature(int idx) const {
00293         assert(idx < featureArray_.size());
00294         return featureArray_[idx];
00295     }
00296 
00297     Feature* getFeature(const char* name) {
00298         for (int i = 0; i < featureArray_.size(); i++) {
00299             if (!strcmp(featureArray_[i]->getName(), name)) {
00300                 return featureArray_[i];
00301             }
00302         }
00303         return NULL;
00304     }
00305 
00306     const Feature* getFeature(const char* name) const {
00307         for (int i = 0; i < featureArray_.size(); i++) {
00308             if (!strcmp(featureArray_[i]->getName(), name)) {
00309                 return featureArray_[i];
00310             }
00311         }
00312         return NULL;
00313     }
00314 
00315     int getNumFeatures() const {
00316         return featureArray_.size();
00317     }
00318 
00319     int getNumWeights() const {
00320 
00321         int numWeights = 0;
00322         for (int i = 0; i < featureArray_.size(); i++) {
00323             numWeights += featureArray_[i]->getNumWeights();
00324         }
00325         return numWeights;
00326     }
00327 
00328    /*
00329     void learnWeights();
00330     void inferMissing(Database* db);
00331     ...
00332    */
00333 
00334     Array<Feature*> getFeatureArray() { return featureArray_; }
00335 
00336 protected:
00337     Feature* topFeature_;
00338     int maxFeatureId_;
00339     Array<Feature*> featureArray_;
00340 };
00341 
00342 
00343 inline ostream& operator<<(ostream& out, const RRF* rrf)
00344 {
00345     out << "RRF {\n";
00346     for (int i = 0;
00347             i < rrf->getNumFeatures(); i++) {
00348         rrf->getFeature(i)->print(out);
00349     }
00350     out << "}\n";
00351     return out;
00352 }
00353 
00354 #endif

Generated on Sun Jun 7 11:55:20 2009 for Alchemy by  doxygen 1.5.1