00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include <fstream>
00068 #include <iostream>
00069 #include <sstream>
00070 #include "arguments.h"
00071 #include "inferenceargs.h"
00072 #include "lbfgsb.h"
00073 #include "discriminativelearner.h"
00074 #include "learnwts.h"
00075 #include "maxwalksat.h"
00076 #include "mcsat.h"
00077 #include "gibbssampler.h"
00078 #include "simulatedtempering.h"
00079
00080
00081
00082 bool PRINT_CLAUSE_DURING_COUNT = true;
00083
00084 const double DISC_DEFAULT_STD_DEV = 2;
00085 const double GEN_DEFAULT_STD_DEV = 100;
00086
00087
00088 bool discLearn = false;
00089 bool genLearn = false;
00090 char* outMLNFile = NULL;
00091 char* dbFiles = NULL;
00092 char* nonEvidPredsStr = NULL;
00093 bool noAddUnitClauses = false;
00094 bool multipleDatabases = false;
00095 bool initWithLogOdds = false;
00096 bool isQueryEvidence = false;
00097
00098 bool aPeriodicMLNs = false;
00099
00100 bool noPrior = false;
00101 double priorMean = 0;
00102 double priorStdDev = -1;
00103
00104
00105 int maxIter = 10000;
00106 double convThresh = 1e-5;
00107 bool noEqualPredWt = false;
00108
00109
00110 int numIter = 100;
00111 double maxSec = 0;
00112 double maxMin = 0;
00113 double maxHour = 0;
00114 double learningRate = 0.001;
00115 double momentum = 0.0;
00116 bool withEM = false;
00117 char* aInferStr = NULL;
00118 bool noUsePerWeight = false;
00119 bool useNewton = false;
00120 bool useCG = false;
00121 bool useVP = false;
00122 int discMethod = DiscriminativeLearner::CG;
00123 double cg_lambda = 100;
00124 double cg_max_lambda = DBL_MAX;
00125 bool cg_noprecond = false;
00126 int amwsMaxSubsequentSteps = -1;
00127 char* ainDBListFile = NULL;
00128
00129
00130
00131
00132
00133 ARGS ARGS::Args[] =
00134 {
00135
00136 ARGS("i", ARGS::Req, ainMLNFiles,
00137 "Comma-separated input .mln files. (With the -multipleDatabases "
00138 "option, the second file to the last one are used to contain constants "
00139 "from different databases, and they correspond to the .db files "
00140 "specified with the -t option.)"),
00141
00142 ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00143 "Specified non-evidence atoms (comma-separated with no space) are "
00144 "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00145 "appearing here cannot be query atoms and cannot appear in the -o "
00146 "option."),
00147
00148 ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00149 "Specified evidence atoms (comma-separated with no space) are open "
00150 "world, while other evidence atoms are closed-world. "
00151 "Atoms appearing here cannot appear in the -c option."),
00152
00153
00154
00155 ARGS("m", ARGS::Tog, amapPos,
00156 "(Embed in -infer argument) "
00157 "Run MAP inference and return only positive query atoms."),
00158
00159 ARGS("a", ARGS::Tog, amapAll,
00160 "(Embed in -infer argument) "
00161 "Run MAP inference and show 0/1 results for all query atoms."),
00162
00163 ARGS("p", ARGS::Tog, agibbsInfer,
00164 "(Embed in -infer argument) "
00165 "Run inference using MCMC (Gibbs sampling) and return probabilities "
00166 "for all query atoms."),
00167
00168 ARGS("ms", ARGS::Tog, amcsatInfer,
00169 "(Embed in -infer argument) "
00170 "Run inference using MC-SAT and return probabilities "
00171 "for all query atoms"),
00172
00173 ARGS("simtp", ARGS::Tog, asimtpInfer,
00174 "(Embed in -infer argument) "
00175 "Run inference using simulated tempering and return probabilities "
00176 "for all query atoms"),
00177
00178 ARGS("seed", ARGS::Opt, aSeed,
00179 "(Embed in -infer argument) "
00180 "[2350877] Seed used to initialize the randomizer in the inference "
00181 "algorithm. If not set, seed is initialized from a fixed random number."),
00182
00183 ARGS("lazy", ARGS::Opt, aLazy,
00184 "(Embed in -infer argument) "
00185 "[false] Run lazy version of inference if this flag is set."),
00186
00187 ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox,
00188 "(Embed in -infer argument) "
00189 "[false] Lazy version of inference will not approximate by deactivating "
00190 "atoms to save memory. This flag is ignored if -lazy is not set."),
00191
00192 ARGS("memLimit", ARGS::Opt, aMemLimit,
00193 "(Embed in -infer argument) "
00194 "[-1] Maximum limit in kbytes which should be used for inference. "
00195 "-1 means main memory available on system is used."),
00196
00197
00198
00199 ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00200 "(Embed in -infer argument) "
00201 "[100000] (MaxWalkSat) The max number of steps taken."),
00202
00203 ARGS("tries", ARGS::Opt, amwsTries,
00204 "(Embed in -infer argument) "
00205 "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00206
00207 ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00208 "(Embed in -infer argument) "
00209 "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00210 "with weight <= specified weight."),
00211
00212 ARGS("hard", ARGS::Opt, amwsHard,
00213 "(Embed in -infer argument) "
00214 "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00215 "satisfy a soft one."),
00216
00217 ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00218 "(Embed in -infer argument) "
00219 "[2] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00220 "2 = TABU, 3 = SAMPLESAT)."),
00221
00222 ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00223 "(Embed in -infer argument) "
00224 "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00225 "atom when using the tabu heuristic in MaxWalkSat." ),
00226
00227 ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState,
00228 "(Embed in -infer argument) "
00229 "[false] (MaxWalkSat) If false, the naive way of saving low states "
00230 "(each time a low state is found, the whole state is saved) is used; "
00231 "otherwise, a list of variables flipped since the last low state is "
00232 "kept and the low state is reconstructed. This can be much faster for "
00233 "very large data sets."),
00234
00235
00236
00237 ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00238 "(Embed in -infer argument) "
00239 "[0] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00240
00241 ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00242 "(Embed in -infer argument) "
00243 "[0] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00244
00245 ARGS("minSteps", ARGS::Opt, amcmcMinSteps,
00246 "(Embed in -infer argument) "
00247 "[-1] (MCMC) Minimum number of MCMC sampling steps."),
00248
00249 ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps,
00250 "(Embed in -infer argument) "
00251 "[optimal] (MCMC) Maximum number of MCMC sampling steps."),
00252
00253 ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds,
00254 "(Embed in -infer argument) "
00255 "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00256
00257
00258
00259 ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00260 "(Embed in -infer argument) "
00261 "[2] (Simulated Tempering) Selection interval between swap attempts"),
00262
00263 ARGS("numRuns", ARGS::Opt, asimtpNumST,
00264 "(Embed in -infer argument) "
00265 "[3] (Simulated Tempering) Number of simulated tempering runs"),
00266
00267 ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00268 "(Embed in -infer argument) "
00269 "[10] (Simulated Tempering) Number of swapping chains"),
00270
00271
00272
00273 ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00274 "(Embed in -infer argument) "
00275 "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00276
00277 ARGS("saRatio", ARGS::Opt, assSaRatio,
00278 "(Embed in -infer argument) "
00279 "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00280 "MC-SAT"),
00281
00282 ARGS("saTemperature", ARGS::Opt, assSaTemp,
00283 "(Embed in -infer argument) "
00284 "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00285 "SampleSat"),
00286
00287 ARGS("lateSa", ARGS::Tog, assLateSa,
00288 "(Embed in -infer argument) "
00289 "[false] Run simulated annealing from the start in SampleSat"),
00290
00291
00292
00293 ARGS("numChains", ARGS::Opt, amcmcNumChains,
00294 "(Embed in -infer argument) "
00295 "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00296 "at least 2)."),
00297
00298 ARGS("delta", ARGS::Opt, agibbsDelta,
00299 "(Embed in -infer argument) "
00300 "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00301 "exceeded is less than this value."),
00302
00303 ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00304 "(Embed in -infer argument) "
00305 "[0.01] (Gibbs) Fractional error from true probability."),
00306
00307 ARGS("fracConverged", ARGS::Opt, agibbsFracConverged,
00308 "(Embed in -infer argument) "
00309 "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00310 "have converged."),
00311
00312 ARGS("walksatType", ARGS::Opt, agibbsWalksatType,
00313 "(Embed in -infer argument) "
00314 "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00315 "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00316
00317 ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest,
00318 "(Embed in -infer argument) "
00319 "[100] Perform convergence test once after this many number of samples "
00320 "per chain."),
00321
00322
00323
00324 ARGS("periodic", ARGS::Tog, aPeriodicMLNs,
00325 "Write out MLNs after 1, 2, 5, 10, 20, 50, etc. iterations"),
00326
00327 ARGS("infer", ARGS::Opt, aInferStr,
00328 "Specified inference parameters when using discriminative learning. "
00329 "The arguments are to be encapsulated in \"\" and the syntax is "
00330 "identical to the infer command (run infer with no commands to see "
00331 "this). If not specified, 5 steps of MC-SAT with no burn-in is used."),
00332
00333 ARGS("d", ARGS::Tog, discLearn, "Discriminative weight learning."),
00334
00335 ARGS("g", ARGS::Tog, genLearn, "Generative weight learning."),
00336
00337 ARGS("o", ARGS::Req, outMLNFile,
00338 "Output .mln file containing formulas with learned weights."),
00339
00340 ARGS("t", ARGS::Opt, dbFiles,
00341 "Comma-separated .db files containing the training database "
00342 "(of true/false ground atoms), including function definitions, "
00343 "e.g. ai.db,graphics.db,languages.db."),
00344
00345 ARGS("l", ARGS::Opt, ainDBListFile,
00346 "list of database files used in learning"
00347 ", each line contains a pointer to a database file."),
00348
00349 ARGS("ne", ARGS::Opt, nonEvidPredsStr,
00350 "First-order non-evidence predicates (comma-separated with no space), "
00351 "e.g., cancer,smokes,friends. For discriminative learning, at least "
00352 "one non-evidence predicate must be specified. For generative learning, "
00353 "the specified predicates are included in the (weighted) pseudo-log-"
00354 "likelihood computation; if none are specified, all are included."),
00355
00356 ARGS("noAddUnitClauses", ARGS::Tog, noAddUnitClauses,
00357 "If specified, unit clauses are not included in the .mln file; "
00358 "otherwise they are included."),
00359
00360 ARGS("multipleDatabases", ARGS::Tog, multipleDatabases,
00361 "If specified, each .db file belongs to a separate database; "
00362 "otherwise all .db files belong to the same database."),
00363
00364 ARGS("withEM", ARGS::Tog, withEM,
00365 "If set, EM is used to fill in missing truth values; "
00366 "otherwise missing truth values are set to false."),
00367
00368 ARGS("dNumIter", ARGS::Opt, numIter,
00369 "[100] (For discriminative learning only.) "
00370 "Number of iterations to run discriminative learning method."),
00371
00372 ARGS("dMaxSec", ARGS::Opt, maxSec,
00373 "[-1] Maximum number of seconds to spend learning"),
00374
00375 ARGS("dMaxMin", ARGS::Opt, maxMin,
00376 "[-1] Maximum number of minutes to spend learning"),
00377
00378 ARGS("dMaxHour", ARGS::Opt, maxHour,
00379 "[-1] Maximum number of hours to spend learning"),
00380
00381 ARGS("dLearningRate", ARGS::Opt, learningRate,
00382 "[0.001] (For discriminative learning only) "
00383 "Learning rate for the gradient descent in disc. learning algorithm."),
00384
00385 ARGS("dMomentum", ARGS::Opt, momentum,
00386 "[0.0] (For discriminative learning only) "
00387 "Momentum term for the gradient descent in disc. learning algorithm."),
00388
00389 ARGS("dNoPW", ARGS::Tog, noUsePerWeight,
00390 "[false] (For voted perceptron only.) "
00391 "Do not use per-weight learning rates, based on the number of true "
00392 "groundings per weight."),
00393
00394 ARGS("dVP", ARGS::Tog, useVP,
00395 "[false] (For discriminative learning only) "
00396 "Use voted perceptron to learn the weights."),
00397
00398 ARGS("dNewton", ARGS::Tog, useNewton,
00399 "[false] (For discriminative learning only) "
00400 "Use diagonalized Newton's method to learn the weights."),
00401
00402 ARGS("dCG", ARGS::Tog, useCG,
00403 "[false] (For discriminative learning only) "
00404 "Use rescaled conjugate gradient to learn the weights."),
00405
00406 ARGS("cgLambda", ARGS::Opt, cg_lambda,
00407 "[100] (For CG only) (For CG only) Starting value of parameter to limit "
00408 "step size"),
00409
00410 ARGS("cgMaxLambda", ARGS::Opt, cg_max_lambda,
00411 "[no limit] (For CG only) Maximum value of parameter to limit step size"),
00412
00413 ARGS("cgNoPrecond", ARGS::Tog, cg_noprecond,
00414 "[false] (For CG only) precondition without the diagonal Hessian"),
00415
00416 ARGS("queryEvidence", ARGS::Tog, isQueryEvidence,
00417 "[false] If this flag is set, then all the groundings of query preds not "
00418 "in db are assumed false evidence."),
00419
00420 ARGS("dInitWithLogOdds", ARGS::Tog, initWithLogOdds,
00421 "[false] (For discriminative learning only.) "
00422 "Initialize clause weights to their log odds instead of zero."),
00423
00424 ARGS("dMwsMaxSubsequentSteps", ARGS::Opt, amwsMaxSubsequentSteps,
00425 "[Same as mwsMaxSteps] (For discriminative learning only.) The max "
00426 "number of MaxWalkSat steps taken in subsequent iterations (>= 2) of "
00427 "disc. learning. If not specified, mwsMaxSteps is used in each "
00428 "iteration"),
00429
00430 ARGS("gMaxIter", ARGS::Opt, maxIter,
00431 "[10000] (For generative learning only.) "
00432 "Max number of iterations to run L-BFGS-B, "
00433 "the optimization algorithm for generative learning."),
00434
00435 ARGS("gConvThresh", ARGS::Opt, convThresh,
00436 "[1e-5] (For generative learning only.) "
00437 "Fractional change in pseudo-log-likelihood at which "
00438 "L-BFGS-B terminates."),
00439
00440 ARGS("gNoEqualPredWt", ARGS::Opt, noEqualPredWt,
00441 "[false] (For generative learning only.) "
00442 "If specified, the predicates are not weighted equally in the "
00443 "pseudo-log-likelihood computation; otherwise they are."),
00444
00445 ARGS("noPrior", ARGS::Tog, noPrior, "No Gaussian priors on formula weights."),
00446
00447 ARGS("priorMean", ARGS::Opt, priorMean,
00448 "[0] Means of Gaussian priors on formula weights. By default, "
00449 "for each formula, it is the weight given in the .mln input file, "
00450 "or fraction thereof if the formula turns into multiple clauses. "
00451 "This mean applies if no weight is given in the .mln file."),
00452
00453 ARGS("priorStdDev", ARGS::Opt, priorStdDev,
00454 "[2 for discriminative learning. 100 for generative learning] "
00455 "Standard deviations of Gaussian priors on clause weights."),
00456
00457 ARGS()
00458 };
00459
00460
00461
00462 void loadArray(const char* file, Array<string>& array)
00463 {
00464 ifstream is(file);
00465 array.clear();
00466 string line;
00467 while (getline(is, line))
00468 {
00469 array.append(line);
00470 }
00471 }
00472
00473 int main(int argc, char* argv[])
00474 {
00475 ARGS::parse(argc,argv,&cout);
00476
00477 if (!discLearn && !genLearn)
00478 {
00479
00480 discLearn = true;
00481
00482
00483
00484
00485 }
00486
00487 Timer timer;
00488 double startSec = timer.time();
00489 double begSec;
00490
00491 if (priorStdDev < 0)
00492 {
00493 if (discLearn)
00494 {
00495 cout << "priorStdDev set to (discriminative learning's) default of "
00496 << DISC_DEFAULT_STD_DEV << endl;
00497 priorStdDev = DISC_DEFAULT_STD_DEV;
00498 }
00499 else
00500 {
00501 cout << "priorStdDev set to (generative learning's) default of "
00502 << GEN_DEFAULT_STD_DEV << endl;
00503 priorStdDev = GEN_DEFAULT_STD_DEV;
00504 }
00505 }
00506
00507
00509 if (discLearn && nonEvidPredsStr == NULL)
00510 {
00511 cout << "ERROR: you must specify non-evidence predicates for "
00512 << "discriminative learning" << endl;
00513 return -1;
00514 }
00515
00516 if (maxIter <= 0) { cout << "maxIter must be > 0" << endl; return -1; }
00517 if (convThresh <= 0 || convThresh > 1)
00518 { cout << "convThresh must be > 0 and <= 1" << endl; return -1; }
00519 if (priorStdDev <= 0) { cout << "priorStdDev must be > 0" << endl; return -1;}
00520
00521 if (amwsMaxSteps <= 0)
00522 { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
00523
00524
00525 if (amwsMaxSubsequentSteps <= 0) amwsMaxSubsequentSteps = amwsMaxSteps;
00526
00527 if (amwsTries <= 0)
00528 { cout << "ERROR: tries must be positive" << endl; return -1; }
00529
00530 if (aMemLimit <= 0 && aMemLimit != -1)
00531 { cout << "ERROR: limit must be positive (or -1)" << endl; return -1; }
00532
00533 if (!discLearn && aLazy)
00534 {
00535 cout << "ERROR: lazy can only be used with discriminative learning"
00536 << endl;
00537 return -1;
00538 }
00539
00540 ofstream out(outMLNFile);
00541 if (!out.good())
00542 {
00543 cout << "ERROR: unable to open " << outMLNFile << endl;
00544 return -1;
00545 }
00546
00547
00548 if (discLearn)
00549 {
00550
00551 if (!useVP && !useCG && !useNewton)
00552 useCG = true;
00553
00554 if ((useCG || useNewton) && !noUsePerWeight)
00555 {
00556 noUsePerWeight = true;
00557 }
00558
00559
00560 amcmcMaxSteps = -1;
00561 amcmcBurnMaxSteps = -1;
00562 if (!aInferStr)
00563 {
00564
00565
00566 amcsatInfer = true;
00567 }
00568
00569 else
00570 {
00571 int inferArgc = 0;
00572 char **inferArgv = new char*[200];
00573 for (int i = 0; i < 200; i++)
00574 {
00575 inferArgv[i] = new char[500];
00576 }
00577
00578
00579 string inferString = "infer ";
00580 inferString.append(aInferStr);
00581 extractArgs(inferString.c_str(), inferArgc, inferArgv);
00582 cout << "extractArgs " << inferArgc << endl;
00583 for (int i = 0; i < inferArgc; i++)
00584 {
00585 cout << i << ": " << inferArgv[i] << endl;
00586 }
00587
00588 ARGS::parseFromCommandLine(inferArgc, inferArgv);
00589
00590
00591 for (int i = 0; i < 200; i++)
00592 {
00593 delete[] inferArgv[i];
00594 }
00595 delete[] inferArgv;
00596 }
00597
00598 if (!asimtpInfer && !amapPos && !amapAll && !agibbsInfer && !amcsatInfer)
00599 {
00600
00601 amcsatInfer = true;
00602 }
00603 }
00604
00605
00606
00607
00608
00609
00610
00611 Array<string> constFilesArr;
00612 extractFileNames(ainMLNFiles, constFilesArr);
00613 assert(constFilesArr.size() >= 1);
00614 string inMLNFile = constFilesArr[0];
00615 constFilesArr.removeItem(0);
00616
00617 Array<string> dbFilesArr;
00618 if (NULL != dbFiles)
00619 {
00620 extractFileNames(dbFiles, dbFilesArr);
00621 }
00622 else
00623 {
00624 loadArray(ainDBListFile, dbFilesArr);
00625 }
00626
00627 if (dbFilesArr.size() <= 0)
00628 {cout<<"ERROR: must specify training data with -t option."<<endl; return -1;}
00629
00630
00631 if (multipleDatabases)
00632 {
00633
00634 if ((constFilesArr.size() > 0 && constFilesArr.size() != dbFilesArr.size()))
00635 {
00636 cout << "ERROR: when there are multiple databases, if .mln files "
00637 << "containing constants are specified, there must "
00638 << "be the same number of them as .db files" << endl;
00639 return -1;
00640 }
00641 }
00642
00643 StringHashArray nonEvidPredNames;
00644 if (nonEvidPredsStr)
00645 {
00646 if (!extractPredNames(nonEvidPredsStr, NULL, nonEvidPredNames))
00647 {
00648 cout << "ERROR: failed to extract non-evidence predicate names." << endl;
00649 return -1;
00650 }
00651 }
00652
00653 StringHashArray owPredNames;
00654 StringHashArray cwPredNames;
00655
00657
00658 cout << "Parsing MLN and creating domains..." << endl;
00659 StringHashArray* nePredNames = (discLearn) ? &nonEvidPredNames : NULL;
00660 Array<Domain*> domains;
00661 Array<MLN*> mlns;
00662 begSec = timer.time();
00663 bool allPredsExceptQueriesAreCW = true;
00664 if (discLearn)
00665 {
00666
00667 if (aOpenWorldPredsStr)
00668 {
00669 if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames))
00670 return -1;
00671 assert(owPredNames.size() > 0);
00672 }
00673
00674
00675 if (aClosedWorldPredsStr)
00676 {
00677 if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames))
00678 return -1;
00679 assert(cwPredNames.size() > 0);
00680 if (!checkQueryPredsNotInClosedWorldPreds(nonEvidPredNames, cwPredNames))
00681 return -1;
00682 }
00683
00684
00685 allPredsExceptQueriesAreCW = false;
00686 }
00687
00688
00689 createDomainsAndMLNs(domains, mlns, multipleDatabases, inMLNFile,
00690 constFilesArr, dbFilesArr, nePredNames,
00691 !noAddUnitClauses, priorMean, true,
00692 allPredsExceptQueriesAreCW, &owPredNames, &cwPredNames);
00693 cout << "Parsing MLN and creating domains took ";
00694 Timer::printTime(cout, timer.time() - begSec); cout << endl;
00695
00696
00697
00698
00699
00700
00701
00702
00703
00704
00705
00706
00707
00708
00709
00711
00712
00713 IndexTranslator* indexTrans
00714 = (IndexTranslator::needIndexTranslator(mlns, domains)) ?
00715 new IndexTranslator(&mlns, &domains) : NULL;
00716
00717 if (indexTrans)
00718 cout << endl << "the weights of clauses in the CNFs of existential"
00719 << " formulas will be tied" << endl;
00720
00721 Array<double> priorMeans, priorStdDevs;
00722 if (!noPrior)
00723 {
00724 if (indexTrans)
00725 {
00726 indexTrans->setPriorMeans(priorMeans);
00727 priorStdDevs.growToSize(priorMeans.size());
00728 for (int i = 0; i < priorMeans.size(); i++)
00729 priorStdDevs[i] = priorStdDev;
00730 }
00731 else
00732 {
00733 const ClauseHashArray* clauses = mlns[0]->getClauses();
00734 int numClauses = clauses->size();
00735 for (int i = 0; i < numClauses; i++)
00736 {
00737 priorMeans.append((*clauses)[i]->getWt());
00738 priorStdDevs.append(priorStdDev);
00739 }
00740 }
00741 }
00742
00743 int numClausesFormulas;
00744 if (indexTrans)
00745 numClausesFormulas = indexTrans->getNumClausesAndExistFormulas();
00746 else
00747 numClausesFormulas = mlns[0]->getClauses()->size();
00748
00749
00751 Array<double> wts;
00752
00753
00754 if (discLearn)
00755 {
00756 wts.growToSize(numClausesFormulas + 1);
00757 double* wwts = (double*) wts.getItems();
00758 wwts++;
00759
00760 string nePredsStr = nonEvidPredsStr;
00761
00762
00763 SampleSatParams* ssparams = new SampleSatParams;
00764 ssparams->lateSa = assLateSa;
00765 ssparams->saRatio = assSaRatio;
00766 ssparams->saTemp = assSaTemp;
00767
00768
00769 MaxWalksatParams* mwsparams = NULL;
00770 mwsparams = new MaxWalksatParams;
00771 mwsparams->ssParams = ssparams;
00772 mwsparams->maxSteps = amwsMaxSteps;
00773 mwsparams->maxTries = amwsTries;
00774 mwsparams->targetCost = amwsTargetWt;
00775 mwsparams->hard = amwsHard;
00776
00777
00778 mwsparams->numSolutions = amwsNumSolutions;
00779 mwsparams->heuristic = amwsHeuristic;
00780 mwsparams->tabuLength = amwsTabuLength;
00781 mwsparams->lazyLowState = amwsLazyLowState;
00782
00783
00784 MCSatParams* msparams = new MCSatParams;
00785 msparams->mwsParams = mwsparams;
00786
00787 msparams->numChains = 1;
00788 msparams->burnMinSteps = amcmcBurnMinSteps;
00789 msparams->burnMaxSteps = amcmcBurnMaxSteps;
00790 msparams->minSteps = amcmcMinSteps;
00791 msparams->maxSteps = amcmcMaxSteps;
00792 msparams->maxSeconds = amcmcMaxSeconds;
00793
00794
00795 GibbsParams* gibbsparams = new GibbsParams;
00796 gibbsparams->mwsParams = mwsparams;
00797 gibbsparams->numChains = amcmcNumChains;
00798 gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00799 gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00800 gibbsparams->minSteps = amcmcMinSteps;
00801 gibbsparams->maxSteps = amcmcMaxSteps;
00802 gibbsparams->maxSeconds = amcmcMaxSeconds;
00803
00804 gibbsparams->gamma = 1 - agibbsDelta;
00805 gibbsparams->epsilonError = agibbsEpsilonError;
00806 gibbsparams->fracConverged = agibbsFracConverged;
00807 gibbsparams->walksatType = agibbsWalksatType;
00808 gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00809
00810
00811 SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00812 stparams->mwsParams = mwsparams;
00813 stparams->numChains = amcmcNumChains;
00814 stparams->burnMinSteps = amcmcBurnMinSteps;
00815 stparams->burnMaxSteps = amcmcBurnMaxSteps;
00816 stparams->minSteps = amcmcMinSteps;
00817 stparams->maxSteps = amcmcMaxSteps;
00818 stparams->maxSeconds = amcmcMaxSeconds;
00819
00820 stparams->subInterval = asimtpSubInterval;
00821 stparams->numST = asimtpNumST;
00822 stparams->numSwap = asimtpNumSwap;
00823
00824 Array<VariableState*> states;
00825 Array<Inference*> inferences;
00826
00827 states.growToSize(domains.size());
00828 inferences.growToSize(domains.size());
00829
00830
00831 Array<int> allPredGndingsAreNonEvid;
00832 Array<Predicate*> ppreds;
00833
00834
00835
00836
00837 for (int j = 0; j < mlns[0]->getNumClauses(); j++)
00838 {
00839 Clause* c = (Clause*) mlns[0]->getClause(j);
00840
00841
00842 if (c->getWt() != 0)
00843 c->lock();
00844 c->setWt(1);
00845 }
00846
00847 for (int i = 0; i < domains.size(); i++)
00848 {
00849 Domain* domain = domains[i];
00850 MLN* mln = mlns[i];
00851
00852
00853
00854 if (amcmcMaxSteps <= 0)
00855 {
00856 int minSize = INT_MAX;
00857
00858 for (int c = 0; c < mln->getNumClauses(); c++)
00859 {
00860 Clause* clause = (Clause*)mln->getClause(c);
00861 double size = clause->getNumGroundings(domain);
00862 if (size < minSize) minSize = (int)size;
00863 }
00864 int steps = 10000 / minSize;
00865 if (steps < 5) steps = 5;
00866 cout << "Setting number of MCMC steps to " << steps << endl;
00867 amcmcMaxSteps = steps;
00868 msparams->maxSteps = amcmcMaxSteps;
00869 gibbsparams->maxSteps = amcmcMaxSteps;
00870 stparams->maxSteps = amcmcMaxSteps;
00871 }
00872
00873
00874 if (!aLazy)
00875 domains[i]->getDB()->setLazyFlag(false);
00876
00877
00878 GroundPredicateHashArray* unePreds = NULL;
00879
00880
00881 GroundPredicateHashArray* knePreds = NULL;
00882 Array<TruthValue>* knePredValues = NULL;
00883
00884
00885 if (!allPredsExceptQueriesAreCW)
00886 {
00887 for (int i = 0; i < owPredNames.size(); i++)
00888 {
00889 nePredsStr.append(",");
00890 nePredsStr.append(owPredNames[i]);
00891 nonEvidPredNames.append(owPredNames[i]);
00892 }
00893 }
00894
00895 Array<Predicate*> gpreds;
00896 Array<TruthValue> gpredValues;
00897
00898
00899 if (!aLazy)
00900 {
00901 unePreds = new GroundPredicateHashArray;
00902 knePreds = new GroundPredicateHashArray;
00903 knePredValues = new Array<TruthValue>;
00904
00905 allPredGndingsAreNonEvid.growToSize(domain->getNumPredicates(), false);
00906
00907 createComLineQueryPreds(nePredsStr, domain, domain->getDB(),
00908 unePreds, knePreds,
00909 &allPredGndingsAreNonEvid, NULL);
00910
00911
00912
00913
00914
00915
00916 knePredValues->growToSize(knePreds->size(), FALSE);
00917 for (int predno = 0; predno < knePreds->size(); predno++)
00918 {
00919
00920 int blockIdx = domain->getBlock((*knePreds)[predno]);
00921 if (blockIdx > -1 &&
00922 domain->getDB()->getValue((*knePreds)[predno]) == TRUE)
00923 {
00924 domain->setBlockEvidence(blockIdx, false);
00925 }
00926
00927 (*knePredValues)[predno] =
00928 domain->getDB()->setValue((*knePreds)[predno], UNKNOWN);
00929 }
00930
00931
00932
00933
00934 if (isQueryEvidence)
00935 {
00936
00937 for (int predno = 0; predno < unePreds->size(); predno++)
00938 {
00939 domain->getDB()->setValue((*unePreds)[predno], FALSE);
00940 delete (*unePreds)[predno];
00941 }
00942 unePreds->clear();
00943 }
00944 }
00945 else
00946 {
00947 Array<Predicate*> ppreds;
00948
00949 domain->getDB()->setPerformingInference(false);
00950
00951 gpreds.clear();
00952 gpredValues.clear();
00953 for (int predno = 0; predno < nonEvidPredNames.size(); predno++)
00954 {
00955 ppreds.clear();
00956 int predid = domain->getPredicateId(nonEvidPredNames[predno].c_str());
00957 Predicate::createAllGroundings(predid, domain, ppreds);
00958 gpreds.append(ppreds);
00959 }
00960
00961 domain->getDB()->setValuesToUnknown(&gpreds, &gpredValues);
00962 }
00963
00964
00965
00966 cout << endl << "constructing state for domain " << i << "..." << endl;
00967 bool markHardGndClauses = false;
00968 bool trackParentClauseWts = true;
00969
00970 VariableState*& state = states[i];
00971 state = new VariableState(unePreds, knePreds, knePredValues,
00972 &allPredGndingsAreNonEvid, markHardGndClauses,
00973 trackParentClauseWts, mln, domain, aLazy);
00974
00975 Inference*& inference = inferences[i];
00976 bool trackClauseTrueCnts = true;
00977
00978 if (amapPos || amapAll)
00979 {
00980
00981
00982 mwsparams->numSolutions = 1;
00983 inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts,
00984 mwsparams);
00985 }
00986 else if (amcsatInfer)
00987 {
00988 inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
00989 }
00990 else if (asimtpInfer)
00991 {
00992
00993
00994 mwsparams->numSolutions = 1;
00995 inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
00996 stparams);
00997 }
00998 else if (agibbsInfer)
00999 {
01000
01001
01002 mwsparams->numSolutions = 1;
01003 inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
01004 gibbsparams);
01005 }
01006
01007 if (!aLazy)
01008 {
01009
01010 domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
01011
01012
01013
01014 for (int predno = 0; predno < unePreds->size(); predno++)
01015 {
01016 domain->getDB()->setValue((*unePreds)[predno], FALSE);
01017 }
01018 }
01019 else
01020 {
01021 domain->getDB()->setValuesToGivenValues(&gpreds, &gpredValues);
01022
01023
01024 for (int predno = 0; predno < gpreds.size(); predno++)
01025 {
01026
01027
01028 delete gpreds[predno];
01029 }
01030
01031 domain->getDB()->setPerformingInference(true);
01032 }
01033 }
01034 cout << endl << "done constructing variable states" << endl << endl;
01035
01036 if (useVP)
01037 discMethod = DiscriminativeLearner::SIMPLE;
01038 else if (useNewton)
01039 discMethod = DiscriminativeLearner::DN;
01040 else
01041 discMethod = DiscriminativeLearner::CG;
01042
01043 DiscriminativeLearner dl(inferences, nonEvidPredNames, indexTrans, aLazy,
01044 withEM, !noUsePerWeight, discMethod, cg_lambda,
01045 !cg_noprecond, cg_max_lambda);
01046
01047 if (!noPrior)
01048 dl.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
01049 priorStdDevs.getItems());
01050 else
01051 dl.setMeansStdDevs(-1, NULL, NULL);
01052
01053 begSec = timer.time();
01054 cout << "learning (discriminative) weights .. " << endl;
01055 double maxTime = maxSec + 60*maxMin + 3600*maxHour;
01056 dl.learnWeights(wwts, wts.size()-1, numIter, maxTime, learningRate,
01057 momentum, initWithLogOdds, amwsMaxSubsequentSteps,
01058 aPeriodicMLNs);
01059 cout << endl << endl << "Done learning discriminative weights. "<< endl;
01060 cout << "Time Taken for learning = ";
01061 Timer::printTime(cout, (timer.time() - begSec)); cout << endl;
01062
01063 if (mwsparams) delete mwsparams;
01064 if (ssparams) delete ssparams;
01065 if (msparams) delete msparams;
01066 if (gibbsparams) delete gibbsparams;
01067 if (stparams) delete stparams;
01068 for (int i = 0; i < inferences.size(); i++) delete inferences[i];
01069 for (int i = 0; i < states.size(); i++) delete states[i];
01070 }
01071 else
01072 {
01074
01075 Array<bool> areNonEvidPreds;
01076 if (nonEvidPredNames.empty())
01077 {
01078 areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), true);
01079 for (int i = 0; i < domains[0]->getNumPredicates(); i++)
01080 {
01081
01082 if (domains[0]->getPredicateTemplate(i)->isEqualPred())
01083 {
01084 const char* pname = domains[0]->getPredicateTemplate(i)->getName();
01085 int predId = domains[0]->getPredicateId(pname);
01086 areNonEvidPreds[predId] = false;
01087 }
01088
01089 if (domains[0]->getPredicateTemplate(i)->isInternalPredicateTemplate())
01090 {
01091 const char* pname = domains[0]->getPredicateTemplate(i)->getName();
01092 int predId = domains[0]->getPredicateId(pname);
01093 areNonEvidPreds[predId] = false;
01094 }
01095 }
01096 }
01097 else
01098 {
01099 areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), false);
01100 for (int i = 0; i < nonEvidPredNames.size(); i++)
01101 {
01102 int predId = domains[0]->getPredicateId(nonEvidPredNames[i].c_str());
01103 if (predId < 0)
01104 {
01105 cout << "ERROR: Predicate " << nonEvidPredNames[i] << " undefined."
01106 << endl;
01107 exit(-1);
01108 }
01109 areNonEvidPreds[predId] = true;
01110 }
01111 }
01112
01113 Array<bool>* nePreds = &areNonEvidPreds;
01114 PseudoLogLikelihood pll(nePreds, &domains, !noEqualPredWt, false,-1,-1,-1);
01115 pll.setIndexTranslator(indexTrans);
01116
01117 if (!noPrior)
01118 pll.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
01119 priorStdDevs.getItems());
01120 else
01121 pll.setMeansStdDevs(-1, NULL, NULL);
01122
01124
01125 begSec = timer.time();
01126 for (int m = 0; m < mlns.size(); m++)
01127 {
01128 cout << "Computing counts for clauses in domain " << m << "..." << endl;
01129 const ClauseHashArray* clauses = mlns[m]->getClauses();
01130 for (int i = 0; i < clauses->size(); i++)
01131 {
01132 if (PRINT_CLAUSE_DURING_COUNT)
01133 {
01134 cout << "clause " << i << ": ";
01135 (*clauses)[i]->printWithoutWt(cout, domains[m]);
01136 cout << endl; cout.flush();
01137 }
01138 MLNClauseInfo* ci = (MLNClauseInfo*) mlns[m]->getMLNClauseInfo(i);
01139 pll.computeCountsForNewAppendedClause((*clauses)[i], &(ci->index), m,
01140 NULL, false, NULL);
01141 }
01142 }
01143 pll.compress();
01144 cout <<"Computing counts took ";
01145 Timer::printTime(cout, timer.time() - begSec); cout << endl;
01146
01148
01149
01150 wts.growToSize(numClausesFormulas + 1);
01151 for (int i = 0; i < numClausesFormulas; i++) wts[i+1] = 0;
01152
01153
01154
01155 cout << "L-BFGS-B is finding optimal weights......" << endl;
01156 begSec = timer.time();
01157 LBFGSB lbfgsb(maxIter, convThresh, &pll, numClausesFormulas);
01158 int iter;
01159 bool error;
01160 double pllValue = lbfgsb.minimize((double*)wts.getItems(), iter, error);
01161
01162 if (error) cout << "LBFGSB returned with an error!" << endl;
01163 cout << "num iterations = " << iter << endl;
01164 cout << "time taken = ";
01165 Timer::printTime(cout, timer.time() - begSec);
01166 cout << endl;
01167 cout << "pseudo-log-likelihood = " << -pllValue << endl;
01168
01169 }
01170
01172 if (indexTrans) assignWtsAndOutputMLN(out, mlns, domains, wts, indexTrans);
01173 else assignWtsAndOutputMLN(out, mlns, domains, wts);
01174
01175 out.close();
01176
01178 deleteDomains(domains);
01179
01180 for (int i = 0; i < mlns.size(); i++)
01181 {
01182 if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
01183 {
01184 mlns[i]->setClauses(NULL);
01185 mlns[i]->setMLNClauseInfos(NULL);
01186 mlns[i]->setPredIdToClausesMap(NULL);
01187 mlns[i]->setFormulaAndClausesArray(NULL);
01188 mlns[i]->setExternalClause(NULL);
01189 }
01190 delete mlns[i];
01191 }
01192
01193 PowerSet::deletePowerSet();
01194 if (indexTrans) delete indexTrans;
01195
01196 cout << "Total time = ";
01197 Timer::printTime(cout, timer.time() - startSec); cout << endl;
01198 }