00001 #include "mln.h"
00002 #include "fol.h"
00003 #include "rrf.h"
00004 #include "random.h"
00005 #include <math.h>
00006 #include "arguments.h"
00007 #include "infer.h"
00008 #include "timer.h"
00009 #include "gfeature.h"
00010 #include "lbfgsr.h"
00011
00012 #define BURN_IN 100
00013
00014
00015
00016 int llcount = 0;
00017
00018
00019 #define CMP_RRF 0
00020
00021 enum inferenceMethod { INF_ICM, INF_GIBBS, INF_EXACT, INF_PSEUDO, INF_MAX };
00022
00023 char* afilename = NULL;
00024 char* nonEvidPredsStr = "all";
00025 char* aInferenceMethodStr = "p";
00026 int aGibbsIters = 100;
00027 int nf = 20;
00028 int infMethod = INF_PSEUDO;
00029 double aAlpha = 0.1;
00030 bool aRec = false;
00031 bool aGround = false;
00032 int aNumFeatures = 10;
00033 int aPredsPerFeature = 5;
00034 double aSigmaSq = 100;
00035 char* atestConstant = NULL;
00036 char* amodelFilename = NULL;
00037 char* aoutputFilename = NULL;
00038 int aMaxTrainIters = 10000;
00039 bool aVerbose = false;
00040 int aSeed = 1;
00041 bool aTrainTopLevel = false;
00042 bool aTrainBottomLevel = false;
00043 bool aInfer = false;
00044 bool aLPL = false;
00045 double aSamplingFrac = 1.0;
00046 bool aPseudoFast = false;
00047
00048
00049 ARGS ARGS::Args[] =
00050 {
00051 ARGS("i", ARGS::Req, afilename, "input .mln file"),
00052 ARGS("weights", ARGS::Opt, amodelFilename,
00053 "Filename containing initial weights\n"),
00054 ARGS("r", ARGS::Opt, aoutputFilename,
00055 "File for saving learned model\n"),
00056 ARGS("ne", ARGS::Opt, nonEvidPredsStr,
00057
00058 "first-order non-evidence predicates (comma-separated with no space)"),
00059 ARGS("inf", ARGS::Opt, aInferenceMethodStr, "Inference method to use\n"
00060 "i = ICM; g = Gibbs sampling; p = pseudo-likelihood; e = Exact\n"),
00061 ARGS("frac", ARGS::Opt, aSamplingFrac,
00062 "Fraction of ground predicates to consider in psuedo-likelihood\n"),
00063 ARGS("iters", ARGS::Opt, aGibbsIters,
00064 "Number of iterations to use for Gibbs sampling\n"),
00065 ARGS("nf", ARGS::Opt, aNumFeatures, "Number of features to use\n"),
00066 ARGS("alpha", ARGS::Opt, aAlpha, "Learning rate\n"),
00067 ARGS("sigmasq", ARGS::Opt, aSigmaSq, "Large-weight penalty\n"),
00068 ARGS("rec", ARGS::Opt, aRec, "Compute counts recursively\n"),
00069 ARGS("ppf", ARGS::Opt, aPredsPerFeature,
00070 "Number of ground predicates per feature\n"),
00071 ARGS("ground", ARGS::Opt, aGround,
00072 "Compute counts using fully ground tree\n"),
00073 ARGS("test", ARGS::Opt, atestConstant,
00074 "Test constant, whose query preds are unknown\n"),
00075 ARGS("trainiters", ARGS::Opt, aMaxTrainIters,
00076 "Maximum number of training iterations\n"),
00077 ARGS("v", ARGS::Tog, aVerbose,
00078 "Verbose output\n"),
00079 ARGS("seed", ARGS::Opt, aSeed, "Random seed\n"),
00080 ARGS("top", ARGS::Tog, aTrainTopLevel,
00081 "Only train top-level weights (feature coefficients)\n"),
00082 ARGS("bottom", ARGS::Tog, aTrainBottomLevel,
00083 "Only train lower-level (within-feature) weights\n"),
00084 ARGS("infer", ARGS::Tog, aInfer,
00085 "Run inference and produce per-constant probabilities\n"),
00086 ARGS("lpl", ARGS::Tog, aLPL,
00087 "Produces pseudo-log-likelihood for query predicates\n"),
00088 ARGS("f", ARGS::Tog, aPseudoFast,
00089 "Use optimized (and broken) pseudo-likelihood gradient computation.\n"),
00090 ARGS()
00091 };
00092
00093
00094 Array<TruthValue> truthValues;
00095 void saveState(Database* db, Array<Predicate*>& queryPreds)
00096 {
00097 truthValues.clear();
00098 for (int i = 0; i < queryPreds.size(); i++) {
00099 truthValues.append(db->getValue(queryPreds[i]));
00100 }
00101 }
00102
00103 void restoreState(Database* db, Array<Predicate*>& queryPreds)
00104 {
00105 for (int i = 0; i < queryPreds.size(); i++) {
00106 db->setValue(queryPreds[i], truthValues[i]);
00107 }
00108 }
00109
00110
00111 void runICMInference(RRF* rrf, Database* db, Array<Predicate*>& queryPreds)
00112 {
00113
00114 for (int i = 0; i < queryPreds.size(); i++) {
00115 TruthValue tv = (frand() < 0.5) ? TRUE : FALSE;
00116 db->setValue(queryPreds[i], tv);
00117 }
00118 rrf->invalidateAll();
00119
00120
00121 bool changed = true;
00122 while (changed) {
00123 changed = false;
00124
00125 double currLikelihood = rrf->getLogValue(db);
00126 for (int i = 0; i < queryPreds.size(); i++) {
00127
00128
00129
00130 TruthValue oldValue = db->getValue(queryPreds[i]);
00131 TruthValue newValue = (oldValue == TRUE) ? FALSE : TRUE;
00132
00133 db->setValue(queryPreds[i], newValue);
00134 rrf->changedPredicate(queryPreds[i], db);
00135 double newLikelihood = rrf->getLogValue(db);
00136 if (newLikelihood > currLikelihood) {
00137
00138 currLikelihood = newLikelihood;
00139 changed = true;
00140 } else {
00141
00142 db->setValue(queryPreds[i], oldValue);
00143 rrf->changedPredicate(queryPreds[i], db);
00144 }
00145 }
00146 }
00147 }
00148
00149
00150 void runICMInference(GroundRRF* grrf, const Array<int>& queryPreds,
00151 Database* db, RRF* rrf)
00152 {
00153
00154 for (int q = 0; q < queryPreds.size(); q++) {
00155 #if CMP_RRF
00156 Array<Predicate*> queryGroundings;
00157 Predicate::createAllGroundings(queryPreds[q], db->getDomain(),
00158 queryGroundings);
00159 #endif
00160 for (int j = 0; j < grrf->getNumGroundings(queryPreds[q]-1); j++) {
00161 bool tv = (frand() < 0.5);
00162 grrf->setPredicateValue(queryPreds[q], j, tv);
00163 #if CMP_RRF
00164 db->setValue(queryGroundings[j], tv ? TRUE : FALSE);
00165 #endif
00166 }
00167 }
00168
00169
00170 bool changed = true;
00171 while (changed) {
00172 changed = false;
00173
00174 double currLikelihood = grrf->getLogValue();
00175 #if CMP_RRF
00176 double rrfLikelihood = rrf->getLogValue(db);
00177 if (fabs(rrfLikelihood - currLikelihood) > 0.00001) {
00178 cout << "Likelihoods differ!\n";
00179 }
00180 #endif
00181 for (int q = 0; q < queryPreds.size(); q++) {
00182
00183 int i = queryPreds[q];
00184 #if CMP_RRF
00185 Array<Predicate*> queryGroundings;
00186 Predicate::createAllGroundings(queryPreds[q], db->getDomain(),
00187 queryGroundings);
00188 #endif
00189 for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
00190
00191
00192 bool oldValue = grrf->getPredicateValue(i, j);
00193 grrf->setPredicateAndUpdate(i, j, !oldValue);
00194
00195
00196 double newLikelihood = grrf->getLogValue();
00197 #if CMP_RRF
00198 db->setValue(queryGroundings[j], oldValue ? FALSE : TRUE);
00199 double rrfLikelihood = rrf->getLogValue(db);
00200 if (fabs(rrfLikelihood - newLikelihood) > 0.00001) {
00201 cout << "Toggled likelihoods differ!\n";
00202 }
00203 #endif
00204 if (newLikelihood > currLikelihood) {
00205
00206 currLikelihood = newLikelihood;
00207 changed = true;
00208 } else {
00209
00210 grrf->setPredicateAndUpdate(i, j, oldValue);
00211
00212 #if CMP_RRF
00213 db->setValue(queryGroundings[j], oldValue ? TRUE : FALSE);
00214 #endif
00215 }
00216 }
00217 }
00218 }
00219 }
00220
00221
00222 void getExactCounts(RRF* rrf, Database* db, Array<Predicate*>& queryPreds,
00223 Array<double>& counts)
00224 {
00225
00226 rrf->getCounts(counts, db);
00227 for (int i = 0; i < counts.size(); i++) {
00228 counts[i] = 0.0;
00229 }
00230
00231
00232 ArraysAccessor<TruthValue> predValues;
00233 Array<TruthValue> truthValues;
00234 truthValues.append(TRUE);
00235 truthValues.append(FALSE);
00236
00237 for (int i = 0; i < queryPreds.size(); i++) {
00238 predValues.appendArray(&truthValues);
00239 }
00240
00241 double Z = 0.0;
00242
00243
00244 do {
00245 Array<TruthValue> truthValues;
00246 predValues.getNextCombination(truthValues);
00247 for (int i = 0; i < truthValues.size(); i++) {
00248 db->setValue(queryPreds[i], truthValues[i]);
00249 }
00250
00251 double likelihood = rrf->getValue(db);
00252 Z += likelihood;
00253
00254
00255 Array<double> newCounts;
00256 rrf->getCounts(newCounts, db);
00257 for (int j = 0; j < newCounts.size(); j++) {
00258 counts[j] += likelihood * newCounts[j];
00259 }
00260
00261
00262
00263 } while (predValues.hasNextCombination());
00264
00265
00266 for (int j = 0; j < counts.size(); j++) {
00267 counts[j] /= Z;
00268 }
00269 }
00270
00271 void getGibbsCounts(RRF* rrf, Database* db, Array<Predicate*>& queryPreds,
00272 Array<double>& counts)
00273 {
00274
00275 double oldLikelihood = rrf->getValue(db);
00276 rrf->getCounts(counts, db);
00277
00278
00279 for (int iter = 0; iter < aGibbsIters; iter++) {
00280 for (int i = 0; i < queryPreds.size(); i++) {
00281
00282
00283 TruthValue oldValue = db->getValue(queryPreds[i]);
00284 TruthValue newValue = (oldValue == TRUE) ? FALSE : TRUE;
00285
00286 db->setValue(queryPreds[i], newValue);
00287 rrf->changedPredicate(queryPreds[i], db);
00288 double newLikelihood = rrf->getValue(db);
00289 double prob = newLikelihood/(oldLikelihood + newLikelihood);
00290 if (frand() < prob) {
00291 oldLikelihood = newLikelihood;
00292 } else {
00293 db->setValue(queryPreds[i], oldValue);
00294 rrf->changedPredicate(queryPreds[i], db);
00295 }
00296 }
00297
00298
00299 Array<double> newCounts;
00300 rrf->getCounts(newCounts, db);
00301 for (int j = 0; j < newCounts.size(); j++) {
00302 counts[j] += newCounts[j];
00303 }
00304 }
00305
00306
00307 for (int j = 0; j < counts.size(); j++) {
00308 counts[j] /= ((double)aGibbsIters+1.0);
00309 }
00310 }
00311
00312 void runGibbs(RRF* rrf, Database* db, Array<Predicate*>& queryPreds,
00313 Array<double>& counts)
00314 {
00315
00316 double oldLikelihood = rrf->getValue(db);
00317 counts.clear();
00318
00319
00320 for (int iter = 0; iter < aGibbsIters; iter++) {
00321 for (int i = 0; i < queryPreds.size(); i++) {
00322
00323
00324 TruthValue oldValue = db->getValue(queryPreds[i]);
00325 TruthValue newValue = (oldValue == TRUE) ? FALSE : TRUE;
00326 TruthValue finalValue;
00327
00328 db->setValue(queryPreds[i], newValue);
00329 rrf->changedPredicate(queryPreds[i], db);
00330 double newLikelihood = rrf->getValue(db);
00331 double prob = newLikelihood/(oldLikelihood + newLikelihood);
00332 if (frand() < prob) {
00333 oldLikelihood = newLikelihood;
00334 finalValue = newValue;
00335 } else {
00336 db->setValue(queryPreds[i], oldValue);
00337 rrf->changedPredicate(queryPreds[i], db);
00338 finalValue = oldValue;
00339 }
00340
00341
00342 if (counts.size() <= i) {
00343 counts.append(0.5);
00344 }
00345 if (finalValue == TRUE) {
00346 counts[i]++;
00347 }
00348 }
00349 }
00350
00351
00352 for (int j = 0; j < counts.size(); j++) {
00353 counts[j] /= ((double)aGibbsIters+1.0);
00354 }
00355 }
00356
00357 double getGibbsLogLikelihood(RRF* rrf, Database* db,
00358 Array<Predicate*>& queryPreds)
00359 {
00360 Array<double> counts;
00361
00362 Array<TruthValue> trueValues;
00363
00364 for (int i = 0; i < queryPreds.size(); i++) {
00365 trueValues.append(db->getValue(queryPreds[i]));
00366
00367 counts.append(0.1);
00368 }
00369 double total = 0.2;
00370
00371
00372 runICMInference(rrf, db, queryPreds);
00373
00374
00375 rrf->invalidateAll();
00376 double oldLikelihood = rrf->getLogValue(db);
00377
00378 for (int iter = 0; iter < aGibbsIters + BURN_IN; iter++) {
00379 for (int i = 0; i < queryPreds.size(); i++) {
00380
00381
00382 TruthValue oldValue = db->getValue(queryPreds[i]);
00383 TruthValue newValue = (oldValue == TRUE) ? FALSE : TRUE;
00384 TruthValue finalValue;
00385
00386 db->setValue(queryPreds[i], newValue);
00387 rrf->changedPredicate(queryPreds[i], db);
00388 double newLikelihood = rrf->getLogValue(db);
00389 double prob = 1.0/(1.0 + exp(oldLikelihood - newLikelihood));
00390 if (oldLikelihood - newLikelihood > 100) {
00391 prob = 0.0;
00392 } else if (oldLikelihood - newLikelihood < -100) {
00393 prob = 1.0;
00394 }
00395 if (frand() < prob) {
00396 oldLikelihood = newLikelihood;
00397 finalValue = newValue;
00398 } else {
00399 db->setValue(queryPreds[i], oldValue);
00400 rrf->changedPredicate(queryPreds[i], db);
00401 finalValue = oldValue;
00402 }
00403
00404 if (iter >= BURN_IN && finalValue) {
00405 counts[i]++;
00406 }
00407 }
00408
00409 if (iter >= BURN_IN) {
00410 total++;
00411 }
00412 }
00413
00414 double ll = 0.0;
00415
00416
00417 for (int j = 0; j < counts.size(); j++) {
00418
00419 if (trueValues[j]) {
00420 ll += log(counts[j]/total);
00421 } else {
00422 ll += log(1.0 - counts[j]/total);
00423 }
00424
00425 cout << *(queryPreds[j]) << " = " << trueValues[j]
00426 << " (" << (counts[j]/total) << ")\n";
00427 }
00428
00429
00430
00431
00432 return ll;
00433 }
00434
00435
00436 double getGibbsLogLikelihood2(RRF* rrf, Database* db,
00437 Array<Predicate*>& queryPreds)
00438 {
00439 Array<int> queryPredId;
00440 Array<int> queryGroundId;
00441
00442 Array<double> counts;
00443 Array<TruthValue> trueValues;
00444
00445 double ll = 0.0;
00446
00447 for (int i = 0; i < queryPreds.size(); i++) {
00448 trueValues.append(db->getValue(queryPreds[i]));
00449
00450
00451 queryPredId.append(queryPreds[i]->getId());
00452 Array<int> grounding;
00453 for (int j = 0; j < queryPreds[i]->getNumTerms(); j++) {
00454 grounding.append(queryPreds[i]->getTerm(j)->getId());
00455 }
00456 queryGroundId.append(rrf->getFeature(queryPredId[i]-1)->
00457 getGroundingIndex(grounding, db));
00458
00459 counts.append(0.1);
00460 }
00461 double total = 0.2;
00462
00463
00464 for (int trial= 0; trial < 10; trial++) {
00465
00466 Array<int> qpreds;
00467 qpreds.append(queryPreds[0]->getId());
00468 GroundRRF* grrf = new GroundRRF(rrf, db);
00469 grrf->dirtyAll();
00470
00471 double best_ll = -1.0e100;
00472 for (int init = 0; init < 10; init++) {
00473 runICMInference(grrf, qpreds, db, rrf);
00474 if (grrf->getLogValue() > best_ll) {
00475 best_ll = grrf->getLogValue();
00476 saveState(db, queryPreds);
00477 }
00478 }
00479 restoreState(db, queryPreds);
00480
00481 double oldLikelihood = grrf->getLogValue();
00482 for (int iter = 0; iter < aGibbsIters + BURN_IN; iter++) {
00483 for (int i = 0; i < queryPreds.size(); i++) {
00484
00485
00486 bool oldValue = grrf->getPredicateValue(
00487 queryPredId[i], queryGroundId[i]);
00488 bool finalValue;
00489
00490 grrf->setPredicateAndUpdate(queryPredId[i], queryGroundId[i],
00491 !oldValue);
00492 double newLikelihood = grrf->getLogValue();
00493
00494 double prob = 1.0/(1.0 + exp(oldLikelihood - newLikelihood));
00495 if (oldLikelihood - newLikelihood > 100) {
00496 prob = 0.0;
00497 } else if (oldLikelihood - newLikelihood < -100) {
00498 prob = 1.0;
00499 }
00500
00501 if (frand() < prob) {
00502 oldLikelihood = newLikelihood;
00503 finalValue = !oldValue;
00504 } else {
00505 grrf->setPredicateAndUpdate(queryPredId[i],
00506 queryGroundId[i], oldValue);
00507 finalValue = oldValue;
00508 }
00509
00510 if (iter >= BURN_IN && finalValue) {
00511 counts[i]++;
00512 }
00513 }
00514
00515 if (iter >= BURN_IN) {
00516 total++;
00517 }
00518 }
00519 }
00520
00521
00522 for (int j = 0; j < counts.size(); j++) {
00523
00524 double curr_ll;
00525 if (trueValues[j]) {
00526 curr_ll = log(counts[j]/total);
00527 } else {
00528 curr_ll = log(1.0 - counts[j]/total);
00529 }
00530 ll += curr_ll;
00531
00532 queryPreds[j]->printWithStrVar(cout, db->getDomain());
00533 cout << " = " << trueValues[j]
00534 << " (" << (counts[j]/total) << ") " << curr_ll << endl;
00535
00536 }
00537
00538 return ll;
00539 }
00540
00541
00542 #define USE_LOGS 1
00543
00544 void getGibbsCounts(GroundRRF* grrf, const Array<int>& queryPreds,
00545 Array<double>& counts, Database* db, RRF* rrf)
00546 {
00547
00548 #if USE_LOGS
00549 double oldLikelihood = grrf->getLogValue();
00550 #else
00551 double oldLikelihood = grrf->getValue();
00552 #endif
00553 grrf->getCounts(counts);
00554
00555 #if CMP_RRF
00556 Array<double> rrfCounts;
00557
00558 for (int q = 0; q < queryPreds.size(); q++) {
00559 int i = queryPreds[q];
00560 Array<Predicate*> queryGroundings;
00561 Predicate::createAllGroundings(i, db->getDomain(), queryGroundings);
00562 for (int j = 0; j < queryGroundings.size(); j++) {
00563 bool value = grrf->getPredicateValue(i,j);
00564 db->setValue(queryGroundings[j], value ? TRUE : FALSE);
00565 }
00566 }
00567 #endif
00568
00569 Array<double> newCounts;
00570 for (int iter = 0; iter < aGibbsIters; iter++) {
00571 for (int q = 0; q < queryPreds.size(); q++) {
00572 int i = queryPreds[q];
00573
00574
00575 for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
00576
00577
00578 bool oldValue = grrf->getPredicateValue(i,j);
00579 grrf->setPredicateAndUpdate(i,j,!oldValue);
00580
00581 #if CMP_RRF
00582 db->setValue(queryGroundings[j], oldValue ? FALSE : TRUE);
00583 #endif
00584
00585 #if USE_LOGS
00586 double newLikelihood = grrf->getLogValue();
00587 #else
00588 double newLikelihood = grrf->getValue();
00589 #endif
00590
00591 #if CMP_RRF // TODO -- fix this...
00592 double rrfLikelihood = rrf->getValue(db);
00593 if (fabs(rrfLikelihood - newLikelihood) > 0.00001) {
00594 cout << "Different likelihoods: \n";
00595 cout << rrfLikelihood << endl;
00596 cout << newLikelihood << endl;
00597 }
00598 #endif
00599
00600 #if USE_LOGS
00601 double prob = 1.0/(1.0 + exp(oldLikelihood - newLikelihood));
00602 if (oldLikelihood - newLikelihood > 100) {
00603 prob = 0.0;
00604 } else if (oldLikelihood - newLikelihood < -100) {
00605 prob = 1.0;
00606 }
00607 #else
00608 double prob = newLikelihood/(oldLikelihood + newLikelihood);
00609 #endif
00610
00611 if (frand() < prob) {
00612 oldLikelihood = newLikelihood;
00613 } else {
00614 grrf->setPredicateAndUpdate(i,j,oldValue);
00615
00616 #if CMP_RRF
00617 db->setValue(queryGroundings[j], oldValue ? TRUE : FALSE);
00618 #endif
00619 }
00620
00621 #if CMP_RRF
00622 db->setValue(queryGroundings[j],
00623 grrf->getPredicateValue(i,j) ? TRUE : FALSE);
00624 #endif
00625 }
00626 }
00627
00628 if (grrf == NULL) {
00629 cout << "grrf is somehow NULL!\n";
00630 }
00631
00632
00633 grrf->dirtyAll();
00634 grrf->getCounts(newCounts);
00635 #if CMP_RRF
00636 rrf->getCounts(rrfCounts, db);
00637 #endif
00638 for (int c = 0; c < newCounts.size(); c++) {
00639 #if CMP_RRF
00640 if (fabs(newCounts[c] - rrfCounts[c]) > 0.000001) {
00641 cout << "Different counts for " << c << ":\n";
00642 cout << "Ground: " << newCounts[c] << endl;
00643 cout << "True: " << rrfCounts[c] << endl;
00644 }
00645 #endif
00646 counts[c] += newCounts[c];
00647 }
00648 }
00649
00650
00651 for (int c = 0; c < counts.size(); c++) {
00652 counts[c] /= ((double)aGibbsIters+1.0);
00653 }
00654 }
00655
00656
00657 void runGibbs(GroundRRF* grrf, const Array<int>& queryPreds,
00658 Array<double>& counts, Database* db)
00659 {
00660
00661 double oldLikelihood = grrf->getLogValue();
00662
00663 counts.clear();
00664 for (int iter = 0; iter < aGibbsIters; iter++) {
00665 int countIndex = 0;
00666 for (int q = 0; q < queryPreds.size(); q++) {
00667 int i = queryPreds[q];
00668
00669
00670 for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
00671
00672
00673 bool oldValue = grrf->getPredicateValue(i,j);
00674 grrf->setPredicateAndUpdate(i,j,!oldValue);
00675
00676 bool finalValue;
00677
00678 double newLikelihood = grrf->getLogValue();
00679 double prob = 1.0/(1.0 + exp(oldLikelihood - newLikelihood));
00680 if (oldLikelihood - newLikelihood > 100) {
00681 prob = 0.0;
00682 } else if (oldLikelihood - newLikelihood < -100) {
00683 prob = 1.0;
00684 }
00685
00686 if (frand() < prob) {
00687 oldLikelihood = newLikelihood;
00688 finalValue = !oldValue;
00689 } else {
00690 grrf->setPredicateAndUpdate(i,j,oldValue);
00691 finalValue = oldValue;
00692 }
00693
00694 if (counts.size() <= countIndex) {
00695 counts.append(0.5);
00696 }
00697 if (finalValue) {
00698 counts[countIndex]++;
00699 }
00700 countIndex++;
00701 }
00702 }
00703
00704 if (grrf == NULL) {
00705 cout << "grrf is somehow NULL!\n";
00706 }
00707 }
00708
00709
00710 for (int c = 0; c < counts.size(); c++) {
00711 counts[c] /= ((double)aGibbsIters+1.0);
00712 }
00713 }
00714
00715 void scrambleWeights(RRF* rrf)
00716 {
00717
00718 Array<Feature*> features = rrf->getFeatureArray();
00719 for (int i = 0; i < features.size(); i++) {
00720 for (int j = 0; j < features[i]->getNumWeights(); j++) {
00721
00722 double randWt = frand() - 0.5;
00723 features[i]->setWeight(j, randWt);
00724 }
00725 }
00726 }
00727
00728 void trainRRF(RRF* rrf, Database* db, const Array<int>& queryPreds)
00729 {
00730
00731
00732
00733
00734 Array<Predicate*> allGroundings;
00735 Array<TruthValue> truePredValues;
00736
00737 if (aRec || CMP_RRF) {
00738 const Domain* domain = db->getDomain();
00739 for (int i = 0; i < queryPreds.size(); i++) {
00740
00741
00742 Array<Predicate*> predGroundings;
00743 Predicate::createAllGroundings(queryPreds[i], domain, predGroundings);
00744 allGroundings.append(predGroundings);
00745
00746
00747 for (int j = 0; j < predGroundings.size(); j++) {
00748 truePredValues.append(db->getValue(predGroundings[j]));
00749 }
00750 }
00751 }
00752
00753
00754 GroundRRF* grrf = NULL;
00755 #if 1
00756 GroundRRF* trueGrrf = NULL;
00757 #else
00758 Array<Array<bool> > trueValues;
00759 #endif
00760
00761 if (aGround) {
00762 grrf = new GroundRRF(rrf, db);
00763 #if 1
00764 trueGrrf = new GroundRRF(rrf, db);
00765 #else
00766
00767 for (int q = 0; q < queryPreds.size(); q++) {
00768 int i = queryPreds[q];
00769 trueValues.append(Array<bool>());
00770 for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
00771 trueValues[q].append(grrf->getPredicateValue(i,j));
00772 }
00773 }
00774 #endif
00775 }
00776
00777
00778 if (infMethod == INF_PSEUDO) {
00779
00780
00781 cout << "ground lpl = "
00782 << trueGrrf->getLogPseudoLikelihood(queryPreds) << endl;
00783
00784
00785 int numWeights = rrf->getNumWeights();
00786 double wts[numWeights+1];
00787 for (int i = 0; i < numWeights; i++) {
00788 wts[i+1] = rrf->getWeight(i);
00789 }
00790
00791
00792 int minWt = 0;
00793 int maxWt = numWeights;
00794
00795 if (aTrainBottomLevel) {
00796 minWt = rrf->getRoot()->getNumWeights();
00797 }
00798 if (aTrainTopLevel) {
00799 maxWt = rrf->getRoot()->getNumWeights();
00800 }
00801
00802
00803 int iter;
00804 bool error;
00805 LBFGSR solver(aMaxTrainIters, 1e-5, aSamplingFrac, rrf, trueGrrf,
00806 queryPreds, maxWt - minWt, minWt, aSigmaSq, aPseudoFast);
00807 solver.minimize(wts+minWt, iter, error);
00808 cout << "Iters required: " << iter << endl;
00809
00810
00811 for (int i = 0; i < numWeights; i++) {
00812 rrf->setWeight(i, wts[i+1]);
00813 }
00814
00815 trueGrrf->dirtyAll();
00816
00817
00818 cout << "ground lpl = "
00819 << trueGrrf->getLogPseudoLikelihood(queryPreds) << endl;
00820 } else {
00821
00822
00823
00824 Timer timer;
00825 double lastWriteSec = timer.time();
00826 for (int iter = 0; iter < aMaxTrainIters; iter++) {
00827
00828 double begSec = timer.time();
00829
00830 #if 0
00831 for (int i = 0; i < rrf->getNumFeatures(); i++) {
00832 cout << "F" << i << ":";
00833 Feature* feature = rrf->getFeature(i);
00834 for (int j = 0; j < feature->getNumWeights(); j++) {
00835 cout << " " << feature->getWeight(j);
00836 }
00837 cout << "\n";
00838 }
00839 #endif
00840
00841
00842 if (aRec) {
00843 cout << "rec lpl = " << rrf->getLogPseudoLikelihood(db, queryPreds)
00844 << endl;
00845 }
00846 if (aGround) {
00847 #if 0
00848 cout << "ground lpl = " << grrf->getLogPseudoLikelihood(queryPreds)
00849 << endl;
00850 #else
00851 cout << "ground lpl = "
00852 << trueGrrf->getLogPseudoLikelihood(queryPreds) << endl;
00853 cout << "ground ll = "
00854 << trueGrrf->getLogValue() << endl;
00855 #if 0
00856
00857 if (llcount++ % 100 == 0) {
00858 cout << "ll = " << log(rrf->getExactConditionalLikelihood(db, queryPreds)) << endl;
00859 }
00860 #endif
00861 #endif
00862 #if CMP_RRF
00863 cout << "rec lpl = " << rrf->getLogPseudoLikelihood(db, queryPreds)
00864 << endl;
00865 #endif
00866 }
00867 cout << "wll = " << rrf->getWeightLogLikelihood(aSigmaSq) << endl;
00868
00869
00870 if (aRec) {
00871 if (infMethod != INF_EXACT) {
00872 runICMInference(rrf, db, allGroundings);
00873 }
00874 }
00875 if (aGround) {
00876 if (infMethod == INF_ICM) {
00877 runICMInference(grrf, queryPreds, db, rrf);
00878 }
00879 }
00880
00881
00882 Array<double> groundInfCounts;
00883 Array<double> recInfCounts;
00884
00885 if (infMethod == INF_GIBBS) {
00886 if (aRec) { getGibbsCounts(rrf, db, allGroundings, recInfCounts); }
00887 if (aGround) { getGibbsCounts(grrf, queryPreds, groundInfCounts,
00888 db, rrf); }
00889 } else if (infMethod == INF_ICM) {
00890 if (aRec) { rrf->getCounts(recInfCounts, db); }
00891 if (aGround) { grrf->getCounts(groundInfCounts); }
00892 } else if (infMethod == INF_EXACT) {
00893 if (aRec) { getExactCounts(rrf, db, allGroundings, recInfCounts); }
00894
00895 if (aGround) { cout << "ERROR: exact inference not supported for ground network.\n"; }
00896 } else if (infMethod == INF_PSEUDO) {
00897
00898 } else {
00899 cout << "ERROR: unknown inference method " << infMethod << endl;
00900 }
00901
00902
00903 #if 0
00904 if (aGround) {
00905 for (int q = 0; q < queryPreds.size(); q++) {
00906 for (int j = 0; j < trueValues[q].size(); j++) {
00907 grrf->setPredicateValue(queryPreds[q], j, trueValues[q][j]);
00908 }
00909 }
00910 }
00911 #endif
00912
00913 if (aRec || CMP_RRF)
00914 {
00915
00916 for (int i = 0; i < allGroundings.size(); i++) {
00917 db->setValue(allGroundings[i], truePredValues[i]);
00918 }
00919 rrf->invalidateAll();
00920 }
00921
00922
00923 Array<double> groundTrueCounts;
00924 Array<double> recTrueCounts;
00925
00926 if (aGround) {
00927 #if 0
00928 grrf->getCounts(groundTrueCounts);
00929 #else
00930 if (infMethod == INF_PSEUDO) {
00931 if (aPseudoFast) {
00932 trueGrrf->getPseudoCountsFast(groundTrueCounts,
00933 queryPreds, aSamplingFrac);
00934 } else {
00935 trueGrrf->getPseudoCounts(groundTrueCounts, queryPreds,
00936 aSamplingFrac);
00937 }
00938 } else {
00939 trueGrrf->getCounts(groundTrueCounts);
00940 }
00941 #endif
00942 }
00943 if (aRec) {
00944 if (infMethod == INF_PSEUDO) {
00945
00946 cout << "ERROR: pseudo-likelihood requires ground RRF.\n";
00947 } else {
00948 rrf->getCounts(recTrueCounts, db);
00949 }
00950 }
00951
00952
00953 int iMin = 0;
00954 int iMax = aRec ? recTrueCounts.size() : groundTrueCounts.size();
00955
00956 if (aTrainBottomLevel) {
00957 iMin = rrf->getRoot()->getNumWeights();
00958 }
00959 if (aTrainTopLevel) {
00960 iMax = rrf->getRoot()->getNumWeights();
00961 }
00962
00963
00964 Array<double> gradient;
00965 for (int i = 0; i < iMin; i++) {
00966 gradient.append(0.0);
00967 }
00968
00969
00970 double sqlength = 0.0;
00971 double weightsum = 0.0;
00972 int maxIndex = 0;
00973 double maxTotal = 0.0;
00974 for (int i = iMin; i < iMax; i++)
00975 {
00976 #if CMP_RRF
00977
00978 if (aRec && aGround) {
00979 if (2.0*fabs((recTrueCounts[i] - groundTrueCounts[i])/
00980 (recTrueCounts[i] + groundTrueCounts[i])) > 0.01) {
00981 cout << "True counts of " << i << " differ.\n";
00982 }
00983 if (2.0*fabs(recInfCounts[i] - groundInfCounts[i]) > 0.0001) {
00984 cout << "Inferred counts of " << i << " differ:\n";
00985 cout << "Rec: " << recInfCounts[i] << endl;
00986 cout << "Ground: " << groundInfCounts[i] << endl;
00987 }
00988
00989 double recDiff = recInfCounts[i] - recTrueCounts[i];
00990 double groundDiff = groundInfCounts[i] - groundTrueCounts[i];
00991 if (2.0*fabs(recDiff - groundDiff) > 0.0001) {
00992 cout << "Gradient differs on " << i << endl;
00993 cout << "Rec: " << recDiff << endl;
00994 cout << "Ground: " << groundDiff << endl;
00995 }
00996 }
00997 #endif
00998 double diff;
00999 if (infMethod == INF_PSEUDO) {
01000 diff = groundTrueCounts[i];
01001 } else if (aRec) {
01002 diff = recTrueCounts[i] - recInfCounts[i];
01003 } else {
01004 diff = groundTrueCounts[i] - groundInfCounts[i];
01005 }
01006 double weight = rrf->getWeight(i);
01007 double total = diff - weight/aSigmaSq;
01008 gradient.append(total);
01009 sqlength += total * total;
01010
01011
01012 if (fabs(total) > fabs(maxTotal)) {
01013 maxTotal = total;
01014 maxIndex = i;
01015 }
01016 weightsum += fabs(weight);
01017 }
01018
01019
01020 #if 0
01021
01022 cout << "largest change = " << maxTotal*aAlpha << " (" << maxIndex << ")\n";
01023 cout << "sqlength = " << sqlength*aAlpha*aAlpha << endl;
01024 cout << "length = " << sqrt(sqlength*aAlpha*aAlpha) << endl;
01025 cout << "sum of weights = " << weightsum << endl;
01026 #endif
01027
01028 for (int i = iMin; i < iMax; i++)
01029 {
01030
01031 double weight = rrf->getWeight(i);
01032 rrf->setWeight(i, weight + gradient[i]*aAlpha
01033
01034
01035 );
01036 }
01037
01038
01039 if (aGround) {
01040 grrf->dirtyAll();
01041 #if 1
01042 trueGrrf->dirtyAll();
01043 #endif
01044 }
01045
01046 rrf->invalidateAll();
01047
01048
01049 cout << "Time: ";
01050 Timer::printTime(cout, timer.time() - begSec); cout << endl;
01051
01052 if (aVerbose) {
01053 cout << rrf;
01054 }
01055
01056 if (timer.time() - lastWriteSec > 5) {
01057 ofstream fout(aoutputFilename);
01058 fout << rrf;
01059 fout.close();
01060 lastWriteSec = timer.time();
01061 }
01062 }
01063 }
01064
01065
01066 ofstream fout(aoutputFilename);
01067 fout << rrf;
01068
01069 if (aVerbose) {
01070 cout << "Final model:\n";
01071 cout << rrf;
01072 }
01073 fout.close();
01074 }
01075
01076
01077 void inferRRF(RRF* rrf, Database* db, const Array<int>& queryPreds)
01078 {
01079
01080 Array<Predicate*> allGroundings;
01081 Array<TruthValue> truePredValues;
01082 const Domain* domain = db->getDomain();
01083
01084
01085 if (1) {
01086 for (int i = 0; i < queryPreds.size(); i++) {
01087
01088
01089 Array<Predicate*> predGroundings;
01090 Predicate::createAllGroundings(queryPreds[i], domain, predGroundings);
01091 allGroundings.append(predGroundings);
01092
01093
01094 for (int j = 0; j < predGroundings.size(); j++) {
01095 truePredValues.append(db->getValue(predGroundings[j]));
01096 }
01097 }
01098 }
01099
01100
01101 GroundRRF* grrf = NULL;
01102 #if 1
01103 GroundRRF* trueGrrf = NULL;
01104 #else
01105 Array<Array<bool> > trueValues;
01106 #endif
01107
01108 if (aGround) {
01109 grrf = new GroundRRF(rrf, db);
01110 #if 1
01111 trueGrrf = new GroundRRF(rrf, db);
01112 #else
01113
01114 for (int q = 0; q < queryPreds.size(); q++) {
01115 int i = queryPreds[q];
01116 trueValues.append(Array<bool>());
01117 for (int j = 0; j < grrf->getNumGroundings(i-1); j++) {
01118 trueValues[q].append(grrf->getPredicateValue(i,j));
01119 }
01120 }
01121 #endif
01122 }
01123
01124
01125 if (aRec) {
01126 runICMInference(rrf, db, allGroundings);
01127 }
01128
01129 if (aGround) {
01130 runICMInference(grrf, queryPreds, db, rrf);
01131 }
01132
01133 Array<double> recInfCounts;
01134 Array<double> groundInfCounts;
01135
01136 getGibbsLogLikelihood2(rrf, db, allGroundings);
01137
01138 #if 0
01139 if (infMethod == INF_GIBBS) {
01140 if (aRec) { runGibbs(rrf, db, allGroundings, recInfCounts); }
01141 if (aGround) { runGibbs(grrf, queryPreds, groundInfCounts, db); }
01142 } else if (infMethod == INF_ICM) {
01143
01144
01145
01146 }
01147
01148
01149 for (int i = 0; i < allGroundings.size(); i++) {
01150 allGroundings[i]->printWithStrVar(cout, domain);
01151 if (aGround) {
01152 cout << " " << groundInfCounts[i];
01153 }
01154 if (aRec) {
01155 cout << " " << recInfCounts[i];
01156 }
01157 cout << endl;
01158 }
01159 #endif
01160 }
01161
01162
01163
01164 int main(int argc, char* argv[])
01165 {
01166 ARGS::parse(argc, argv, &cout);
01167
01168 switch(aInferenceMethodStr[0]) {
01169 case 'i': infMethod = INF_ICM;
01170 break;
01171 case 'g': infMethod = INF_GIBBS;
01172 break;
01173 case 'e': infMethod = INF_EXACT;
01174 break;
01175 case 'p': infMethod = INF_PSEUDO;
01176 break;
01177 default: cout << "ERROR: Unknown inference method "
01178 << aInferenceMethodStr << endl;
01179 exit(-1);
01180 }
01181
01182 if (!aRec && !aGround) {
01183 aGround = true;
01184 }
01185
01186 srand(aSeed);
01187
01189 MLN* mln = new MLN;
01190 Domain* domain = new Domain;
01191 StringHashArray* openWorldPredNames = new StringHashArray();
01192
01193 StringHashArray* queryPredNames = new StringHashArray();
01194 runYYParser(mln, domain, afilename, true,
01195 openWorldPredNames, queryPredNames, false,
01196 false, 0, false,
01197 NULL, true, false);
01198
01199
01200
01201 Array<int> queryPreds;
01202 if (!strcmp(nonEvidPredsStr, "all")) {
01203
01204
01205 for (int i = 1; i < domain->getNumPredicates(); i++) {
01206 queryPreds.append(i);
01207 }
01208 } else {
01209 StringHashArray nonEvidPredNames;
01210 if(!extractPredNames(nonEvidPredsStr, NULL, nonEvidPredNames)) {
01211 cout << "ERROR: failed to extract non-evidence predicate names.\n";
01212 return -1;
01213 }
01214
01215 for (int i = 0; i < nonEvidPredNames.size(); i++) {
01216 int predId = domain->getPredicateId( nonEvidPredNames[i].c_str() );
01217 if (predId < 0) {
01218 cout << "Ignoring unrecognized query predicate \""
01219 << nonEvidPredNames[i] << "\".\n";
01220 } else {
01221 queryPreds.append(predId);
01222 }
01223 }
01224 }
01225
01226
01227 RRF* rrf;
01228
01229 if (amodelFilename) {
01230
01231 cout << "Loading weights from file.\n";
01232 rrf = new RRF();
01233 ifstream modelIn(amodelFilename);
01234 rrf->load(modelIn, domain);
01235 modelIn.close();
01236
01237 } else {
01238 #if 0
01240 rrf = RRF::generateTwoLevel(domain);
01241 Array<int> typeArity;
01242 typeArity.append(0);
01243 typeArity.append(1);
01244 for (int i = 0; i < aNumFeatures; i++) {
01245 rrf->addCompleteFeature(typeArity, domain, queryPreds,
01246 aPredsPerFeature);
01247 }
01248 #else
01249 cout << "ERROR: no longer supported!\n";
01250 exit(-1);
01251 #endif
01252 }
01253
01254 #if 0
01255 if (atestConstant) {
01256
01257 int testId = domain->getConstantId(atestConstant);
01258 if (testId < 0) {
01259 cout << "Error: unknown constant \"" << atestConstant << "\"\n";
01260 }
01261
01262
01263 Array<Predicate*> groundQueryPreds;
01264 Array<int> req;
01265 req.append(testId);
01266
01267
01268 for (int i = 0; i < queryPreds.size(); i++) {
01269
01270
01271
01272
01273 cout << "req = ";
01274 for (int j = 0; j < req.size(); j++) {
01275 cout << req[j] << " ";
01276 }
01277 cout << endl;
01278
01279
01280
01281 Array<int> grounding;
01282 while (iter.hasNextGrounding()) {
01283
01284
01285 iter.getNextGrounding(grounding);
01286 Predicate* pred = new Predicate(
01287 domain->getPredicateTemplate(queryPreds[i]));
01288 for (int j = 0; j < grounding.size(); j++) {
01289 pred->setTermToConstant(j, grounding[j]);
01290 }
01291
01292
01293 groundQueryPreds.append(pred);
01294 }
01295 }
01296
01297 double lpl = rrf->getLogPseudoLikelihood(domain->getDB(),
01298 groundQueryPreds);
01299
01300
01301 double ll;
01302 if (aGround) {
01303 ll = getGibbsLogLikelihood2(rrf, domain->getDB(),
01304 groundQueryPreds);
01305 } else {
01306 ll = getGibbsLogLikelihood(rrf, domain->getDB(),
01307 groundQueryPreds);
01308 }
01309 cout << "lpl = " << lpl << endl;
01310 cout << "ll = " << ll << endl;
01311
01312 } else
01313 #endif
01314 if (aInfer) {
01315 inferRRF(rrf, domain->getDB(), queryPreds);
01316 } else if (aLPL) {
01317 if (aGround) {
01318 GroundRRF* grrf = new GroundRRF(rrf, domain->getDB());
01319 cout << "lpl = " << grrf->getLogPseudoLikelihood(queryPreds)
01320 << endl;
01321 } else {
01322 cout << "ERROR: pll for non-ground RRF is not yet supported!\n";
01323 }
01324 } else {
01325 if (aoutputFilename == NULL) {
01326 cout << "Warning: you should specify an output model filename "
01327 << "using the -r option.\n";
01328 }
01329 trainRRF(rrf, domain->getDB(), queryPreds);
01330 }
01331
01332 return 0;
01333 }