testrrf.cpp

00001 #include "mln.h"
00002 #include "fol.h"
00003 #include "rrf.h"
00004 #include "random.h"
00005 #include <math.h>
00006 #include "arguments.h"
00007 #include "infer.h"  // For extractPredNames()
00008 #include "timer.h"
00009 #include "gfeature.h"
00010 #include "lbfgsr.h"
00011 
00012 #define BURN_IN 100
00013 //#define BURN_IN 10
00014 
00015 // DEBUG
00016 int llcount = 0;
00017 
00018 //#define MAX_TRAIN_ITERS 10000
00019 #define CMP_RRF 0
00020 
00021 enum inferenceMethod { INF_ICM, INF_GIBBS, INF_EXACT, INF_PSEUDO, INF_MAX };
00022 
00023 char* afilename = NULL;
00024 char* nonEvidPredsStr = "all";
00025 char* aInferenceMethodStr = "p";
00026 int   aGibbsIters = 100;
00027 int   nf = 20;
00028 int   infMethod = INF_PSEUDO;
00029 double aAlpha = 0.1;
00030 bool aRec = false;
00031 bool aGround = false;
00032 int  aNumFeatures = 10;
00033 int  aPredsPerFeature = 5;
00034 double aSigmaSq = 100;
00035 char* atestConstant = NULL;
00036 char* amodelFilename = NULL;
00037 char* aoutputFilename = NULL;
00038 int aMaxTrainIters = 10000;
00039 bool aVerbose = false;
00040 int  aSeed = 1;
00041 bool aTrainTopLevel = false;
00042 bool aTrainBottomLevel = false;
00043 bool aInfer = false;
00044 bool aLPL = false;
00045 double aSamplingFrac = 1.0;
00046 bool aPseudoFast = false;
00047 
00048 
00049 ARGS ARGS::Args[] =
00050 {
00051     ARGS("i", ARGS::Req, afilename, "input .mln file"),
00052     ARGS("weights", ARGS::Opt, amodelFilename,
00053             "Filename containing initial weights\n"),
00054     ARGS("r", ARGS::Opt, aoutputFilename,
00055             "File for saving learned model\n"),
00056     ARGS("ne", ARGS::Opt, nonEvidPredsStr, 
00057        //"(for discriminative learning only) "
00058        "first-order non-evidence predicates (comma-separated with no space)"),
00059     ARGS("inf", ARGS::Opt, aInferenceMethodStr, "Inference method to use\n"
00060             "i = ICM; g = Gibbs sampling; p = pseudo-likelihood; e = Exact\n"),
00061     ARGS("frac", ARGS::Opt, aSamplingFrac, 
00062            "Fraction of ground predicates to consider in psuedo-likelihood\n"),
00063     ARGS("iters", ARGS::Opt, aGibbsIters, 
00064             "Number of iterations to use for Gibbs sampling\n"),
00065     ARGS("nf", ARGS::Opt, aNumFeatures, "Number of features to use\n"),
00066     ARGS("alpha", ARGS::Opt, aAlpha, "Learning rate\n"),
00067     ARGS("sigmasq", ARGS::Opt, aSigmaSq, "Large-weight penalty\n"),
00068     ARGS("rec", ARGS::Opt, aRec, "Compute counts recursively\n"),
00069     ARGS("ppf", ARGS::Opt, aPredsPerFeature, 
00070             "Number of ground predicates per feature\n"),
00071     ARGS("ground", ARGS::Opt, aGround, 
00072             "Compute counts using fully ground tree\n"),
00073     ARGS("test", ARGS::Opt, atestConstant,
00074             "Test constant, whose query preds are unknown\n"),
00075     ARGS("trainiters", ARGS::Opt, aMaxTrainIters,
00076             "Maximum number of training iterations\n"),
00077     ARGS("v", ARGS::Tog, aVerbose,
00078             "Verbose output\n"),
00079     ARGS("seed", ARGS::Opt, aSeed, "Random seed\n"),
00080     ARGS("top", ARGS::Tog, aTrainTopLevel,
00081             "Only train top-level weights (feature coefficients)\n"),
00082     ARGS("bottom", ARGS::Tog, aTrainBottomLevel,
00083             "Only train lower-level (within-feature) weights\n"),
00084     ARGS("infer", ARGS::Tog, aInfer, 
00085             "Run inference and produce per-constant probabilities\n"),
00086     ARGS("lpl", ARGS::Tog, aLPL, 
00087             "Produces pseudo-log-likelihood for query predicates\n"),
00088     ARGS("f", ARGS::Tog, aPseudoFast,
00089       "Use optimized (and broken) pseudo-likelihood gradient computation.\n"),
00090     ARGS()
00091 };
00092 
00093 // HACKISH
00094 Array<TruthValue> truthValues;
00095 void saveState(Database* db, Array<Predicate*>& queryPreds)
00096 {
00097     truthValues.clear();
00098     for (int i = 0; i < queryPreds.size(); i++) {
00099         truthValues.append(db->getValue(queryPreds[i]));
00100     }
00101 }
00102 
00103 void restoreState(Database* db, Array<Predicate*>& queryPreds)
00104 {
00105     for (int i = 0; i < queryPreds.size(); i++) {
00106         db->setValue(queryPreds[i], truthValues[i]);
00107     }
00108 }
00109 
00110 
00111 void runICMInference(RRF* rrf, Database* db, Array<Predicate*>& queryPreds)
00112 {
00113     // Initialize randomly
00114     for (int i = 0; i < queryPreds.size(); i++) {
00115         TruthValue tv = (frand() < 0.5) ? TRUE : FALSE;
00116         db->setValue(queryPreds[i], tv);
00117     }
00118     rrf->invalidateAll();
00119 
00120     // Toggle predicates one by one, until convergence
00121     bool changed = true;
00122     while (changed) {
00123         changed = false;
00124 
00125         double currLikelihood = rrf->getLogValue(db);
00126         for (int i = 0; i < queryPreds.size(); i++) {
00127 
00128             // Try toggling the value.
00129             //cout << *queryPreds[i] << endl;
00130             TruthValue oldValue = db->getValue(queryPreds[i]);
00131             TruthValue newValue = (oldValue == TRUE) ? FALSE : TRUE;
00132 
00133             db->setValue(queryPreds[i], newValue);
00134             rrf->changedPredicate(queryPreds[i], db);
00135             double newLikelihood = rrf->getLogValue(db);
00136             if (newLikelihood > currLikelihood) {
00137                 // If it yields higher likelihood, keep it.
00138                 currLikelihood = newLikelihood;
00139                 changed = true;
00140             } else {
00141                 // Otherwise, revert the change.
00142                 db->setValue(queryPreds[i], oldValue);
00143                 rrf->changedPredicate(queryPreds[i], db);
00144             }
00145         }
00146     }
00147 }
00148 
00149 
00150 void runICMInference(GroundRRF* grrf, const Array<int>& queryPreds,
00151         Database* db, RRF* rrf)
00152 {
00153     // Initialize randomly
00154     for (int q = 0; q < queryPreds.size(); q++) {
00155 #if CMP_RRF
00156         Array<Predicate*> queryGroundings;
00157         Predicate::createAllGroundings(queryPreds[q], db->getDomain(), 
00158                 queryGroundings);
00159 #endif
00160         for (int j = 0; j < grrf->getNumGroundings(queryPreds[q]-1); j++) {
00161             bool tv = (frand() < 0.5);
00162             grrf->setPredicateValue(queryPreds[q], j, tv);
00163 #if CMP_RRF
00164             db->setValue(queryGroundings[j], tv ? TRUE : FALSE);
00165 #endif
00166         }
00167     }
00168 
00169     // Toggle predicates one by one, until convergence
00170     bool changed = true;
00171     while (changed) {
00172         changed = false;
00173 
00174         double currLikelihood = grrf->getLogValue();
00175 #if CMP_RRF
00176         double rrfLikelihood = rrf->getLogValue(db);
00177         if (fabs(rrfLikelihood - currLikelihood) > 0.00001) {
00178             cout << "Likelihoods differ!\n";
00179         }
00180 #endif
00181         for (int q = 0; q < queryPreds.size(); q++) {
00182 
00183             int i = queryPreds[q];
00184 #if CMP_RRF 
00185             Array<Predicate*> queryGroundings;
00186             Predicate::createAllGroundings(queryPreds[q], db->getDomain(), 
00187                     queryGroundings);
00188 #endif
00189             for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
00190 
00191                 // Try toggling the value.
00192                 bool oldValue = grrf->getPredicateValue(i, j);
00193                 grrf->setPredicateAndUpdate(i, j, !oldValue);
00194                 //grrf->setPredicateValue(i, j, !oldValue);
00195 
00196                 double newLikelihood = grrf->getLogValue();
00197 #if CMP_RRF 
00198                 db->setValue(queryGroundings[j], oldValue ? FALSE : TRUE);
00199                 double rrfLikelihood = rrf->getLogValue(db);
00200                 if (fabs(rrfLikelihood - newLikelihood) > 0.00001) {
00201                     cout << "Toggled likelihoods differ!\n";
00202                 }
00203 #endif
00204                 if (newLikelihood > currLikelihood) {
00205                     // If it yields higher likelihood, keep it.
00206                     currLikelihood = newLikelihood;
00207                     changed = true;
00208                 } else {
00209                     // Otherwise, revert the change.
00210                     grrf->setPredicateAndUpdate(i, j, oldValue);
00211                     //grrf->setPredicateValue(i, j, oldValue);
00212 #if CMP_RRF
00213                     db->setValue(queryGroundings[j], oldValue ? TRUE : FALSE);
00214 #endif
00215                 }
00216             }
00217         }
00218     }
00219 }
00220 
00221 
00222 void getExactCounts(RRF* rrf, Database* db, Array<Predicate*>& queryPreds,
00223         Array<double>& counts)
00224 {
00225     // Start with no counts
00226     rrf->getCounts(counts, db);
00227     for (int i = 0; i < counts.size(); i++) {
00228         counts[i] = 0.0;
00229     }
00230 
00231     // The below was cut and pasted from RRF:getExactZ()
00232     ArraysAccessor<TruthValue> predValues;
00233     Array<TruthValue> truthValues;
00234     truthValues.append(TRUE);
00235     truthValues.append(FALSE);
00236 
00237     for (int i = 0; i < queryPreds.size(); i++) {
00238         predValues.appendArray(&truthValues);
00239     }
00240 
00241     double Z = 0.0;
00242 
00243     // Sum value for all worlds (keeping the evidence pred values the same)
00244     do {
00245         Array<TruthValue> truthValues;
00246         predValues.getNextCombination(truthValues);
00247         for (int i = 0; i < truthValues.size(); i++) {
00248             db->setValue(queryPreds[i], truthValues[i]);
00249         }
00250 
00251         double likelihood = rrf->getValue(db);
00252         Z += likelihood;
00253 
00254         // Collect counts
00255         Array<double> newCounts;
00256         rrf->getCounts(newCounts, db);
00257         for (int j = 0; j < newCounts.size(); j++) {
00258             counts[j] += likelihood * newCounts[j];
00259         }
00260 
00261         //cout << "truthValues.size() = " << truthValues.size() << endl;
00262         //cout << "Likelihood " << index << " = " << likelihood << endl;
00263     } while (predValues.hasNextCombination());
00264 
00265     // Normalize
00266     for (int j = 0; j < counts.size(); j++) {
00267         counts[j] /= Z;
00268     }
00269 }
00270 
00271 void getGibbsCounts(RRF* rrf, Database* db, Array<Predicate*>& queryPreds,
00272         Array<double>& counts)
00273 {
00274     // Get initial likelihood and counts
00275     double oldLikelihood = rrf->getValue(db);
00276     rrf->getCounts(counts, db);
00277 
00278     // TODO -- optimize...
00279     for (int iter = 0; iter < aGibbsIters; iter++) {
00280         for (int i = 0; i < queryPreds.size(); i++) {
00281 
00282             // Sample i'th predicate, conditioned on the rest
00283             TruthValue oldValue = db->getValue(queryPreds[i]);
00284             TruthValue newValue = (oldValue == TRUE) ? FALSE : TRUE;
00285 
00286             db->setValue(queryPreds[i], newValue);
00287             rrf->changedPredicate(queryPreds[i], db);
00288             double newLikelihood = rrf->getValue(db);
00289             double prob = newLikelihood/(oldLikelihood + newLikelihood);
00290             if (frand() < prob) {
00291                 oldLikelihood = newLikelihood;
00292             } else {
00293                 db->setValue(queryPreds[i], oldValue);
00294                 rrf->changedPredicate(queryPreds[i], db);
00295             }
00296         }
00297 
00298         // Collect counts
00299         Array<double> newCounts;
00300         rrf->getCounts(newCounts, db);
00301         for (int j = 0; j < newCounts.size(); j++) {
00302             counts[j] += newCounts[j];
00303         }
00304     }
00305 
00306     // Normalize
00307     for (int j = 0; j < counts.size(); j++) {
00308         counts[j] /= ((double)aGibbsIters+1.0);
00309     }
00310 }
00311 
00312 void runGibbs(RRF* rrf, Database* db, Array<Predicate*>& queryPreds,
00313         Array<double>& counts)
00314 {
00315     // Get initial likelihood and counts
00316     double oldLikelihood = rrf->getValue(db);
00317     counts.clear();
00318 
00319     // TODO -- optimize...
00320     for (int iter = 0; iter < aGibbsIters; iter++) {
00321         for (int i = 0; i < queryPreds.size(); i++) {
00322 
00323             // Sample i'th predicate, conditioned on the rest
00324             TruthValue oldValue = db->getValue(queryPreds[i]);
00325             TruthValue newValue = (oldValue == TRUE) ? FALSE : TRUE;
00326             TruthValue finalValue;
00327 
00328             db->setValue(queryPreds[i], newValue);
00329             rrf->changedPredicate(queryPreds[i], db);
00330             double newLikelihood = rrf->getValue(db);
00331             double prob = newLikelihood/(oldLikelihood + newLikelihood);
00332             if (frand() < prob) {
00333                 oldLikelihood = newLikelihood;
00334                 finalValue = newValue;
00335             } else {
00336                 db->setValue(queryPreds[i], oldValue);
00337                 rrf->changedPredicate(queryPreds[i], db);
00338                 finalValue = oldValue;
00339             }
00340 
00341             // Keep track of counts
00342             if (counts.size() <= i) {
00343                 counts.append(0.5);
00344             }
00345             if (finalValue == TRUE) {
00346                 counts[i]++;
00347             }
00348         }
00349     }
00350 
00351     // Normalize
00352     for (int j = 0; j < counts.size(); j++) {
00353         counts[j] /= ((double)aGibbsIters+1.0);
00354     }
00355 }
00356 
00357 double getGibbsLogLikelihood(RRF* rrf, Database* db, 
00358         Array<Predicate*>& queryPreds)
00359 {
00360     Array<double> counts;
00361 
00362     Array<TruthValue> trueValues;
00363 
00364     for (int i = 0; i < queryPreds.size(); i++) {
00365         trueValues.append(db->getValue(queryPreds[i]));
00366         //db->setValue(queryPreds[i], (frand() > 0.5) ? TRUE : FALSE);
00367         counts.append(0.1);
00368     }
00369     double total = 0.2;
00370 
00371     // Initialize using ICM
00372     runICMInference(rrf, db, queryPreds);
00373 
00374 
00375     rrf->invalidateAll();
00376     double oldLikelihood = rrf->getLogValue(db);
00377 
00378     for (int iter = 0; iter < aGibbsIters + BURN_IN; iter++) {
00379         for (int i = 0; i < queryPreds.size(); i++) {
00380 
00381             // Sample i'th predicate, conditioned on the rest
00382             TruthValue oldValue = db->getValue(queryPreds[i]);
00383             TruthValue newValue = (oldValue == TRUE) ? FALSE : TRUE;
00384             TruthValue finalValue;
00385             
00386             db->setValue(queryPreds[i], newValue);
00387             rrf->changedPredicate(queryPreds[i], db);
00388             double newLikelihood = rrf->getLogValue(db);
00389             double prob = 1.0/(1.0 + exp(oldLikelihood - newLikelihood));
00390             if (oldLikelihood - newLikelihood > 100) {
00391                 prob = 0.0;
00392             } else if (oldLikelihood - newLikelihood < -100) {
00393                 prob = 1.0;
00394             }
00395             if (frand() < prob) {
00396                 oldLikelihood = newLikelihood;
00397                 finalValue = newValue;
00398             } else {
00399                 db->setValue(queryPreds[i], oldValue);
00400                 rrf->changedPredicate(queryPreds[i], db);
00401                 finalValue = oldValue;
00402             }
00403 
00404             if (iter >= BURN_IN && finalValue) {
00405                 counts[i]++;
00406             }
00407         }
00408 
00409         if (iter >= BURN_IN) {
00410             total++;
00411         }
00412     }
00413 
00414     double ll = 0.0;
00415 
00416     // Normalize
00417     for (int j = 0; j < counts.size(); j++) {
00418 
00419         if (trueValues[j]) {
00420             ll += log(counts[j]/total);
00421         } else {
00422             ll += log(1.0 - counts[j]/total);
00423         }
00424 
00425         cout << *(queryPreds[j]) << " = " << trueValues[j] 
00426             << " (" << (counts[j]/total) << ")\n";
00427     }
00428 
00429     //cout << right/(right + wrong) << endl;
00430     //cout << naive/(right + wrong) << endl;
00431 
00432     return ll;
00433 }
00434 
00435 
00436 double getGibbsLogLikelihood2(RRF* rrf, Database* db, 
00437         Array<Predicate*>& queryPreds)
00438 {
00439     Array<int> queryPredId;
00440     Array<int> queryGroundId;
00441 
00442     Array<double> counts;
00443     Array<TruthValue> trueValues;
00444 
00445     double ll = 0.0;
00446 
00447     for (int i = 0; i < queryPreds.size(); i++) {
00448         trueValues.append(db->getValue(queryPreds[i]));
00449 
00450         // Store identity of query predicates
00451         queryPredId.append(queryPreds[i]->getId());
00452         Array<int> grounding;
00453         for (int j = 0; j < queryPreds[i]->getNumTerms(); j++) {
00454             grounding.append(queryPreds[i]->getTerm(j)->getId());
00455         }
00456         queryGroundId.append(rrf->getFeature(queryPredId[i]-1)->
00457                 getGroundingIndex(grounding, db));
00458 
00459         counts.append(0.1);
00460     }
00461     double total = 0.2;
00462 
00463 
00464     for (int trial= 0; trial < 10; trial++) {
00465 
00466         Array<int> qpreds;
00467         qpreds.append(queryPreds[0]->getId());
00468         GroundRRF* grrf = new GroundRRF(rrf, db);
00469         grrf->dirtyAll();
00470 
00471         double best_ll = -1.0e100;
00472         for (int init = 0; init < 10; init++) {
00473             runICMInference(grrf, qpreds, db, rrf);
00474             if (grrf->getLogValue() > best_ll) {
00475                 best_ll = grrf->getLogValue();
00476                 saveState(db, queryPreds);
00477             }
00478         }
00479         restoreState(db, queryPreds);
00480 
00481         double oldLikelihood = grrf->getLogValue();
00482         for (int iter = 0; iter < aGibbsIters + BURN_IN; iter++) {
00483             for (int i = 0; i < queryPreds.size(); i++) {
00484 
00485                 // Sample i'th predicate, conditioned on the rest
00486                 bool oldValue = grrf->getPredicateValue(
00487                         queryPredId[i], queryGroundId[i]);
00488                 bool finalValue;
00489 
00490                 grrf->setPredicateAndUpdate(queryPredId[i], queryGroundId[i], 
00491                         !oldValue);
00492                 double newLikelihood = grrf->getLogValue();
00493                 //double prob = newLikelihood/(oldLikelihood + newLikelihood);
00494                 double prob = 1.0/(1.0 + exp(oldLikelihood - newLikelihood));
00495                 if (oldLikelihood - newLikelihood > 100) {
00496                     prob = 0.0;
00497                 } else if (oldLikelihood - newLikelihood < -100) {
00498                     prob = 1.0;
00499                 }
00500                 //cout << prob << endl;
00501                 if (frand() < prob) {
00502                     oldLikelihood = newLikelihood;
00503                     finalValue = !oldValue;
00504                 } else {
00505                     grrf->setPredicateAndUpdate(queryPredId[i], 
00506                             queryGroundId[i], oldValue);
00507                     finalValue = oldValue;
00508                 }
00509 
00510                 if (iter >= BURN_IN && finalValue) {
00511                     counts[i]++;
00512                 }
00513             }
00514 
00515             if (iter >= BURN_IN) {
00516                 total++;
00517             }
00518         }
00519     }
00520 
00521     // Normalize
00522     for (int j = 0; j < counts.size(); j++) {
00523 
00524         double curr_ll;
00525         if (trueValues[j]) {
00526             curr_ll = log(counts[j]/total);
00527         } else {
00528             curr_ll = log(1.0 - counts[j]/total);
00529         }
00530         ll += curr_ll;
00531 
00532         queryPreds[j]->printWithStrVar(cout, db->getDomain());
00533         cout << " = " << trueValues[j] 
00534             << " (" << (counts[j]/total) << ")  " << curr_ll << endl;
00535 
00536     }
00537 
00538     return ll;
00539 }
00540 
00541 
00542 #define USE_LOGS 1
00543 
00544 void getGibbsCounts(GroundRRF* grrf, const Array<int>& queryPreds,
00545         Array<double>& counts, Database* db, RRF* rrf)
00546 {
00547     // Get initial likelihood and counts
00548 #if USE_LOGS
00549     double oldLikelihood = grrf->getLogValue();
00550 #else
00551     double oldLikelihood = grrf->getValue();
00552 #endif
00553     grrf->getCounts(counts);
00554 
00555 #if CMP_RRF
00556     Array<double> rrfCounts;
00557 
00558     for (int q = 0; q < queryPreds.size(); q++) {
00559         int i = queryPreds[q];
00560         Array<Predicate*> queryGroundings;
00561         Predicate::createAllGroundings(i, db->getDomain(), queryGroundings);
00562         for (int j = 0; j < queryGroundings.size(); j++) {
00563             bool value = grrf->getPredicateValue(i,j);
00564             db->setValue(queryGroundings[j], value ? TRUE : FALSE);
00565         }
00566     }
00567 #endif
00568 
00569     Array<double> newCounts;
00570     for (int iter = 0; iter < aGibbsIters; iter++) {
00571         for (int q = 0; q < queryPreds.size(); q++) {
00572             int i = queryPreds[q];
00573             // NOTE: assuming that predicate 0 is equality and therefore 
00574             // skipped.  This is the source of the (i-1).
00575             for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
00576 
00577                 // Sample predicate, conditioned on the rest
00578                 bool oldValue = grrf->getPredicateValue(i,j);
00579                 grrf->setPredicateAndUpdate(i,j,!oldValue);
00580                 //grrf->setPredicateValue(i,j,!oldValue);
00581 #if CMP_RRF
00582                 db->setValue(queryGroundings[j], oldValue ? FALSE : TRUE);
00583 #endif
00584 
00585 #if USE_LOGS
00586                 double newLikelihood = grrf->getLogValue();
00587 #else
00588                 double newLikelihood = grrf->getValue();
00589 #endif
00590 
00591 #if CMP_RRF  // TODO -- fix this...
00592                 double rrfLikelihood = rrf->getValue(db);
00593                 if (fabs(rrfLikelihood - newLikelihood) > 0.00001) {
00594                     cout << "Different likelihoods: \n";
00595                     cout << rrfLikelihood << endl;
00596                     cout << newLikelihood << endl;
00597                 }
00598 #endif
00599 
00600 #if USE_LOGS
00601                 double prob = 1.0/(1.0 + exp(oldLikelihood - newLikelihood));
00602                 if (oldLikelihood - newLikelihood > 100) {
00603                     prob = 0.0;
00604                 } else if (oldLikelihood - newLikelihood < -100) {
00605                     prob = 1.0;
00606                 }
00607 #else
00608                 double prob = newLikelihood/(oldLikelihood + newLikelihood);
00609 #endif
00610 
00611                 if (frand() < prob) {
00612                     oldLikelihood = newLikelihood;
00613                 } else {
00614                     grrf->setPredicateAndUpdate(i,j,oldValue);
00615                     //grrf->setPredicateValue(i,j,oldValue);
00616 #if CMP_RRF
00617                     db->setValue(queryGroundings[j], oldValue ? TRUE : FALSE);
00618 #endif
00619                 }
00620 
00621 #if CMP_RRF
00622                 db->setValue(queryGroundings[j], 
00623                         grrf->getPredicateValue(i,j) ? TRUE : FALSE);
00624 #endif
00625             }
00626         }
00627 
00628         if (grrf == NULL) {
00629             cout << "grrf is somehow NULL!\n";
00630         }
00631 
00632         // Collect counts
00633         grrf->dirtyAll();
00634         grrf->getCounts(newCounts);
00635 #if CMP_RRF
00636         rrf->getCounts(rrfCounts, db);
00637 #endif
00638         for (int c = 0; c < newCounts.size(); c++) {
00639 #if CMP_RRF
00640             if (fabs(newCounts[c] - rrfCounts[c]) > 0.000001) {
00641                 cout << "Different counts for " << c << ":\n";
00642                 cout << "Ground: " << newCounts[c] << endl;
00643                 cout << "True:   " << rrfCounts[c] << endl;
00644             } 
00645 #endif
00646             counts[c] += newCounts[c];
00647         }
00648     }
00649 
00650     // Normalize
00651     for (int c = 0; c < counts.size(); c++) {
00652         counts[c] /= ((double)aGibbsIters+1.0);
00653     }
00654 }
00655 
00656 
00657 void runGibbs(GroundRRF* grrf, const Array<int>& queryPreds,
00658         Array<double>& counts, Database* db)
00659 {
00660     // Get initial likelihood and counts
00661     double oldLikelihood = grrf->getLogValue();
00662 
00663     counts.clear();
00664     for (int iter = 0; iter < aGibbsIters; iter++) {
00665         int countIndex = 0;
00666         for (int q = 0; q < queryPreds.size(); q++) {
00667             int i = queryPreds[q];
00668             // NOTE: assuming that predicate 0 is equality and therefore 
00669             // skipped.  This is the source of the (i-1).
00670             for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
00671 
00672                 // Sample predicate, conditioned on the rest
00673                 bool oldValue = grrf->getPredicateValue(i,j);
00674                 grrf->setPredicateAndUpdate(i,j,!oldValue);
00675                 //grrf->setPredicateValue(i,j,!oldValue);
00676                 bool finalValue;
00677 
00678                 double newLikelihood = grrf->getLogValue();
00679                 double prob = 1.0/(1.0 + exp(oldLikelihood - newLikelihood));
00680                 if (oldLikelihood - newLikelihood > 100) {
00681                     prob = 0.0;
00682                 } else if (oldLikelihood - newLikelihood < -100) {
00683                     prob = 1.0;
00684                 }
00685 
00686                 if (frand() < prob) {
00687                     oldLikelihood = newLikelihood;
00688                     finalValue = !oldValue;
00689                 } else {
00690                     grrf->setPredicateAndUpdate(i,j,oldValue);
00691                     finalValue = oldValue;
00692                 }
00693 
00694                 if (counts.size() <= countIndex) {
00695                     counts.append(0.5);
00696                 }
00697                 if (finalValue) {
00698                     counts[countIndex]++;
00699                 }
00700                 countIndex++;
00701             }
00702         }
00703 
00704         if (grrf == NULL) {
00705             cout << "grrf is somehow NULL!\n";
00706         }
00707     }
00708 
00709     // Normalize
00710     for (int c = 0; c < counts.size(); c++) {
00711         counts[c] /= ((double)aGibbsIters+1.0);
00712     }
00713 }
00714 
00715 void scrambleWeights(RRF* rrf) 
00716 {
00717     // Set initial (random) weights
00718     Array<Feature*> features = rrf->getFeatureArray();
00719     for (int i = 0; i < features.size(); i++) {
00720         for (int j = 0; j < features[i]->getNumWeights(); j++) {
00721             // Initialize to a random weight in (-0.5, +0.5)
00722             double randWt = frand() - 0.5;
00723             features[i]->setWeight(j, randWt);
00724         }
00725     }
00726 }
00727 
00728 void trainRRF(RRF* rrf, Database* db, const Array<int>& queryPreds)
00729 {
00730     // Set initial (random) weights
00731     // scrambleWeights(rrf);
00732 
00733     // Initial preparations for recursive counts computation
00734     Array<Predicate*> allGroundings;
00735     Array<TruthValue> truePredValues;
00736 
00737     if (aRec || CMP_RRF) {
00738         const Domain* domain = db->getDomain();
00739         for (int i = 0; i < queryPreds.size(); i++) {
00740 
00741             // Create all groundings for this predicate
00742             Array<Predicate*> predGroundings;
00743             Predicate::createAllGroundings(queryPreds[i], domain, predGroundings);
00744             allGroundings.append(predGroundings);
00745 
00746             // Save all original truth values (assuming complete data for now)
00747             for (int j = 0; j < predGroundings.size(); j++)  {
00748                 truePredValues.append(db->getValue(predGroundings[j]));
00749             }
00750         }
00751     }
00752 
00753     // Initial preparations for full ground tree counts computation
00754     GroundRRF* grrf = NULL;
00755 #if 1
00756     GroundRRF* trueGrrf = NULL;
00757 #else
00758     Array<Array<bool> > trueValues;
00759 #endif
00760 
00761     if (aGround) {
00762         grrf = new GroundRRF(rrf, db);
00763 #if 1
00764         trueGrrf = new GroundRRF(rrf, db);
00765 #else
00766         // Save true values of all query predicates
00767         for (int q = 0; q < queryPreds.size(); q++) {
00768             int i = queryPreds[q];
00769             trueValues.append(Array<bool>());
00770             for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
00771                 trueValues[q].append(grrf->getPredicateValue(i,j));
00772             }
00773         }
00774 #endif
00775     }
00776 
00777 
00778     if (infMethod == INF_PSEUDO) {
00779 
00780         // Print out initial lpl
00781         cout << "ground lpl = " 
00782             << trueGrrf->getLogPseudoLikelihood(queryPreds) << endl;
00783 
00784         // Get weights
00785         int numWeights = rrf->getNumWeights();
00786         double wts[numWeights+1];
00787         for (int i = 0; i < numWeights; i++) {
00788             wts[i+1] = rrf->getWeight(i);
00789         }
00790 
00791 
00792         int minWt = 0;
00793         int maxWt = numWeights;
00794 
00795         if (aTrainBottomLevel) {
00796             minWt = rrf->getRoot()->getNumWeights();
00797         }
00798         if (aTrainTopLevel) {
00799             maxWt = rrf->getRoot()->getNumWeights();
00800         }
00801 
00802         // Solve
00803         int iter;
00804         bool error;
00805         LBFGSR solver(aMaxTrainIters, 1e-5, aSamplingFrac, rrf, trueGrrf, 
00806                 queryPreds, maxWt - minWt, minWt, aSigmaSq, aPseudoFast);
00807         solver.minimize(wts+minWt, iter, error);
00808         cout << "Iters required: " << iter << endl;
00809 
00810         // Save weights
00811         for (int i = 0; i < numWeights; i++) {
00812             rrf->setWeight(i, wts[i+1]);
00813         }
00814 
00815         trueGrrf->dirtyAll();
00816 
00817         // Print out final lpl
00818         cout << "ground lpl = " 
00819             << trueGrrf->getLogPseudoLikelihood(queryPreds) << endl;
00820     } else {
00821 
00822 
00823     // Loop until convergence
00824     Timer timer;
00825     double lastWriteSec = timer.time();
00826     for (int iter = 0; iter < aMaxTrainIters; iter++) {
00827 
00828         double begSec = timer.time();
00829 
00830 #if 0
00831         for (int i = 0; i < rrf->getNumFeatures(); i++) {
00832             cout << "F" << i << ":";
00833             Feature* feature = rrf->getFeature(i);
00834             for (int j = 0; j < feature->getNumWeights(); j++) {
00835                 cout << " " << feature->getWeight(j);
00836             }
00837             cout << "\n";
00838         }
00839 #endif
00840 
00841         // Print out log likelihood
00842         if (aRec) {
00843             cout << "rec lpl = " << rrf->getLogPseudoLikelihood(db, queryPreds) 
00844                 << endl;
00845         }
00846         if (aGround) {
00847 #if 0
00848             cout << "ground lpl = " << grrf->getLogPseudoLikelihood(queryPreds) 
00849                 << endl;
00850 #else
00851             cout << "ground lpl = " 
00852                 << trueGrrf->getLogPseudoLikelihood(queryPreds) << endl;
00853             cout << "ground ll = "
00854                 << trueGrrf->getLogValue() << endl;
00855 #if 0
00856             // DEBUG
00857             if (llcount++ % 100 == 0) {
00858                 cout << "ll = " << log(rrf->getExactConditionalLikelihood(db, queryPreds)) << endl;
00859             }
00860 #endif
00861 #endif
00862 #if CMP_RRF
00863             cout << "rec lpl = " << rrf->getLogPseudoLikelihood(db, queryPreds) 
00864                 << endl;
00865 #endif
00866         }
00867         cout << "wll = " << rrf->getWeightLogLikelihood(aSigmaSq) << endl;
00868 
00869         // Run inference
00870         if (aRec) {
00871             if (infMethod != INF_EXACT) {
00872                 runICMInference(rrf, db, allGroundings);
00873             }
00874         }
00875         if (aGround) {
00876             if (infMethod == INF_ICM) {
00877                 runICMInference(grrf, queryPreds, db, rrf);
00878             }
00879         }
00880         
00881         // Get inferred counts
00882         Array<double> groundInfCounts;
00883         Array<double> recInfCounts;
00884 
00885         if (infMethod == INF_GIBBS) {
00886             if (aRec)    { getGibbsCounts(rrf, db, allGroundings, recInfCounts); }
00887             if (aGround) { getGibbsCounts(grrf, queryPreds, groundInfCounts, 
00888                     db, rrf); }
00889         } else if (infMethod == INF_ICM) {
00890             if (aRec)    { rrf->getCounts(recInfCounts, db); }
00891             if (aGround) { grrf->getCounts(groundInfCounts); }
00892         } else if (infMethod == INF_EXACT) {
00893             if (aRec)    { getExactCounts(rrf, db, allGroundings, recInfCounts); }
00894             // TODO -- implement this
00895             if (aGround) { cout << "ERROR: exact inference not supported for ground network.\n";  }
00896         } else if (infMethod == INF_PSEUDO) {
00897             // No need to do any expectation!
00898         } else {
00899             cout << "ERROR: unknown inference method " << infMethod << endl;
00900         }
00901 
00902         // Reset database
00903 #if 0
00904         if (aGround) {
00905             for (int q = 0; q < queryPreds.size(); q++) {
00906                 for (int j = 0; j < trueValues[q].size(); j++) {
00907                     grrf->setPredicateValue(queryPreds[q], j, trueValues[q][j]);
00908                 }
00909             }
00910         }
00911 #endif
00912 
00913         if (aRec || CMP_RRF) 
00914         {
00915             // Reset database
00916             for (int i = 0; i < allGroundings.size(); i++) {
00917                 db->setValue(allGroundings[i], truePredValues[i]);
00918             }
00919             rrf->invalidateAll();
00920         }
00921 
00922         // Get true counts
00923         Array<double> groundTrueCounts;
00924         Array<double> recTrueCounts;
00925 
00926         if (aGround) {
00927 #if 0
00928             grrf->getCounts(groundTrueCounts);
00929 #else
00930             if (infMethod == INF_PSEUDO) {
00931                 if (aPseudoFast) {
00932                     trueGrrf->getPseudoCountsFast(groundTrueCounts, 
00933                             queryPreds, aSamplingFrac);
00934                 } else {
00935                     trueGrrf->getPseudoCounts(groundTrueCounts, queryPreds,
00936                             aSamplingFrac);
00937                 }
00938             } else {
00939                 trueGrrf->getCounts(groundTrueCounts);
00940             }
00941 #endif
00942         }
00943         if (aRec) {
00944             if (infMethod == INF_PSEUDO) {
00945                 // TODO... implement this properly
00946                 cout << "ERROR: pseudo-likelihood requires ground RRF.\n";
00947             } else {
00948                 rrf->getCounts(recTrueCounts, db);
00949             }
00950         }
00951 
00952         // Compute and follow gradient for one step
00953         int iMin = 0;
00954         int iMax = aRec ? recTrueCounts.size() : groundTrueCounts.size();
00955 
00956         if (aTrainBottomLevel) {
00957             iMin = rrf->getRoot()->getNumWeights();
00958         }
00959         if (aTrainTopLevel) {
00960             iMax = rrf->getRoot()->getNumWeights();
00961         }
00962 
00963 
00964         Array<double> gradient;
00965         for (int i = 0; i < iMin; i++) {
00966             gradient.append(0.0);
00967         }
00968 
00969         // DEBUG
00970         double sqlength = 0.0;
00971         double weightsum = 0.0;
00972         int maxIndex = 0;
00973         double maxTotal = 0.0;
00974         for (int i = iMin; i < iMax; i++) 
00975         {
00976 #if CMP_RRF
00977             // NOTE: no support for psuedo-log-likelihood
00978             if (aRec && aGround) {
00979                 if (2.0*fabs((recTrueCounts[i] - groundTrueCounts[i])/
00980                         (recTrueCounts[i] + groundTrueCounts[i])) > 0.01) {
00981                     cout << "True counts of " << i << " differ.\n";
00982                 }
00983                 if (2.0*fabs(recInfCounts[i] - groundInfCounts[i]) > 0.0001) {
00984                     cout << "Inferred counts of " << i << " differ:\n";
00985                     cout << "Rec:    " << recInfCounts[i] << endl;
00986                     cout << "Ground: " << groundInfCounts[i] << endl;
00987                 }
00988 
00989                 double recDiff = recInfCounts[i] - recTrueCounts[i];
00990                 double groundDiff = groundInfCounts[i] - groundTrueCounts[i];
00991                 if (2.0*fabs(recDiff - groundDiff) > 0.0001) {
00992                     cout << "Gradient differs on " << i << endl;
00993                     cout << "Rec:    " << recDiff << endl;
00994                     cout << "Ground: " << groundDiff << endl;
00995                 }
00996             }
00997 #endif
00998             double diff;
00999             if (infMethod == INF_PSEUDO) {
01000                 diff = groundTrueCounts[i];
01001             } else if (aRec) {
01002                 diff = recTrueCounts[i] - recInfCounts[i];
01003             } else {
01004                 diff = groundTrueCounts[i] - groundInfCounts[i];
01005             }
01006             double weight = rrf->getWeight(i);
01007             double total = diff - weight/aSigmaSq;
01008             gradient.append(total);
01009             sqlength += total * total;
01010 
01011             // DEBUG
01012             if (fabs(total) > fabs(maxTotal)) {
01013                 maxTotal = total;
01014                 maxIndex = i;
01015             }
01016             weightsum += fabs(weight);
01017         }
01018         //double norm = sqrt(sqlength);
01019 
01020 #if 0
01021         // DEBUG
01022         cout << "largest change = " << maxTotal*aAlpha << " (" << maxIndex << ")\n";
01023         cout << "sqlength = " << sqlength*aAlpha*aAlpha << endl;
01024         cout << "length = " << sqrt(sqlength*aAlpha*aAlpha) << endl;
01025         cout << "sum of weights = " << weightsum << endl;
01026 #endif
01027 
01028         for (int i = iMin; i < iMax; i++) 
01029         {
01030             // Take current weight, add normalized(?) gradient
01031             double weight = rrf->getWeight(i);
01032             rrf->setWeight(i, weight + gradient[i]*aAlpha
01033                    // *weightsum
01034                    // /norm
01035                     );
01036         }
01037 
01038         // Weights changed; cache invalid!
01039         if (aGround) {
01040             grrf->dirtyAll();
01041 #if 1
01042             trueGrrf->dirtyAll();
01043 #endif
01044         }
01045 
01046         rrf->invalidateAll();
01047 
01048         // Print out time for this iteration
01049         cout << "Time: ";
01050         Timer::printTime(cout, timer.time() - begSec); cout << endl;
01051 
01052         if (aVerbose) {
01053             cout << rrf;
01054         }
01055 
01056         if (timer.time() - lastWriteSec > 5) {
01057             ofstream fout(aoutputFilename);
01058             fout << rrf;
01059             fout.close();
01060             lastWriteSec = timer.time();
01061         }
01062     }
01063     }
01064 
01065     // One final save.
01066     ofstream fout(aoutputFilename);
01067     fout << rrf;
01068 
01069     if (aVerbose) {
01070         cout << "Final model:\n";
01071         cout << rrf;
01072     }
01073     fout.close();
01074 }
01075 
01076 
01077 void inferRRF(RRF* rrf, Database* db, const Array<int>& queryPreds)
01078 {
01079     // Initial preparations for recursive counts computation
01080     Array<Predicate*> allGroundings;
01081     Array<TruthValue> truePredValues;
01082     const Domain* domain = db->getDomain();
01083 
01084     //if (aRec || CMP_RRF) 
01085     if (1) {
01086         for (int i = 0; i < queryPreds.size(); i++) {
01087 
01088             // Create all groundings for this predicate
01089             Array<Predicate*> predGroundings;
01090             Predicate::createAllGroundings(queryPreds[i], domain, predGroundings);
01091             allGroundings.append(predGroundings);
01092 
01093             // Save all original truth values (assuming complete data for now)
01094             for (int j = 0; j < predGroundings.size(); j++)  {
01095                 truePredValues.append(db->getValue(predGroundings[j]));
01096             }
01097         }
01098     }
01099 
01100     // Initial preparations for full ground tree counts computation
01101     GroundRRF* grrf = NULL;
01102 #if 1
01103     GroundRRF* trueGrrf = NULL;
01104 #else
01105     Array<Array<bool> > trueValues;
01106 #endif
01107 
01108     if (aGround) {
01109         grrf = new GroundRRF(rrf, db);
01110 #if 1
01111         trueGrrf = new GroundRRF(rrf, db);
01112 #else
01113         // Save true values of all query predicates
01114         for (int q = 0; q < queryPreds.size(); q++) {
01115             int i = queryPreds[q];
01116             trueValues.append(Array<bool>());
01117             for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
01118                 trueValues[q].append(grrf->getPredicateValue(i,j));
01119             }
01120         }
01121 #endif
01122     }
01123 
01124     // Run inference
01125     if (aRec) {
01126         runICMInference(rrf, db, allGroundings);
01127     }
01128 
01129     if (aGround) {
01130         runICMInference(grrf, queryPreds, db, rrf);
01131     }
01132 
01133     Array<double> recInfCounts;
01134     Array<double> groundInfCounts;
01135     
01136     getGibbsLogLikelihood2(rrf, db, allGroundings);
01137 
01138 #if 0
01139     if (infMethod == INF_GIBBS) {
01140         if (aRec)    { runGibbs(rrf, db, allGroundings, recInfCounts); }
01141         if (aGround) { runGibbs(grrf, queryPreds, groundInfCounts, db); }
01142     } else if (infMethod == INF_ICM) {
01143         //if (aRec)    { rrf->getCounts(recInfCounts, db); }
01144         //if (aGround) { grrf->getCounts(groundInfCounts); }
01145         // TODO!
01146     } 
01147 
01148     // Print out all probabilities
01149     for (int i = 0; i < allGroundings.size(); i++) {
01150         allGroundings[i]->printWithStrVar(cout, domain);
01151         if (aGround) {
01152             cout << " " << groundInfCounts[i];
01153         }
01154         if (aRec) {
01155             cout << " " << recInfCounts[i];
01156         }
01157         cout << endl;
01158     }
01159 #endif
01160 }
01161 
01162 
01163 
01164 int main(int argc, char* argv[]) 
01165 {
01166     ARGS::parse(argc, argv, &cout);
01167 
01168     switch(aInferenceMethodStr[0]) {
01169         case 'i': infMethod = INF_ICM;
01170                   break;
01171         case 'g': infMethod = INF_GIBBS;
01172                   break;
01173         case 'e': infMethod = INF_EXACT;
01174                   break;
01175         case 'p': infMethod = INF_PSEUDO;
01176                   break;
01177         default:  cout << "ERROR: Unknown inference method "
01178                   << aInferenceMethodStr << endl;
01179                   exit(-1);
01180     }
01181 
01182     if (!aRec && !aGround) {
01183         aGround = true;
01184     }
01185 
01186     srand(aSeed);
01187 
01189     MLN* mln = new MLN; // required, but unused...
01190     Domain* domain = new Domain;
01191     StringHashArray* openWorldPredNames = new StringHashArray();
01192     //StringHashArray* closedWorldPredNames = new StringHashArray();
01193     StringHashArray* queryPredNames = new StringHashArray();
01194     runYYParser(mln, domain, afilename, true, 
01195             openWorldPredNames, queryPredNames, false, 
01196             false, 0, false, 
01197             NULL, true, false);
01198     //delete mln; mln = NULL;
01199 
01200     // Get a list of query predicates
01201     Array<int> queryPreds;
01202     if (!strcmp(nonEvidPredsStr, "all")) {
01203 
01204         // Skip the equality predicate
01205         for (int i = 1; i < domain->getNumPredicates(); i++) {
01206             queryPreds.append(i);
01207         }
01208     } else {
01209         StringHashArray nonEvidPredNames;
01210         if(!extractPredNames(nonEvidPredsStr, NULL, nonEvidPredNames)) {
01211             cout << "ERROR: failed to extract non-evidence predicate names.\n";
01212             return -1;
01213         }
01214 
01215         for (int i = 0; i < nonEvidPredNames.size(); i++) {
01216             int predId = domain->getPredicateId( nonEvidPredNames[i].c_str() );
01217             if (predId < 0) {
01218                 cout << "Ignoring unrecognized query predicate \""
01219                     << nonEvidPredNames[i] << "\".\n";
01220             } else {
01221                 queryPreds.append(predId);
01222             }
01223         }
01224     }
01225 
01226 
01227     RRF* rrf;
01228 
01229     if (amodelFilename) {
01230         // DEBUG
01231         cout << "Loading weights from file.\n";
01232         rrf = new RRF();
01233         ifstream modelIn(amodelFilename);
01234         rrf->load(modelIn, domain);
01235         modelIn.close();
01236 
01237     } else {
01238 #if 0
01240         rrf = RRF::generateTwoLevel(domain);
01241         Array<int> typeArity;
01242         typeArity.append(0);
01243         typeArity.append(1);
01244         for (int i = 0; i < aNumFeatures; i++) {
01245             rrf->addCompleteFeature(typeArity, domain, queryPreds, 
01246                     aPredsPerFeature);
01247         }
01248 #else
01249         cout << "ERROR: no longer supported!\n";
01250         exit(-1);
01251 #endif
01252     }
01253 
01254 #if 0
01255     if (atestConstant) {
01256         // HACK -- do leave-one-out experiments
01257         int testId = domain->getConstantId(atestConstant);
01258         if (testId < 0) {
01259             cout << "Error: unknown constant \"" << atestConstant << "\"\n";
01260         }
01261 
01262         // Make list of all query preds involving the held-out constant
01263         Array<Predicate*> groundQueryPreds;
01264         Array<int> req;
01265         req.append(testId);
01266 
01267         // For each query predicate...
01268         for (int i = 0; i < queryPreds.size(); i++) {
01269 
01270             // Iterate through all groundings that include the held-out
01271             // constant.
01272 
01273             cout << "req = ";
01274             for (int j = 0; j < req.size(); j++) {
01275                 cout << req[j] << " ";
01276             }
01277             cout << endl;
01278             // Broken...
01279            // ParentIter2 iter(req, domain->getConstantsByType(), 
01280             //    domain->getPredicateTemplate(queryPreds[i])->getNumTerms());
01281             Array<int> grounding;
01282             while (iter.hasNextGrounding()) {
01283 
01284                 // Copy the groundings to an instance of a predicate
01285                 iter.getNextGrounding(grounding);
01286                 Predicate* pred = new Predicate(
01287                         domain->getPredicateTemplate(queryPreds[i]));
01288                 for (int j = 0; j < grounding.size(); j++) {
01289                     pred->setTermToConstant(j, grounding[j]);
01290                 }
01291 
01292                 // Add the predicate to the list of all test preds
01293                 groundQueryPreds.append(pred);
01294             }
01295         }
01296 
01297         double lpl = rrf->getLogPseudoLikelihood(domain->getDB(),
01298                 groundQueryPreds);
01299 
01300         // Compute log likelihood of test preds
01301         double ll;
01302         if (aGround) {
01303             ll = getGibbsLogLikelihood2(rrf, domain->getDB(), 
01304                 groundQueryPreds);
01305         } else {
01306             ll = getGibbsLogLikelihood(rrf, domain->getDB(), 
01307                 groundQueryPreds);
01308         }
01309         cout << "lpl = " << lpl << endl;
01310         cout << "ll = " << ll << endl;
01311 
01312     } else 
01313 #endif
01314     if (aInfer) {
01315         inferRRF(rrf, domain->getDB(), queryPreds);
01316     } else if (aLPL) {
01317         if (aGround) {
01318             GroundRRF* grrf = new GroundRRF(rrf, domain->getDB());
01319             cout << "lpl = " << grrf->getLogPseudoLikelihood(queryPreds) 
01320                 << endl;
01321         } else {
01322             cout << "ERROR: pll for non-ground RRF is not yet supported!\n";
01323         }
01324     } else {
01325         if (aoutputFilename == NULL) {
01326             cout << "Warning: you should specify an output model filename "
01327                 << "using the -r option.\n";
01328         }
01329         trainRRF(rrf, domain->getDB(), queryPreds);
01330     }
01331 
01332     return 0;
01333 }

Generated on Sun Jun 7 11:55:20 2009 for Alchemy by  doxygen 1.5.1