rrf.cpp

00001 #include "rrf.h"
00002 #include "domain.h"
00003 
00004 RRF* RRF::generateTwoLevel(Domain* domain) 
00005 {
00006     RRF* rrf = new RRF();
00007 
00008     // Create predicate (bottom level) features
00009     // HACK -- ignore equality predicate, 0
00010     for (int pred = 1; pred < domain->getNumPredicates(); pred++) {
00011         PredicateFeature* predFeat = new PredicateFeature(
00012                 domain->getPredicateTemplate(pred));
00013         rrf->maxFeatureId_++;
00014         predFeat->setId(rrf->maxFeatureId_);
00015         rrf->featureArray_.append(predFeat);
00016     }
00017 
00018     // Create top-level feature
00019     rrf->topFeature_ = new RecursiveFeature("root", true, false);
00020     rrf->maxFeatureId_++;
00021     rrf->topFeature_->setId(rrf->maxFeatureId_);
00022     rrf->featureArray_.append(rrf->topFeature_);
00023 
00024     return rrf;
00025 }
00026 
00027 void RRF::addCompleteFeature(Array<int> typeArity, Domain* domain, 
00028         const Array<int>& queryPreds, int numChildren)
00029 {
00030     char clausalName[100];
00031     sprintf(clausalName, "f%d", ++maxFeatureId_);
00032     ClausalFeature* feature = new ClausalFeature(clausalName);
00033     //RecursiveFeature* feature = new RecursiveFeature(clausalName);
00034     feature->setId(maxFeatureId_);
00035     Array<Array<int> > termIndicesByType;
00036 
00037     // For each type...
00038     for (int i = 0; i < typeArity.size(); i++) {
00039 
00040         // Keep a list of all indices for terms of the current type
00041         Array<int> currTypeTermIndices;
00042         for (int j = 0; j < typeArity[i]; j++) {
00043 
00044             // Add a term of the appropriate type to the current
00045             // feature, and record its index in the array.
00046             feature->addTermType(i);
00047             currTypeTermIndices.append(feature->getNumTerms() - 1);
00048         }
00049 
00050         // Append this array to the list that includes all types
00051         termIndicesByType.append(currTypeTermIndices);
00052     }
00053 
00054     // Add all groundings of each predicate
00055     // HACK: start at 1 to avoid equality predicate
00056     for (int pred = 1; pred < domain->getNumPredicates(); pred++) {
00057 
00058         const PredicateTemplate* predTemp = 
00059                 domain->getPredicateTemplate(pred);
00060 
00061         // Construct iterator over all term mappings of 
00062         // this predicate.
00063         // DEBUG
00064         //cout << predTemp->getNumTerms() << endl;
00065         ArraysAccessor<int> mapIter;
00066         for (int i = 0; i < predTemp->getNumTerms(); i++) {
00067             int termType = predTemp->getTermTypeAsInt(i);
00068             mapIter.appendArray(&(termIndicesByType[termType]));
00069             //Array<int>* copiedArray 
00070             //    = new Array<int>(termIndicesByType[termType]);
00071             //mapIter.appendArray(copiedArray);
00072         }
00073 
00074         // Add the predicate feature with each possible mapping.
00075         // Assumption: feature 0 is the top level feature;
00076         //   next features are the base predicate features.
00077         Feature* predFeature = featureArray_[pred-1];
00078         Array<int> mapping;
00079         do {
00080             mapIter.getNextCombination(mapping);
00081             feature->addChild(predFeature, 0, mapping);
00082         } while (mapIter.hasNextCombination());
00083     }
00084 
00085     Array<int> mapping;
00086     for (int i = 0; i < feature->getNumTerms(); i++) {
00087         mapping.append(-i - 1);
00088     }
00089     ((RecursiveFeature*)topFeature_)->addChild(feature, 0.0, mapping);
00090 
00091 
00092     // Use random example to set weights
00093     Array<int> grounding;
00094     for (int i = 0; i < feature->getNumTerms(); i++) {
00095         int type = feature->getTermType(i);
00096         
00097         const Array<int>* constants = domain->getConstantsByType(type);
00098         grounding.append((*constants)[rand() % constants->size()]);
00099     }
00100     feature->setWeightsFromExample(grounding, domain->getDB(), queryPreds, numChildren);
00101     topFeature_->setWeight(topFeature_->getNumWeights() - 1, frand());
00102 
00103     featureArray_.append(feature);
00104 }
00105 
00106 void RRF::addRandomFeature(int numChildren, Array<int> queryPreds,
00107         Array<int> typeArity, Domain* domain)
00108 {
00109     ClausalFeature* feature = NULL;
00110 
00111     Array<Array<int> > termIndicesByType;
00112 
00113     bool queryPredChosen = false;
00114     while (!queryPredChosen) {
00115 
00116         feature = new ClausalFeature;
00117 
00118         // For each type...
00119         for (int i = 0; i < typeArity.size(); i++) {
00120 
00121             // Keep a list of all indices for terms of the current type
00122             Array<int> currTypeTermIndices;
00123             for (int j = 0; j < typeArity[i]; j++) {
00124 
00125                 // Add a term of the appropriate type to the current
00126                 // feature, and record its index in the array.
00127                 feature->addTermType(i);
00128                 currTypeTermIndices.append(feature->getNumTerms() - 1);
00129             }
00130 
00131             // Append this array to the list that includes all types
00132             termIndicesByType.append(currTypeTermIndices);
00133         }
00134 
00135         // Add the specified number of random predicates
00136         for (int i = 0; i < numChildren; i++) {
00137 
00138             int pred = (int)(rand() % (domain->getNumPredicates()-1)) + 1; 
00139             const PredicateTemplate* predTemp = 
00140                     domain->getPredicateTemplate(pred);
00141 
00142             // Add the feature with a random mapping
00143             Array<int> mapping;
00144             for (int i = 0; i < predTemp->getNumTerms(); i++) {
00145                 int termType = predTemp->getTermTypeAsInt(i);
00146                 int numTerms = termIndicesByType[termType].size();
00147                 mapping.append(termIndicesByType[termType][rand() % numTerms]);
00148             }
00149 
00150             feature->addChild(featureArray_[pred], 0.0, mapping);
00151 
00152             // Check to see if we have a query predicate
00153             if (queryPreds.find(pred) != -1) {
00154                 queryPredChosen = true;
00155             }
00156         }
00157     }
00158 
00159     feature->setId(++maxFeatureId_);
00160 
00161     Array<int> mapping;
00162     for (int i = 0; i < feature->getNumTerms(); i++) {
00163         mapping.append(-i - 1);
00164     }
00165     ((RecursiveFeature*)topFeature_)->addChild(feature, 0.0, mapping);
00166 
00167     // Use random example to set weights
00168     Array<int> grounding;
00169     for (int i = 0; i < feature->getNumTerms(); i++) {
00170         int type = feature->getTermType(i);
00171         
00172         const Array<int>* constants = domain->getConstantsByType(type);
00173         grounding.append((*constants)[rand() % constants->size()]);
00174     }
00175     feature->setWeightsFromExample(grounding, domain->getDB(), queryPreds, 
00176             numChildren);
00177     topFeature_->setWeight(topFeature_->getNumWeights() - 1, 1.0 + frand());
00178 
00179     featureArray_.append(feature);
00180 }
00181 
00182 double RRF::getLogPseudoLikelihood(Database* db, const Array<int>& queryPreds)
00183 {
00184     Array<Predicate*> allPreds;
00185     for (int i = 0; i < queryPreds.size(); i++) {
00186         Array<Predicate*> predArray;
00187         Predicate::createAllGroundings(queryPreds[i], db->getDomain(), 
00188                 predArray);
00189         allPreds.append(predArray);
00190     }
00191 
00192     return getLogPseudoLikelihood(db, allPreds);
00193 }
00194 
00195 
00196 double RRF::getLogPseudoLikelihood(Database* db, 
00197         const Array<Predicate*>& allPreds)
00198 {
00199 
00200     // LPL = log pseudo-likelihood, 
00201     //       the conditional probability of each query predicate 
00202     //       conditioned on all data.
00203     double lpl = 0.0;
00204 #if 1
00205     invalidateAll();
00206     double logTruthProb = getLogValue(db);
00207 #endif
00208 
00209     for (int i = 0; i < allPreds.size(); i++) {
00210 #if 1
00211         TruthValue currValue = db->getValue(allPreds[i]);
00212         assert(currValue != UNKNOWN);
00213         TruthValue newValue = (currValue == TRUE) ? FALSE : TRUE;
00214         db->setValue(allPreds[i], newValue);
00215         changedPredicate(allPreds[i], db);
00216         double logUntruthProb = getLogValue(db);
00217         db->setValue(allPreds[i], currValue);
00218         changedPredicate(allPreds[i], db);
00219 
00220         //cout << log(truthProb / (truthProb + untruthProb));
00221 
00222         if (fabs(logUntruthProb - logTruthProb) > 100) {
00223             lpl += logUntruthProb - logTruthProb;
00224         } else {
00225             lpl += -log(1.0 + exp(logUntruthProb - logTruthProb));
00226         }
00227 #else
00228         lpl += getPredicateLogLikelihood(allPreds[i], db);
00229 #endif
00230     }
00231 
00232     return lpl/allPreds.size();
00233 }
00234 
00235 
00236 double RRF::getExactZ(const Domain* domain, const Array<int>& queryPreds,
00237         Database* origDb)
00238 {
00239     ArraysAccessor<TruthValue> predValues;
00240     Array<TruthValue> truthValues;
00241     truthValues.append(TRUE);
00242     truthValues.append(FALSE);
00243 
00244     Array<Predicate*> allPreds;
00245     for (int i = 0; i < queryPreds.size(); i++) {
00246         Array<Predicate*> predArray;
00247         Predicate::createAllGroundings(queryPreds[i], domain, predArray);
00248         allPreds.append(predArray);
00249     }
00250 
00251     for (int i = 0; i < allPreds.size(); i++) {
00252         predValues.appendArray(&truthValues);
00253     }
00254 
00255     double Z = 0.0;
00256 
00257     // Create database, using closed-world assumption
00258     Array<bool> closedWorld;
00259     for (int i = 0; i < domain->getNumPredicates(); i++) {
00260         closedWorld.append(true);
00261     }
00262     Database* db = new Database(domain, closedWorld, true);
00263 
00264     // Fill in values from the original db
00265     if (origDb != NULL) {
00266         for (int i = 1; i < domain->getNumPredicates(); i++) {
00267             Array<Predicate*> predArray;
00268             Predicate::createAllGroundings(i, domain, predArray);
00269             for (int j = 0; j < predArray.size(); j++) {
00270                 db->setValue(predArray[j], origDb->getValue(predArray[j]));
00271             }
00272             predArray.deleteItemsAndClear();
00273         }
00274     }
00275 
00276     // Sum value for all worlds (keeping the evidence pred values the same)
00277     do {
00278         Array<TruthValue> truthValues;
00279         predValues.getNextCombination(truthValues);
00280         for (int i = 0; i < truthValues.size(); i++) {
00281             db->setValue(allPreds[i], truthValues[i]);
00282         }
00283 
00284         Array<int> emptyGrounding;
00285         Z += topFeature_->getValue(emptyGrounding, db);
00286     } while (predValues.hasNextCombination());
00287 
00288     allPreds.deleteItemsAndClear();
00289 
00290     delete db;
00291 
00292     return Z;
00293 }
00294 
00295 double RRF::getExactZ(const Domain* domain) 
00296 {
00297     Array<int> allPreds;
00298     for (int i = 1; i < domain->getNumPredicates(); i++) {
00299         allPreds.append(i);
00300     }
00301     return getExactZ(domain, allPreds, NULL);
00302 }
00303 
00304 inline bool isnamebegin(char c) {
00305     return (isalpha(c) || c == '_');
00306 }
00307 
00308 inline bool isname(char c) {
00309     return (isalnum(c) || c == '_');
00310 }
00311 
00312 char* nextChild(char* buf, const Array<char*>& vars, 
00313         double& weight, string& name, Array<int>& grounding)
00314 {
00315     char* pbuf;
00316     bool isConstant = false;
00317 
00318     // Read weight
00319     while (isspace(*buf)) { buf++; }
00320     pbuf = buf;
00321     while (!isspace(*pbuf) && *pbuf != ')' && *pbuf) { pbuf++; }
00322     if (*pbuf == ')') {
00323         isConstant = true;
00324     } 
00325     *pbuf = '\0';
00326     weight = atof(buf);
00327     buf = pbuf + 1;
00328 
00329 #if PARSE_DEBUG
00330     cout << "Read weight: " << weight << endl;
00331 #endif
00332 
00333     while (!isConstant && isspace(*buf)) { buf++; }
00334     if (*buf == ')') {
00335         isConstant = true;
00336     }
00337     if (!isnamebegin(*buf)) {
00338         isConstant = true;
00339     }
00340 
00341     if (!isConstant) {
00342         // Read name
00343         char* pbuf = buf;
00344         while (isname(*pbuf)) { pbuf++; }
00345         if (*pbuf != '(') {
00346             cout << "ERROR: expected '('; found '" << *pbuf << "'.\n";
00347             return NULL;
00348         }
00349         // Mark the end of the name with a NULL and save it to a string
00350         *pbuf = '\0';
00351         name = buf;
00352         buf = pbuf + 1;
00353 
00354 #if PARSE_DEBUG
00355         cout << "Read name: " << name << endl;
00356 #endif
00357 
00358         // Read arguments.  No spaces allowed in between!
00359         while (isspace(*buf)) { buf++; }
00360         if (*buf != ')') {
00361             pbuf = buf;
00362             bool argsDone = false;
00363             while (!argsDone) {
00364                 while (isname(*pbuf)) { pbuf++; }
00365                 if (*pbuf != ',' && *pbuf != ')') {
00366                     cout << "ERROR: expected ','; found '" << *pbuf << "'.\n";
00367                     return NULL;
00368                 }
00369 
00370                 if (*pbuf == ')') {
00371                     argsDone = true;
00372                 }
00373 
00374                 // Mark the end of the argument with a NULL
00375                 *pbuf = '\0';
00376                 int varIndex = -1;
00377                 for (int i = 0; i < vars.size(); i++) {
00378                     if (!strcmp(buf, vars[i])) {
00379                         varIndex = i;
00380                         break;
00381                     }
00382                 }
00383                 grounding.append(varIndex);
00384 
00385                 // Advance past the NULL, to the next argument
00386                 pbuf++;
00387                 buf = pbuf;
00388             }
00389         } else {
00390             // Advance past ')'
00391             buf++;
00392         }
00393     } else {
00394         name = "";
00395     }
00396 
00397     // Advance to next child, so we're ready to parse it.
00398     // Return NULL if there are no more children.
00399     while (isspace(*buf)) { buf++; }
00400     if (*buf == '+') {
00401         return ++buf;
00402     } else if (*buf == ')' || *buf == '\n' || *buf == '\r' || *buf == '\0') {
00403         return NULL;
00404     } else {
00405         // TODO: distinguish between NULL and error?
00406         cout << "ERROR: read \'" << *buf << "\'; expected '+' or ')'.\n";
00407         return NULL;
00408     }
00409 }
00410 
00411 
00412 // Utility function
00413 void readFeature(Array<string>& types, Array<double>& weights, 
00414         Array<string>& children, Array<Array<int> >& mapping, 
00415         string& name, char* buf)
00416 {
00417     types.clear();
00418     weights.clear();
00419     children.clear();
00420     mapping.clear();
00421 
00422     // ASSUMPTION: no spaces in between arguments
00423     char* pbuf;
00424 
00425     // Get feature name
00426     while (isspace(*buf)) { buf++; }
00427     pbuf = buf;
00428     while (isname(*pbuf)) { pbuf++; }
00429     if (*pbuf != '(') { 
00430         cout << "ERROR: expected ')'; read '" << *pbuf << "'.\n";
00431         return;
00432     }
00433     *pbuf = '\0';
00434     name = buf;
00435 
00436 #if PARSE_DEBUG
00437     cout << "Read feature name: " << name << endl;
00438 #endif
00439 
00440     // Read in feature arguments
00441     pbuf++;
00442     buf = pbuf;
00443     Array<char*> vars;
00444     bool done = false;
00445     while (isspace(*pbuf)) { ++pbuf; }
00446     if (*pbuf == ')') {
00447         done = true;
00448     }
00449 
00450     // Read name/type pairs
00451     while (*buf && !done) {
00452 
00453         char* p2buf;
00454 
00455         //
00456         // READ NAME
00457         //
00458         while (isname(*pbuf))  { ++pbuf; }
00459 
00460         // Read spaces, ':', and more spaces
00461         p2buf = pbuf;
00462         while (isspace(*p2buf)) { ++p2buf; }
00463         if (*p2buf != ':') {
00464             cout << "ERROR: expected ':'; read '" << *p2buf << "'\n";
00465             return;
00466         }
00467         p2buf++;
00468         while (isspace(*p2buf)) { ++p2buf; }
00469 
00470         // Save name in array and update buffer pointers
00471         *pbuf = '\0';
00472         vars.append(buf);
00473         pbuf = buf = p2buf;
00474 
00475         //
00476         // READ TYPE
00477         //
00478         while(isname(*pbuf))  { ++pbuf; }
00479 
00480         // Read spaces, ',' or ')', and more spaces
00481         p2buf = pbuf;
00482         while(isspace(*p2buf)) { ++p2buf; }
00483         if ((*p2buf != ',') && (*p2buf != ')')) {
00484             cout << "ERROR: expected ',' or ')'; read '" << *p2buf << "'.\n";
00485             return;
00486         }
00487         if (*p2buf == ')') {
00488             done = true;
00489         }
00490         p2buf++;
00491         while (isspace(*p2buf)) { ++p2buf; }
00492 
00493         // Save type in array and update buffer pointers
00494         *pbuf = '\0';
00495         types.append(buf);
00496         pbuf = buf = p2buf;
00497     }
00498 
00499     // Skip the " = exp(" part
00500     while (*buf && *buf != '(') { buf++; }
00501     buf++;
00502 
00503     while (buf) {
00504 
00505         double weight;
00506         Array<int> grounding;
00507         string childName;
00508 
00509         // Read the next child
00510         buf = nextChild(buf, vars, weight, childName, grounding);
00511 
00512         weights.append(weight);
00513         children.append(childName);
00514         mapping.append(grounding);
00515     }
00516 }
00517 
00518 
00519 
00520 void RRF::load(istream& in, Domain* domain)
00521 {
00522 #if PARSE_DEBUG
00523     cout << "Loading RRF from file...\n";
00524 #endif
00525 
00526     // HACK -- this is fragile, and presumes a very rigid input format.
00527     // TODO -- fix and generalize.
00528 
00529     // Create predicate (bottom level) features
00530     // HACK -- ignore equality predicate, 0
00531     for (int pred = 1; pred < domain->getNumPredicates(); pred++) {
00532         PredicateFeature* predFeat = new PredicateFeature(
00533                 domain->getPredicateTemplate(pred));
00534         maxFeatureId_++;
00535         predFeat->setId(maxFeatureId_);
00536         featureArray_.append(predFeat);
00537     }
00538 
00539     // Create a single constant feature (only one is needed)
00540     {
00541         ConstantFeature* conFeature = new ConstantFeature(1.0);
00542         maxFeatureId_++;
00543         conFeature->setId(maxFeatureId_);
00544         featureArray_.append(conFeature);
00545     }
00546 
00547 #if PARSE_DEBUG
00548     cout << "Created predicate features.\n";
00549 #endif
00550     int firstNonPredFeature = maxFeatureId_+1;
00551 
00552 
00553     char buffer[10240];
00554     char* buf = buffer;
00555 
00556     // Skip first line, which should be "RRF {\n"
00557     in.getline(buf, 10240);
00558 
00559     // Keep track of children, so feature can refer to later features
00560     // as necessary, rather than being defined in a specific order.
00561     Array<Array<string> > allChildren;
00562     Array<Array<double> > allWeights;
00563     Array<Array<Array<int> > > allGroundings;
00564 
00565     // Each of the following lines should contain a feature.
00566     // Read in and set all feature weights.
00567     while (1) {
00568 
00569         Array<string> terms;
00570         Array<double> weights;
00571         Array<string> children;
00572         Array<Array<int> > groundings;
00573         string name;
00574 
00575         in.getline(buf, 10240);
00576 
00577         // Stop if we hit a closing '}'
00578         while (isspace(*buf))     { buf++; }
00579         if (buf[0] == '}' || !in) { break; }
00580 
00581         readFeature(terms, weights, children, groundings, name, buf);
00582         RecursiveFeature* feat;
00583         if (!name.compare("root") || !name.compare("f0")) {
00584             feat = new RecursiveFeature(name.c_str(), true, false);
00585             topFeature_ = feat;
00586         } else {
00587             feat = new RecursiveFeature(name.c_str(), false, true);
00588             //feat = new ClausalFeature(name.c_str());
00589         }
00590         feat->setId(++maxFeatureId_);
00591         featureArray_.append(feat);
00592         for (int j = 0; j < terms.size(); j++) {
00593             int type;
00594             type = atoi(terms[j].c_str());
00595             if (type < 1) {
00596                 type = domain->getTypeId(terms[j].c_str());
00597             }
00598             if (type < 1) {
00599                 cout << "ERROR: Unknown term type '" << terms[j] 
00600                     << "' for feature '" << name << "'.\n";
00601                 type = 1;
00602             }
00603             feat->addTermType(type);
00604         }
00605 
00606 #if PARSE_DEBUG
00607         cout << "Read feature: " << name << endl;
00608         cout << "Terms: ";
00609         for (int j = 0; j < terms.size(); j++) {
00610             cout << terms[j] << endl;
00611         }
00612 #endif
00613 
00614         // We have to postpone this, because the referenced features may not
00615         // have been defined yet.
00616         allChildren.append(children);
00617         allWeights.append(weights);
00618         allGroundings.append(groundings);
00619     }
00620 
00621     // After reading in all features, construct parent/child relationships
00622     // TODO: check for cycles!
00623     for (int i = 0; i < getNumFeatures() - firstNonPredFeature; i++) {
00624         for (int j = 0; j < allChildren[i].size(); j++) {
00625             Feature* child = getFeature(allChildren[i][j].c_str());
00626             if (child == NULL) {
00627                 cout << "ERROR: unknown feature " << allChildren[i][j] <<
00628                     " referenced by feature " << 
00629                     getFeature(i + firstNonPredFeature)->getName() << ".\n";
00630             } 
00631             RecursiveFeature* parent = (RecursiveFeature*)getFeature(
00632                     i + firstNonPredFeature);
00633             if (parent == NULL) {
00634                 cout << "ERROR: could not find index " 
00635                     << (i + firstNonPredFeature) << " feature!\n";
00636             }
00637             parent->addChild(child, allWeights[i][j], allGroundings[i][j]);
00638         }
00639     }
00640 
00641     // DEBUG
00642     cout << this << endl;
00643     return;
00644 }

Generated on Sun Jun 7 11:55:20 2009 for Alchemy by  doxygen 1.5.1