00001 #ifndef RRF_H
00002 #define RRF_H
00003 #define SIGMASQ 10000
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #include "feature.h"
00038 #include "predicate.h"
00039 #include "predicatetemplate.h"
00040 #include <math.h>
00041 #include <iostream>
00042 using namespace __gnu_cxx;
00043
00044 class RRF
00045 {
00046 private:
00047
00048
00049 RRF(const RRF& other);
00050
00051 public:
00052
00053 RRF() { maxFeatureId_ = -1; topFeature_ = NULL; }
00054
00055 ~RRF() {
00056 featureArray_.deleteItemsAndClear();
00057 }
00058
00059
00060
00061 void load(istream& in, Domain* domain);
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072 static RRF* generateTwoLevel(Domain* domain);
00073
00074
00075 void addRandomFeature(int numChildren, Array<int> queryPreds,
00076 Array<int> typeArity, Domain* domain);
00077
00078
00079 void addCompleteFeature(Array<int> typeArity, Domain* domain,
00080 const Array<int>& queryPreds, int numChildren);
00081
00082 double getValue(Database* db) {
00083 Array<int> emptyGrounding;
00084
00085
00086 return topFeature_->getValue(emptyGrounding, db);
00087 }
00088
00089 double getLogValue(Database* db) {
00090 Array<int> emptyGrounding;
00091
00092
00093 return topFeature_->getLogValue(emptyGrounding, db);
00094 }
00095
00096
00097 double getExactLikelihood(Database* db) {
00098 return getValue(db) / getExactZ(db->getDomain());
00099 }
00100
00101 double getExactConditionalLikelihood(Database* db,
00102 const Array<int>& queryPreds) {
00103 return getValue(db) / getExactZ(db->getDomain(), queryPreds, db);
00104 }
00105
00106
00107 double getLogPseudoLikelihood(Database* db, const Array<int>& queryPreds);
00108
00109 double getLogPseudoLikelihood(Database* db, const Array<Predicate*>&
00110 queryPreds);
00111
00112 #if 0
00113
00114
00115 double getGibbsLogLikelihood(Database* db, const Array<int>& queryPreds);
00116 #endif
00117
00118 double getWeightLogLikelihood(double sigmaSq) const {
00119
00120 double ll = 0.0;
00121 for (int i = 0; i < getNumWeights(); i++) {
00122 ll -= getWeight(i) * getWeight(i) / sigmaSq;
00123 }
00124 return ll;
00125 }
00126
00127 void changedPredicate(const Predicate* pred, Database* db) {
00128 Array<int> grounding;
00129 for (int i = 0; i < pred->getNumTerms(); i++) {
00130 grounding.append(pred->getTerm(i)->getId());
00131 }
00132 featureArray_[pred->getId()-1]->invalidate(grounding, db);
00133 }
00134
00135 void invalidateAll() {
00136
00137 for (int i = 0; i < getNumFeatures(); i++) {
00138 featureArray_[i]->invalidateAll();
00139 }
00140 }
00141
00142 #if 0
00143 double getPredicateLogLikelihood(const Predicate* pred, Database* db)
00144 {
00145
00146 const Domain* domain = db->getDomain();
00147 TruthValue originalValue = db->getValue(pred);
00148
00149 RecursiveFeature* root = (RecursiveFeature*)topFeature_;
00150
00151 double posWeightSum = 0.0;
00152 double negWeightSum = 0.0;
00153 Array<int> nullGrounding;
00154 for (int wi = 0; wi < topFeature_->getNumWeights(); wi++)
00155 {
00156 ArraysAccessor<int>* iter
00157 = root->getChildGroundingIter(wi, nullGrounding, db);
00158
00159 #if 0
00160 Array<int> grounding;
00161 #else
00162
00163 Array<int> grounding(root->getChild(wi)->getNumTerms());
00164 for (int i = 0; i < root->getChild(wi)->getNumTerms(); i++) {
00165 grounding.append(-1);
00166 }
00167 #endif
00168 while (iter->hasNextCombination()) {
00169 #if 0
00170 iter->getNextCombination(grounding);
00171 #else
00172 for (int i = 0; i < grounding.size(); i++) {
00173 grounding[i] = iter->getItem(i);
00174 }
00175 iter->next();
00176 #endif
00177
00178
00179
00180
00181 bool allTermsPresent = true;
00182 for (int i = 0; i < pred->getNumTerms(); i++) {
00183 int id = pred->getTerm(i)->getId();
00184 bool termPresent = false;
00185 for (int j = 0; j < grounding.size(); j++) {
00186 if (grounding[j] == id) {
00187 termPresent = true;
00188 break;
00189 }
00190 }
00191
00192 if (!termPresent) {
00193 allTermsPresent = false;
00194 break;
00195 }
00196 }
00197
00198 if (allTermsPresent) {
00199 db->setValue(pred, TRUE);
00200 posWeightSum += root->getWeight(wi) *
00201 featureArray_[domain->getNumPredicates()+wi]
00202 ->computeValue(grounding, db);
00203 db->setValue(pred, FALSE);
00204 negWeightSum += root->getWeight(wi) *
00205 featureArray_[domain->getNumPredicates()+wi]
00206 ->computeValue(grounding, db);
00207 }
00208 }
00209 }
00210
00211 db->setValue(pred, originalValue);
00212
00213 if (originalValue == TRUE) {
00214 return -log(1.0 + exp(negWeightSum - posWeightSum));
00215 } else {
00216 return -log(1.0 + exp(posWeightSum - negWeightSum));
00217 }
00218 }
00219 #endif
00220
00221
00222
00223
00224
00225 double getExactZ(const Domain* domain);
00226
00227 double getExactZ(const Domain* domain, const Array<int>& queryPreds,
00228 Database* origDb);
00229
00230
00231 void getCounts(Array<double>& counts, Database* db)
00232 {
00233 Array<int> emptyGrounding;
00234 counts.clear();
00235
00236
00237 for (int i = 0; i < featureArray_.size(); i++) {
00238 for (int j = 0; j < featureArray_[i]->getNumWeights(); j++) {
00239 counts.append(topFeature_->getPartialDeriv(
00240 i, j, emptyGrounding, db));
00241 }
00242 }
00243 }
00244
00245 void getPseudoCounts(Array<double>& counts, Database* db,
00246 Array<Predicate*> queryPreds)
00247 {
00248
00249 }
00250
00251 void setWeight(int idx, double wt) {
00252
00253 for (int i = 0; i < featureArray_.size(); i++) {
00254 int numWeights = featureArray_[i]->getNumWeights();
00255 if (idx < numWeights) {
00256 featureArray_[i]->setWeight(idx, wt);
00257 return;
00258 }
00259 idx -= numWeights;
00260 }
00261
00262
00263
00264 assert(false);
00265 }
00266
00267 double getWeight(int idx) const {
00268
00269 for (int i = 0; i < featureArray_.size(); i++) {
00270 int numWeights = featureArray_[i]->getNumWeights();
00271 if (idx < numWeights) {
00272 return featureArray_[i]->getWeight(idx);
00273 }
00274 idx -= numWeights;
00275 }
00276
00277
00278
00279 assert(false);
00280 return 0;
00281 }
00282
00283 Feature* getRoot() {
00284 return topFeature_;
00285 }
00286
00287 Feature* getFeature(int idx) {
00288 assert(idx < featureArray_.size());
00289 return featureArray_[idx];
00290 }
00291
00292 const Feature* getFeature(int idx) const {
00293 assert(idx < featureArray_.size());
00294 return featureArray_[idx];
00295 }
00296
00297 Feature* getFeature(const char* name) {
00298 for (int i = 0; i < featureArray_.size(); i++) {
00299 if (!strcmp(featureArray_[i]->getName(), name)) {
00300 return featureArray_[i];
00301 }
00302 }
00303 return NULL;
00304 }
00305
00306 const Feature* getFeature(const char* name) const {
00307 for (int i = 0; i < featureArray_.size(); i++) {
00308 if (!strcmp(featureArray_[i]->getName(), name)) {
00309 return featureArray_[i];
00310 }
00311 }
00312 return NULL;
00313 }
00314
00315 int getNumFeatures() const {
00316 return featureArray_.size();
00317 }
00318
00319 int getNumWeights() const {
00320
00321 int numWeights = 0;
00322 for (int i = 0; i < featureArray_.size(); i++) {
00323 numWeights += featureArray_[i]->getNumWeights();
00324 }
00325 return numWeights;
00326 }
00327
00328
00329
00330
00331
00332
00333
00334 Array<Feature*> getFeatureArray() { return featureArray_; }
00335
00336 protected:
00337 Feature* topFeature_;
00338 int maxFeatureId_;
00339 Array<Feature*> featureArray_;
00340 };
00341
00342
00343 inline ostream& operator<<(ostream& out, const RRF* rrf)
00344 {
00345 out << "RRF {\n";
00346 for (int i = 0;
00347 i < rrf->getNumFeatures(); i++) {
00348 rrf->getFeature(i)->print(out);
00349 }
00350 out << "}\n";
00351 return out;
00352 }
00353
00354 #endif