00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #include <fstream>
00067 #include <iostream>
00068 #include "arguments.h"
00069 #include "fol.h"
00070 #include "learnwts.h"
00071
00072 #include "structlearn.h"
00073
00074
00075 char* inMLNFiles = NULL;
00076 char* outMLNFile = NULL;
00077 char* dbFiles = NULL;
00078 char* nonEvidPredsStr = NULL;
00079 bool multipleDatabases = false;
00080
00081 int beamSize = 5;
00082 double minWt = 0.01;
00083 double penalty = 0.01;
00084 int maxVars = 6;
00085 int maxNumPredicates = 6;
00086 int cacheSize = 500;
00087
00088 bool noSampleClauses = false;
00089 double ddelta = 0.05;
00090 double epsilon = 0.2;
00091 int minClauseSamples = -1;
00092 int maxClauseSamples = -1;
00093
00094 bool noSampleGndPreds = false;
00095 double fraction = 0.8;
00096 int minGndPredSamples = -1;
00097 int maxGndPredSamples = -1;
00098
00099 bool noPrior = false;
00100 double priorMean = 0;
00101 double priorStdDev = 100;
00102
00103 int lbMaxIter = 10000;
00104 double lbConvThresh = 1e-5;
00105 int looseMaxIter = 10;
00106 double looseConvThresh = 1e-3;
00107
00108 int numEstBestClauses = 10;
00109 bool noWtPredsEqually = false;
00110 bool startFromEmptyMLN = false;
00111 bool tryAllFlips = false;
00112 int bestGainUnchangedLimit = 2;
00113
00114 bool structGradDescent = false;
00115 bool withEM = false;
00116
00117 ARGS ARGS::Args[] =
00118 {
00119 ARGS("i", ARGS::Req, inMLNFiles,
00120 "Comma-separated input .mln files. (With the -multipleDatabases "
00121 "option, the second file to the last one are used to contain constants "
00122 "from different domains, and they correspond to the .db files specified "
00123 "with the -t option.)"),
00124
00125 ARGS("o", ARGS::Req, outMLNFile,
00126 "Output .mln file containing learned formulas and weights."),
00127
00128 ARGS("t", ARGS::Req, dbFiles,
00129 "Comma-separated .db files containing the training database "
00130 "(of true/false ground atoms), including function definitions, "
00131 "e.g. ai.db,graphics.db,languages.db."),
00132
00133 ARGS("ne", ARGS::Opt, nonEvidPredsStr,
00134 "[all predicates] Non-evidence predicates "
00135 "(comma-separated with no space), e.g., cancer,smokes,friends."),
00136
00137 ARGS("multipleDatabases", ARGS::Tog, multipleDatabases,
00138 "If specified, each .db file belongs to a separate domain; "
00139 "otherwise all .db files belong to the same domain."),
00140
00141 ARGS("beamSize", ARGS::Opt, beamSize, "[5] Size of beam in beam search."),
00142
00143 ARGS("minWt", ARGS::Opt, minWt,
00144 "[0.01] Candidate clauses are discarded if "
00145 "their absolute weights fall below this."),
00146
00147 ARGS("penalty", ARGS::Opt, penalty,
00148 "[0.01] Each difference between the current and previous version of a "
00149 "candidate clause penalizes the (weighted) pseudo-log-likelihood "
00150 "by this amount."),
00151
00152 ARGS("maxVars", ARGS::Opt, maxVars,
00153 "[6] Maximum number of variables in learned clauses."),
00154
00155 ARGS("maxNumPredicates", ARGS::Opt, maxNumPredicates,
00156 "[6] Maximum number of predicates in learned clauses."),
00157
00158 ARGS("cacheSize", ARGS::Opt, cacheSize,
00159 "[500] Size in megabytes of the cache that is used to store the clauses "
00160 "(and their counts) that are created during structure learning."),
00161
00162 ARGS("noSampleClauses", ARGS::Tog, noSampleClauses,
00163 "If specified, compute a clause's number of true groundings exactly, "
00164 "and do not estimate it by sampling its groundings. If not specified, "
00165 "estimate the number by sampling."),
00166
00167 ARGS("delta", ARGS::Opt, ddelta,
00168 "[0.05] (Used only if sampling clauses.) "
00169 "The probability that an estimate a clause's number of true groundings "
00170 "is off by more than epsilon error is less than this value. "
00171 "Used to determine the number of samples of the clause's groundings "
00172 "to draw."),
00173
00174 ARGS("epsilon", ARGS::Opt, epsilon,
00175 "[0.2] (Used only if sampling clauses.) "
00176 "Fractional error from a clause's actual number of true groundings. "
00177 "Used to determine the number of samples of the clause's groundings "
00178 "to draw."),
00179
00180 ARGS("minClauseSamples", ARGS::Opt, minClauseSamples,
00181 "[-1] (Used only if sampling clauses.) "
00182 "Minimum number of samples of a clause's groundings to draw. "
00183 "(-1: no minimum)"),
00184
00185 ARGS("maxClauseSamples", ARGS::Opt, maxClauseSamples,
00186 "[-1] (Used only if sampling clauses.) "
00187 "Maximum number of samples of a clause's groundings to draw. "
00188 "(-1: no maximum)"),
00189
00190 ARGS("noSampleAtoms", ARGS::Tog, noSampleGndPreds,
00191 "If specified, do not estimate the (weighted) pseudo-log-likelihood by "
00192 "sampling ground atoms; otherwise, estimate the value by sampling."),
00193
00194 ARGS("fractAtoms", ARGS::Opt, fraction,
00195 "[0.8] (Used only if sampling ground atoms.) "
00196 "Fraction of each predicate's ground atoms to draw."),
00197
00198 ARGS("minAtomSamples", ARGS::Opt, minGndPredSamples,
00199 "[-1] (Used only if sampling ground atoms.) "
00200 "Minimum number of each predicate's ground atoms to draw. "
00201 "(-1: no minimum)"),
00202
00203 ARGS("maxAtomSamples", ARGS::Opt, maxGndPredSamples,
00204 "[-1] (Used only if sampling ground atoms.) "
00205 "Maximum number of each predicate's ground atoms to draw. "
00206 "(-1: no maximum)"),
00207
00208 ARGS("noPrior", ARGS::Tog, noPrior, "No Gaussian priors on formula weights."),
00209
00210 ARGS("priorMean", ARGS::Opt, priorMean,
00211 "[0] Means of Gaussian priors on formula weights. By default, "
00212 "for each formula, it is the weight given in the .mln input file, "
00213 "or fraction thereof if the formula turns into multiple clauses. "
00214 "This mean applies if no weight is given in the .mln file."),
00215
00216 ARGS("priorStdDev", ARGS::Opt, priorStdDev,
00217 "[100] Standard deviations of Gaussian priors on clause weights."),
00218
00219 ARGS("tightMaxIter", ARGS::Opt, lbMaxIter,
00220 "[10000] Max number of iterations to run L-BFGS-B, "
00221 "the algorithm used to optimize the (weighted) pseudo-log-likelihood."),
00222
00223 ARGS("tightConvThresh", ARGS::Opt, lbConvThresh,
00224 "[1e-5] Fractional change in (weighted) pseudo-log-likelihood at which "
00225 "L-BFGS-B terminates."),
00226
00227 ARGS("looseMaxIter", ARGS::Opt, looseMaxIter,
00228 "[10] Max number of iterations to run L-BFGS-B "
00229 "when evaluating candidate clauses."),
00230
00231 ARGS("looseConvThresh", ARGS::Opt, looseConvThresh,
00232 "[1e-3] Fractional change in (weighted) pseudo-log-likelihood at which "
00233 "L-BFGS-B terminates when evaluating candidate clauses."),
00234
00235 ARGS("numClausesReEval", ARGS::Opt, numEstBestClauses,
00236 "[10] Keep this number of candidate clauses with the highest estimated "
00237 "scores, and re-evaluate their scores precisely."),
00238
00239 ARGS("noWtPredsEqually", ARGS::Tog, noWtPredsEqually,
00240 "If specified, each predicate is not weighted equally. This means that "
00241 "high-arity predicates contribute more to the pseudo-log-likelihood "
00242 "than low-arity ones. If not specified, each predicate is given equal "
00243 "weight in the weighted pseudo-log-likelihood."),
00244
00245 ARGS("startFromEmptyMLN", ARGS::Tog, startFromEmptyMLN,
00246 "If specified, start structure learning from an empty MLN. "
00247 "If the input .mln contains formulas, they will be added to the "
00248 "candidate clauses created in the first step of beam search. "
00249 "If not specified, begin structure learning from the input .mln file."),
00250
00251 ARGS("tryAllFlips", ARGS::Tog, tryAllFlips,
00252 "If specified, the structure learning algorithm tries to flip "
00253 "the predicate signs of the formulas in the input .mln file "
00254 "in all possible ways"),
00255
00256 ARGS("bestGainUnchangedLimit", ARGS::Opt, bestGainUnchangedLimit,
00257 "[2] Beam search stops when the best clause found does not change "
00258 "in this number of iterations."),
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269 ARGS()
00270 };
00271
00272
00273 bool checkParams()
00274 {
00275 bool ok = true;
00276 if (beamSize<=0) {cout<<"ERROR: beamSize must be positive"<<endl; ok =false;}
00277
00278 if (minWt<0) { cout << "ERROR: minWt must be non-negative" << endl;ok =false;}
00279
00280 if (penalty<0) { cout <<"ERROR: penalty must be non-negative"<<endl;ok=false;}
00281
00282 if (maxVars<=0) { cout << "ERROR: maxVar must be positive" << endl;ok =false;}
00283
00284 if (maxNumPredicates <= 0)
00285 {
00286 cout << "ERROR: maxNumPredicates must be positive" << endl; ok = false;
00287 }
00288
00289 if (cacheSize < 0)
00290 { cout << "ERROR: cacheSize must be non-negative" << endl; ok = false; }
00291
00292 if (ddelta <= 0 || ddelta > 1)
00293 { cout << "ERROR: gamma must be between 0 and 1" << endl; ok = false; }
00294
00295 if (epsilon <= 0 || epsilon >= 1)
00296 { cout << "ERROR: epsilon must be between 0 and 1" << endl; ok = false; }
00297
00298 if (fraction < 0 || fraction > 1)
00299 { cout << "ERROR: fraction must be between 0 and 1" << endl; ok = false;}
00300
00301 if (priorMean < 0)
00302 { cout << "ERROR: priorMean must be non-negative" << endl; ok = false; }
00303
00304 if (priorStdDev <= 0)
00305 { cout << "ERROR: priorStdDev must be positive" << endl; ok = false; }
00306
00307 if (lbMaxIter <= 0)
00308 { cout << "ERROR: tightMaxIter must be positive" << endl; ok = false; }
00309
00310 if (lbConvThresh <= 0 || lbConvThresh >= 1)
00311 { cout << "ERROR: tightConvThresh must be between 0 and 1" << endl; ok=false;}
00312
00313 if (looseMaxIter <= 0)
00314 { cout << "ERROR: looseMaxIter must be positive" << endl; ok = false; }
00315
00316 if (looseConvThresh <= 0 || looseConvThresh >= 1)
00317 { cout << "ERROR: looseConvThresh must be between 0 and 1" << endl; ok=false;}
00318
00319 if (numEstBestClauses <= 0)
00320 { cout << "ERROR: numClausesReEval must be positive" << endl; ok = false; }
00321
00322 if (bestGainUnchangedLimit <= 0)
00323 { cout << "ERROR: bestGainUnchangedLimit must be positive" << endl; ok=false;}
00324
00325 if (!structGradDescent && withEM)
00326 { cout << "ERROR: EM can only be used with structural gradient descent" << endl; ok=false; }
00327
00328 if (structGradDescent && nonEvidPredsStr == NULL)
00329 {
00330 cout << "ERROR: you must specify non-evidence predicates for "
00331 << "structural gradient descent" << endl;
00332 ok = false;
00333 }
00334
00335 return ok;
00336 }
00337
00338
00339
00340
00341
00342
00343
00344
00345 int main(int argc, char* argv[])
00346 {
00347 ARGS::parse(argc,argv,&cout);
00348 Timer timer;
00349 double begSec, startSec = timer.time();
00350
00351
00352
00353
00354 Term::computeFixedSizeB();
00355 Predicate::computeFixedSizeB();
00356 Clause::computeFixedSizeB();
00357 AuxClauseData::computeFixedSizeB();
00358
00359
00361
00362 if (!checkParams()) return -1;
00363
00364
00365
00366
00367
00368
00369 Array<string> constFilesArr, dbFilesArr;
00370 extractFileNames(inMLNFiles, constFilesArr);
00371 assert(constFilesArr.size() >= 1);
00372 string inMLNFile = constFilesArr[0];
00373 constFilesArr.removeItem(0);
00374 extractFileNames(dbFiles, dbFilesArr);
00375
00376 if (dbFilesArr.size() <= 0)
00377 { cout << "ERROR: must specify training data with -t flag."<<endl; return -1;}
00378
00379
00380 if (multipleDatabases)
00381 {
00382
00383 if ( (constFilesArr.size() > 0 && constFilesArr.size() != dbFilesArr.size()))
00384 {
00385 cout << "ERROR: when there are multiple databases, if .mln files "
00386 << "containing constants are specified, there must "
00387 << "be the same number of them as .db files" << endl;
00388 return -1;
00389 }
00390 }
00391
00392 StringHashArray tmpNEPredNames;
00393 Array<string> nonEvidPredNames;
00394 if (nonEvidPredsStr)
00395 {
00396 if(!extractPredNames(nonEvidPredsStr, NULL, tmpNEPredNames))
00397 {
00398 cout << "ERROR: failed to extract non-evidence predicate names." << endl;
00399 return -1;
00400 }
00401 for (int i = 0; i < tmpNEPredNames.size(); i++)
00402 nonEvidPredNames.append(tmpNEPredNames[i]);
00403 }
00404
00406
00407 Array<Domain*> domains;
00408 Array<MLN*> mlns;
00409 StringHashArray* queryPredNames = NULL;
00410 if (structGradDescent)
00411 {
00412 queryPredNames = new StringHashArray();
00413 for (int i = 0; i < nonEvidPredNames.size(); i++)
00414 queryPredNames->append(nonEvidPredNames[i]);
00415 }
00416
00417 bool addUnitClauses = false;
00418
00419 bool allPredsExceptQueriesAreCW = true;
00420 bool mwsLazy = true;
00421 if (structGradDescent && withEM) allPredsExceptQueriesAreCW = false;
00422 begSec = timer.time();
00423 cout << "Parsing MLN and creating domains..." << endl;
00424 createDomainsAndMLNs(domains, mlns, multipleDatabases, inMLNFile,
00425 constFilesArr, dbFilesArr, queryPredNames,
00426 addUnitClauses, priorMean, mwsLazy,
00427 allPredsExceptQueriesAreCW, NULL, NULL);
00428 cout << "Parsing MLN and creating domains took ";
00429 Timer::printTime(cout, timer.time()-begSec); cout << endl << endl;
00430
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00447
00448 if (nonEvidPredNames.size() == 0)
00449 domains[0]->getNonEqualPredicateNames(nonEvidPredNames);
00450 bool cacheClauses = (cacheSize > 0);
00451 bool reEvalBestCandidatesWithTightParams = true;
00452
00453 StructLearn sl(&mlns, startFromEmptyMLN, outMLNFile, &domains,
00454 &nonEvidPredNames, maxVars, maxNumPredicates, cacheClauses,
00455 cacheSize, tryAllFlips,
00456 !noSampleClauses, ddelta, epsilon,
00457 minClauseSamples, maxClauseSamples,
00458 !noPrior, priorMean, priorStdDev,
00459 !noWtPredsEqually,
00460 lbMaxIter, lbConvThresh, looseMaxIter, looseConvThresh,
00461 beamSize, bestGainUnchangedLimit, numEstBestClauses,
00462 minWt, penalty,
00463 !noSampleGndPreds,fraction,minGndPredSamples,maxGndPredSamples,
00464 reEvalBestCandidatesWithTightParams, structGradDescent,
00465 withEM);
00466 sl.run();
00467
00468
00470
00471 deleteDomains(domains);
00472 for (int i = 0; i < mlns.size(); i++) delete mlns[i];
00473 PowerSet::deletePowerSet();
00474
00475 cout << "Total time taken = ";
00476 Timer::printTime(cout, timer.time()-startSec); cout << endl;
00477 }