00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #include <fstream>
00067 #include <iostream>
00068 #include <sstream>
00069 #include "arguments.h"
00070 #include "inferenceargs.h"
00071 #include "lbfgsb.h"
00072 #include "votedperceptron.h"
00073 #include "learnwts.h"
00074 #include "maxwalksat.h"
00075 #include "mcsat.h"
00076 #include "gibbssampler.h"
00077 #include "simulatedtempering.h"
00078
00079
00080
00081 bool PRINT_CLAUSE_DURING_COUNT = true;
00082
00083 const double DISC_DEFAULT_STD_DEV = 1;
00084 const double GEN_DEFAULT_STD_DEV = 100;
00085
00086
00087 bool discLearn = false;
00088 bool genLearn = false;
00089 char* outMLNFile = NULL;
00090 char* dbFiles = NULL;
00091 char* nonEvidPredsStr = NULL;
00092 bool noAddUnitClauses = false;
00093 bool multipleDatabases = false;
00094 bool initToZero = false;
00095 bool isQueryEvidence = false;
00096
00097 bool noPrior = false;
00098 double priorMean = 0;
00099 double priorStdDev = -1;
00100
00101
00102 int maxIter = 10000;
00103 double convThresh = 1e-5;
00104 bool noEqualPredWt = false;
00105
00106
00107 int numIter = 200;
00108 double learningRate = 0.001;
00109 double momentum = 0.0;
00110 bool rescaleGradient = false;
00111 bool withEM = false;
00112
00113
00114
00115
00116
00117 ARGS ARGS::Args[] =
00118 {
00119
00120 ARGS("i", ARGS::Req, ainMLNFiles,
00121 "Comma-separated input .mln files. (With the -multipleDatabases "
00122 "option, the second file to the last one are used to contain constants "
00123 "from different databases, and they correspond to the .db files "
00124 "specified with the -t option.)"),
00125
00126 ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00127 "Specified non-evidence atoms (comma-separated with no space) are "
00128 "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00129 "appearing here cannot be query atoms and cannot appear in the -o "
00130 "option."),
00131
00132 ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00133 "Specified evidence atoms (comma-separated with no space) are open "
00134 "world, while other evidence atoms are closed-world. "
00135 "Atoms appearing here cannot appear in the -c option."),
00136
00137
00138
00139 ARGS("m", ARGS::Tog, amapPos,
00140 "Run MAP inference and return only positive query atoms."),
00141
00142 ARGS("a", ARGS::Tog, amapAll,
00143 "Run MAP inference and show 0/1 results for all query atoms."),
00144
00145 ARGS("p", ARGS::Tog, agibbsInfer,
00146 "Run inference using MCMC (Gibbs sampling) and return probabilities "
00147 "for all query atoms."),
00148
00149 ARGS("ms", ARGS::Tog, amcsatInfer,
00150 "Run inference using MC-SAT and return probabilities "
00151 "for all query atoms"),
00152
00153 ARGS("simtp", ARGS::Tog, asimtpInfer,
00154 "Run inference using simulated tempering and return probabilities "
00155 "for all query atoms"),
00156
00157 ARGS("seed", ARGS::Opt, aSeed,
00158 "[random] Seed used to initialize the randomizer in the inference "
00159 "algorithm. If not set, seed is initialized from the current date and "
00160 "time."),
00161
00162 ARGS("lazy", ARGS::Opt, aLazy,
00163 "[false] Run lazy version of inference if this flag is set."),
00164
00165 ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox,
00166 "[false] Lazy version of inference will not approximate by deactivating "
00167 "atoms to save memory. This flag is ignored if -lazy is not set."),
00168
00169 ARGS("memLimit", ARGS::Opt, aMemLimit,
00170 "[-1] Maximum limit in kbytes which should be used for inference. "
00171 "-1 means main memory available on system is used."),
00172
00173
00174
00175 ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00176 "[1000000] (MaxWalkSat) The max number of steps taken."),
00177
00178 ARGS("mwsTries", ARGS::Opt, amwsTries,
00179 "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00180
00181 ARGS("mwsTargetWt", ARGS::Opt, amwsTargetWt,
00182 "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00183 "with weight <= specified weight."),
00184
00185 ARGS("mwsHard", ARGS::Opt, amwsHard,
00186 "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00187 "satisfy a soft one."),
00188
00189 ARGS("mwsHeuristic", ARGS::Opt, amwsHeuristic,
00190 "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00191 "2 = TABU, 3 = SAMPLESAT)."),
00192
00193 ARGS("mwsTabuLength", ARGS::Opt, amwsTabuLength,
00194 "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00195 "atom when using the tabu heuristic in MaxWalkSat." ),
00196
00197 ARGS("mwsLazyLowState", ARGS::Opt, amwsLazyLowState,
00198 "[false] (MaxWalkSat) If false, the naive way of saving low states "
00199 "(each time a low state is found, the whole state is saved) is used; "
00200 "otherwise, a list of variables flipped since the last low state is "
00201 "kept and the low state is reconstructed. This can be much faster for "
00202 "very large data sets."),
00203
00204
00205
00206 ARGS("mcmcBurnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00207 "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00208
00209 ARGS("mcmcBurnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00210 "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00211
00212 ARGS("mcmcMinSteps", ARGS::Opt, amcmcMinSteps,
00213 "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00214
00215 ARGS("mcmcMaxSteps", ARGS::Opt, amcmcMaxSteps,
00216 "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00217
00218 ARGS("mcmcMaxSeconds", ARGS::Opt, amcmcMaxSeconds,
00219 "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00220
00221
00222
00223 ARGS("simtpSubInterval", ARGS::Opt, asimtpSubInterval,
00224 "[2] (Simulated Tempering) Selection interval between swap attempts"),
00225
00226 ARGS("simtpNumRuns", ARGS::Opt, asimtpNumST,
00227 "[3] (Simulated Tempering) Number of simulated tempering runs"),
00228
00229 ARGS("simtpNumSwap", ARGS::Opt, asimtpNumSwap,
00230 "[10] (Simulated Tempering) Number of swapping chains"),
00231
00232
00233
00234 ARGS("mcsatNumStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00235 "[1] (MC-SAT) Number of total steps (mcsat & gibbs) for every mcsat "
00236 "step"),
00237
00238
00239
00240 ARGS("mcsatNumSolutions", ARGS::Opt, amwsNumSolutions,
00241 "[10] Return nth SAT solution in SampleSat"),
00242
00243 ARGS("mcsatSaRatio", ARGS::Opt, assSaRatio,
00244 "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00245 "MC-SAT"),
00246
00247 ARGS("mcsatSaTemperature", ARGS::Opt, assSaTemp,
00248 "[10] Temperature (/100) for sim. annealing step in SampleSat"),
00249
00250 ARGS("mcsatLateSa", ARGS::Tog, assLateSa,
00251 "[false] Run simulated annealing from the start in SampleSat"),
00252
00253
00254
00255 ARGS("gibbsNumChains", ARGS::Opt, amcmcNumChains,
00256 "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00257 "at least 2)."),
00258
00259 ARGS("gibbsDelta", ARGS::Opt, agibbsDelta,
00260 "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00261 "exceeded is less than this value."),
00262
00263 ARGS("gibbsEpsilonError", ARGS::Opt, agibbsEpsilonError,
00264 "[0.01] (Gibbs) Fractional error from true probability."),
00265
00266 ARGS("gibbsFracConverged", ARGS::Opt, agibbsFracConverged,
00267 "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00268 "have converged."),
00269
00270 ARGS("gibbsWalksatType", ARGS::Opt, agibbsWalksatType,
00271 "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00272 "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00273
00274 ARGS("gibbsSamplesPerTest", ARGS::Opt, agibbsSamplesPerTest,
00275 "[100] Perform convergence test once after this many number of samples "
00276 "per chain."),
00277
00278
00279
00280 ARGS("d", ARGS::Tog, discLearn, "Discriminative weight learning."),
00281
00282 ARGS("g", ARGS::Tog, genLearn, "Generative weight learning."),
00283
00284 ARGS("o", ARGS::Req, outMLNFile,
00285 "Output .mln file containing formulas with learned weights."),
00286
00287 ARGS("t", ARGS::Req, dbFiles,
00288 "Comma-separated .db files containing the training database "
00289 "(of true/false ground atoms), including function definitions, "
00290 "e.g. ai.db,graphics.db,languages.db."),
00291
00292 ARGS("ne", ARGS::Opt, nonEvidPredsStr,
00293 "First-order non-evidence predicates (comma-separated with no space), "
00294 "e.g., cancer,smokes,friends. For discriminative learning, at least "
00295 "one non-evidence predicate must be specified. For generative learning, "
00296 "the specified predicates are included in the (weighted) pseudo-log-"
00297 "likelihood computation; if none are specified, all are included."),
00298
00299 ARGS("noAddUnitClauses", ARGS::Tog, noAddUnitClauses,
00300 "If specified, unit clauses are not included in the .mln file; "
00301 "otherwise they are included."),
00302
00303 ARGS("multipleDatabases", ARGS::Tog, multipleDatabases,
00304 "If specified, each .db file belongs to a separate database; "
00305 "otherwise all .db files belong to the same database."),
00306
00307 ARGS("withEM", ARGS::Tog, withEM,
00308 "If set, EM is used to fill in missing truth values; "
00309 "otherwise missing truth values are set to false."),
00310
00311 ARGS("dNumIter", ARGS::Opt, numIter,
00312 "[200] (For discriminative learning only.) "
00313 "Number of iterations to run voted perceptron."),
00314
00315 ARGS("dLearningRate", ARGS::Opt, learningRate,
00316 "[0.001] (For discriminative learning only) "
00317 "Learning rate for the gradient descent in voted perceptron algorithm."),
00318
00319 ARGS("dMomentum", ARGS::Opt, momentum,
00320 "[0.0] (For discriminative learning only) "
00321 "Momentum term for the gradient descent in voted perceptron algorithm."),
00322
00323 ARGS("queryEvidence", ARGS::Tog, isQueryEvidence,
00324 "If this flag is set, then all the groundings of query preds not in db "
00325 "are assumed false evidence."),
00326
00327 ARGS("dRescale", ARGS::Tog, rescaleGradient,
00328 "(For discriminative learning only.) "
00329 "Rescale the gradient by the number of true groundings per weight."),
00330
00331 ARGS("dZeroInit", ARGS::Tog, initToZero,
00332 "(For discriminative learning only.) "
00333 "Initialize clause weights to zero instead of their log odds."),
00334
00335 ARGS("gMaxIter", ARGS::Opt, maxIter,
00336 "[10000] (For generative learning only.) "
00337 "Max number of iterations to run L-BFGS-B, "
00338 "the optimization algorithm for generative learning."),
00339
00340 ARGS("gConvThresh", ARGS::Opt, convThresh,
00341 "[1e-5] (For generative learning only.) "
00342 "Fractional change in pseudo-log-likelihood at which "
00343 "L-BFGS-B terminates."),
00344
00345 ARGS("gNoEqualPredWt", ARGS::Opt, noEqualPredWt,
00346 "(For generative learning only.) "
00347 "If specified, the predicates are not weighted equally in the "
00348 "pseudo-log-likelihood computation; otherwise they are."),
00349
00350 ARGS("noPrior", ARGS::Tog, noPrior, "No Gaussian priors on formula weights."),
00351
00352 ARGS("priorMean", ARGS::Opt, priorMean,
00353 "[0] Means of Gaussian priors on formula weights. By default, "
00354 "for each formula, it is the weight given in the .mln input file, "
00355 "or fraction thereof if the formula turns into multiple clauses. "
00356 "This mean applies if no weight is given in the .mln file."),
00357
00358 ARGS("priorStdDev", ARGS::Opt, priorStdDev,
00359 "[1 for discriminative learning. 100 for generative learning] "
00360 "Standard deviations of Gaussian priors on clause weights."),
00361
00362 ARGS()
00363 };
00364
00365
00366
00367 int main(int argc, char* argv[])
00368 {
00369 ARGS::parse(argc,argv,&cout);
00370
00371 if (!discLearn && !genLearn)
00372 {
00373 cout << "must specify either -d or -g "
00374 <<"(discriminative or generative learning) " << endl;
00375 return -1;
00376 }
00377
00378 Timer timer;
00379 double startSec = timer.time();
00380 double begSec;
00381
00382 if (priorStdDev < 0)
00383 {
00384 if (discLearn)
00385 {
00386 cout << "priorStdDev set to (discriminative learning's) default of "
00387 << DISC_DEFAULT_STD_DEV << endl;
00388 priorStdDev = DISC_DEFAULT_STD_DEV;
00389 }
00390 else
00391 {
00392 cout << "priorStdDev set to (generative learning's) default of "
00393 << GEN_DEFAULT_STD_DEV << endl;
00394 priorStdDev = GEN_DEFAULT_STD_DEV;
00395 }
00396 }
00397
00398
00400 if (discLearn && nonEvidPredsStr == NULL)
00401 {
00402 cout << "ERROR: you must specify non-evidence predicates for "
00403 << "discriminative learning" << endl;
00404 return -1;
00405 }
00406
00407 if (maxIter <= 0) { cout << "maxIter must be > 0" << endl; return -1; }
00408 if (convThresh <= 0 || convThresh > 1)
00409 { cout << "convThresh must be > 0 and <= 1" << endl; return -1; }
00410 if (priorStdDev <= 0) { cout << "priorStdDev must be > 0" << endl; return -1;}
00411
00412 if (amwsMaxSteps <= 0)
00413 { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
00414
00415 if (amwsTries <= 0)
00416 { cout << "ERROR: mwsTries must be positive" << endl; return -1; }
00417
00418 if (aMemLimit <= 0 && aMemLimit != -1)
00419 { cout << "ERROR: mwsLimit must be positive (or -1)" << endl; return -1; }
00420
00421 if (!discLearn && aLazy)
00422 {
00423 cout << "ERROR: lazy can only be used with discriminative learning"
00424 << endl;
00425 return -1;
00426 }
00427
00428 if (!discLearn && withEM)
00429 {
00430 cout << "ERROR: EM can only be used with discriminative learning" << endl;
00431 return -1;
00432 }
00433
00434 ofstream out(outMLNFile);
00435 if (!out.good())
00436 {
00437 cout << "ERROR: unable to open " << outMLNFile << endl;
00438 return -1;
00439 }
00440
00441
00442 if (discLearn && !asimtpInfer && !amapPos && !amapAll && !agibbsInfer &&
00443 !amcsatInfer)
00444 {
00445 amapPos = true;
00446 }
00447
00448
00449
00450
00451
00452
00453
00454 Array<string> constFilesArr;
00455 Array<string> dbFilesArr;
00456 extractFileNames(ainMLNFiles, constFilesArr);
00457 assert(constFilesArr.size() >= 1);
00458 string inMLNFile = constFilesArr[0];
00459 constFilesArr.removeItem(0);
00460 extractFileNames(dbFiles, dbFilesArr);
00461
00462 if (dbFilesArr.size() <= 0)
00463 {cout<<"ERROR: must specify training data with -t option."<<endl; return -1;}
00464
00465
00466 if (multipleDatabases)
00467 {
00468
00469 if ((constFilesArr.size() > 0 && constFilesArr.size() != dbFilesArr.size()))
00470 {
00471 cout << "ERROR: when there are multiple databases, if .mln files "
00472 << "containing constants are specified, there must "
00473 << "be the same number of them as .db files" << endl;
00474 return -1;
00475 }
00476 }
00477
00478 StringHashArray nonEvidPredNames;
00479 if (nonEvidPredsStr)
00480 {
00481 if(!extractPredNames(nonEvidPredsStr, NULL, nonEvidPredNames))
00482 {
00483 cout << "ERROR: failed to extract non-evidence predicate names." << endl;
00484 return -1;
00485 }
00486 }
00487
00488 StringHashArray owPredNames;
00489 StringHashArray cwPredNames;
00490
00492
00493 cout << "Parsing MLN and creating domains..." << endl;
00494 StringHashArray* nePredNames = (discLearn) ? &nonEvidPredNames : NULL;
00495 Array<Domain*> domains;
00496 Array<MLN*> mlns;
00497 begSec = timer.time();
00498 bool allPredsExceptQueriesAreCW = true;
00499 if (discLearn)
00500 {
00501
00502 if (aOpenWorldPredsStr)
00503 {
00504 if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames))
00505 return -1;
00506 assert(owPredNames.size() > 0);
00507 }
00508
00509
00510 if (aClosedWorldPredsStr)
00511 {
00512 if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames))
00513 return -1;
00514 assert(cwPredNames.size() > 0);
00515 if (!checkQueryPredsNotInClosedWorldPreds(nonEvidPredNames, cwPredNames))
00516 return -1;
00517 }
00518
00519 allPredsExceptQueriesAreCW = owPredNames.empty();
00520 }
00521
00522
00523 createDomainsAndMLNs(domains, mlns, multipleDatabases, inMLNFile,
00524 constFilesArr, dbFilesArr, nePredNames,
00525 !noAddUnitClauses, priorMean, true,
00526 allPredsExceptQueriesAreCW, &owPredNames);
00527 cout << "Parsing MLN and creating domains took ";
00528 Timer::printTime(cout, timer.time() - begSec); cout << endl;
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542
00543
00545
00546
00547 IndexTranslator* indexTrans
00548 = (IndexTranslator::needIndexTranslator(mlns, domains)) ?
00549 new IndexTranslator(&mlns, &domains) : NULL;
00550
00551 if (indexTrans)
00552 cout << endl << "the weights of clauses in the CNFs of existential"
00553 << " formulas will be tied" << endl;
00554
00555
00556 Array<double> priorMeans, priorStdDevs;
00557 if (!noPrior)
00558 {
00559 if (indexTrans)
00560 {
00561 indexTrans->setPriorMeans(priorMeans);
00562 priorStdDevs.growToSize(priorMeans.size());
00563 for (int i = 0; i < priorMeans.size(); i++)
00564 priorStdDevs[i] = priorStdDev;
00565 }
00566 else
00567 {
00568 const ClauseHashArray* clauses = mlns[0]->getClauses();
00569 int numClauses = clauses->size();
00570 for (int i = 0; i < numClauses; i++)
00571 {
00572 priorMeans.append((*clauses)[i]->getWt());
00573 priorStdDevs.append(priorStdDev);
00574 }
00575 }
00576 }
00577
00578
00579
00580 int numClausesFormulas = mlns[0]->getClauses()->size();
00581
00582
00584 Array<double> wts;
00585
00586
00587 if (discLearn)
00588 {
00589 wts.growToSize(numClausesFormulas + 1);
00590 double* wwts = (double*) wts.getItems();
00591 wwts++;
00592
00593 string nePredsStr = nonEvidPredsStr;
00594
00595
00596 SampleSatParams* ssparams = new SampleSatParams;
00597 ssparams->lateSa = assLateSa;
00598 ssparams->saRatio = assSaRatio;
00599 ssparams->saTemp = assSaTemp;
00600
00601
00602 MaxWalksatParams* mwsparams = NULL;
00603 mwsparams = new MaxWalksatParams;
00604 mwsparams->ssParams = ssparams;
00605 mwsparams->maxSteps = amwsMaxSteps;
00606 mwsparams->maxTries = amwsTries;
00607 mwsparams->targetCost = amwsTargetWt;
00608 mwsparams->hard = amwsHard;
00609
00610
00611 mwsparams->numSolutions = amwsNumSolutions;
00612 mwsparams->heuristic = amwsHeuristic;
00613 mwsparams->tabuLength = amwsTabuLength;
00614 mwsparams->lazyLowState = amwsLazyLowState;
00615
00616
00617 MCSatParams* msparams = new MCSatParams;
00618 msparams->mwsParams = mwsparams;
00619
00620 msparams->numChains = 1;
00621 msparams->burnMinSteps = amcmcBurnMinSteps;
00622 msparams->burnMaxSteps = amcmcBurnMaxSteps;
00623 msparams->minSteps = amcmcMinSteps;
00624 msparams->maxSteps = amcmcMaxSteps;
00625 msparams->maxSeconds = amcmcMaxSeconds;
00626 msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
00627
00628
00629 GibbsParams* gibbsparams = new GibbsParams;
00630 gibbsparams->mwsParams = mwsparams;
00631 gibbsparams->numChains = amcmcNumChains;
00632 gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00633 gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00634 gibbsparams->minSteps = amcmcMinSteps;
00635 gibbsparams->maxSteps = amcmcMaxSteps;
00636 gibbsparams->maxSeconds = amcmcMaxSeconds;
00637
00638 gibbsparams->gamma = 1 - agibbsDelta;
00639 gibbsparams->epsilonError = agibbsEpsilonError;
00640 gibbsparams->fracConverged = agibbsFracConverged;
00641 gibbsparams->walksatType = agibbsWalksatType;
00642 gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00643
00644
00645 SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00646 stparams->mwsParams = mwsparams;
00647 stparams->numChains = amcmcNumChains;
00648 stparams->burnMinSteps = amcmcBurnMinSteps;
00649 stparams->burnMaxSteps = amcmcBurnMaxSteps;
00650 stparams->minSteps = amcmcMinSteps;
00651 stparams->maxSteps = amcmcMaxSteps;
00652 stparams->maxSeconds = amcmcMaxSeconds;
00653
00654 stparams->subInterval = asimtpSubInterval;
00655 stparams->numST = asimtpNumST;
00656 stparams->numSwap = asimtpNumSwap;
00657
00658 Array<VariableState*> states;
00659 Array<Inference*> inferences;
00660
00661 states.growToSize(domains.size());
00662 inferences.growToSize(domains.size());
00663
00664
00665 Array<int> allPredGndingsAreNonEvid;
00666 Array<Predicate*> ppreds;
00667
00668 for (int i = 0; i < domains.size(); i++)
00669 {
00670 Domain* domain = domains[i];
00671 MLN* mln = mlns[i];
00672
00673
00674 if (!aLazy)
00675 domains[i]->getDB()->setLazyFlag(false);
00676
00677
00678 GroundPredicateHashArray* unePreds = NULL;
00679
00680
00681 GroundPredicateHashArray* knePreds = NULL;
00682 Array<TruthValue>* knePredValues = NULL;
00683
00684
00685 for (int j = 0; j < mln->getNumClauses(); j++)
00686 ((Clause*) mln->getClause(j))->setWt(1);
00687
00688
00689
00690 int numPreds = domain->getNumPredicates();
00691 for (int j = 0; j < numPreds; j++)
00692 {
00693 const char* predName = domain->getPredicateName(j);
00694
00695 if (nonEvidPredNames.contains(predName) ||
00696 domain->getPredicateTemplate(j)->isEqualPredicateTemplate())
00697 continue;
00698
00699 bool unknownPred = false;
00700 ppreds.clear();
00701 Predicate::createAllGroundings(j, domain, ppreds);
00702 for (int k = 0; k < ppreds.size(); k++)
00703 {
00704 TruthValue tv = domain->getDB()->getValue(ppreds[k]);
00705 if (tv == UNKNOWN)
00706 unknownPred = true;
00707 delete ppreds[k];
00708 }
00709 if (unknownPred)
00710 {
00711 nePredsStr.append(",");
00712 nePredsStr.append(predName);
00713 nonEvidPredNames.append(predName);
00714 }
00715 }
00716
00717
00718
00719 if (!aLazy)
00720 {
00721 unePreds = new GroundPredicateHashArray;
00722 knePreds = new GroundPredicateHashArray;
00723 knePredValues = new Array<TruthValue>;
00724
00725 allPredGndingsAreNonEvid.growToSize(domain->getNumPredicates(), false);
00726
00727 createComLineQueryPreds(nePredsStr, domain, domain->getDB(),
00728 unePreds, knePreds,
00729 &allPredGndingsAreNonEvid);
00730
00731
00732
00733
00734
00735 knePredValues->growToSize(knePreds->size(), FALSE);
00736 for (int predno = 0; predno < knePreds->size(); predno++)
00737 (*knePredValues)[predno] =
00738 domain->getDB()->setValue((*knePreds)[predno], UNKNOWN);
00739
00740
00741
00742
00743 if (isQueryEvidence)
00744
00745 for (int predno = 0; predno < unePreds->size(); predno++)
00746 domain->getDB()->setValue((*unePreds)[predno], FALSE);
00747 }
00748
00749
00750
00751 cout << endl << "constructing state for domain " << i << "..." << endl;
00752 bool markHardGndClauses = false;
00753 bool trackParentClauseWts = true;
00754 VariableState*& state = states[i];
00755 state = new VariableState(unePreds, knePreds, knePredValues,
00756 &allPredGndingsAreNonEvid, markHardGndClauses,
00757 trackParentClauseWts, mln, domain, aLazy);
00758
00759 Inference*& inference = inferences[i];
00760 bool trackClauseTrueCnts = true;
00761
00762 if (amapPos || amapAll)
00763 {
00764
00765
00766 mwsparams->numSolutions = 1;
00767 inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts,
00768 mwsparams);
00769 }
00770 else if (amcsatInfer)
00771 {
00772 inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
00773 }
00774 else if (asimtpInfer)
00775 {
00776
00777
00778 mwsparams->numSolutions = 1;
00779 inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
00780 stparams);
00781 }
00782 else if (agibbsInfer)
00783 {
00784
00785
00786 mwsparams->numSolutions = 1;
00787 inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
00788 gibbsparams);
00789 }
00790
00791 if (!aLazy)
00792 {
00793
00794 domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
00795
00796
00797
00798 for (int predno = 0; predno < unePreds->size(); predno++)
00799 {
00800 domain->getDB()->setValue((*unePreds)[predno], FALSE);
00801 }
00802 }
00803 }
00804 cout << endl << "done constructing variable states" << endl << endl;
00805
00806 VotedPerceptron vp(inferences, nonEvidPredNames, indexTrans, aLazy,
00807 rescaleGradient, withEM);
00808 if (!noPrior)
00809 vp.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
00810 priorStdDevs.getItems());
00811 else
00812 vp.setMeansStdDevs(-1, NULL, NULL);
00813
00814 begSec = timer.time();
00815 cout << "learning (discriminative) weights .. " << endl;
00816 vp.learnWeights(wwts, wts.size()-1, numIter, learningRate, momentum,
00817 !initToZero);
00818 cout << endl << endl << "Done learning discriminative weights. "<< endl;
00819 cout << "Time Taken for learning = ";
00820 Timer::printTime(cout, (timer.time() - begSec)); cout << endl;
00821
00822 if (mwsparams) delete mwsparams;
00823 if (ssparams) delete ssparams;
00824 if (msparams) delete msparams;
00825 if (gibbsparams) delete gibbsparams;
00826 if (stparams) delete stparams;
00827 for (int i = 0; i < inferences.size(); i++) delete inferences[i];
00828 for (int i = 0; i < states.size(); i++) delete states[i];
00829 }
00830 else
00831 {
00833
00834 Array<bool> areNonEvidPreds;
00835 if (nonEvidPredNames.empty())
00836 {
00837 areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), true);
00838 for (int i = 0; i < domains[0]->getNumPredicates(); i++)
00839 {
00840
00841 if (domains[0]->getPredicateTemplate(i)->isEqualPred())
00842 {
00843 const char* pname = domains[0]->getPredicateTemplate(i)->getName();
00844 int predId = domains[0]->getPredicateId(pname);
00845 areNonEvidPreds[predId] = false;
00846 }
00847
00848 if (domains[0]->getPredicateTemplate(i)->isInternalPredicateTemplate())
00849 {
00850 const char* pname = domains[0]->getPredicateTemplate(i)->getName();
00851 int predId = domains[0]->getPredicateId(pname);
00852 areNonEvidPreds[predId] = false;
00853 }
00854 }
00855 }
00856 else
00857 {
00858 areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), false);
00859 for (int i = 0; i < nonEvidPredNames.size(); i++)
00860 {
00861 int predId = domains[0]->getPredicateId(nonEvidPredNames[i].c_str());
00862 if (predId < 0)
00863 {
00864 cout << "ERROR: Predicate " << nonEvidPredNames[i] << " undefined."
00865 << endl;
00866 exit(-1);
00867 }
00868 areNonEvidPreds[predId] = true;
00869 }
00870 }
00871
00872 Array<bool>* nePreds = &areNonEvidPreds;
00873 PseudoLogLikelihood pll(nePreds, &domains, !noEqualPredWt, false,-1,-1,-1);
00874 pll.setIndexTranslator(indexTrans);
00875
00876 if (!noPrior)
00877 pll.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
00878 priorStdDevs.getItems());
00879 else
00880 pll.setMeansStdDevs(-1, NULL, NULL);
00881
00883
00884 begSec = timer.time();
00885 for (int m = 0; m < mlns.size(); m++)
00886 {
00887 cout << "Computing counts for clauses in domain " << m << "..." << endl;
00888 const ClauseHashArray* clauses = mlns[m]->getClauses();
00889 for (int i = 0; i < clauses->size(); i++)
00890 {
00891 if (PRINT_CLAUSE_DURING_COUNT)
00892 {
00893 cout << "clause " << i << ": ";
00894 (*clauses)[i]->printWithoutWt(cout, domains[m]);
00895 cout << endl; cout.flush();
00896 }
00897 MLNClauseInfo* ci = (MLNClauseInfo*) mlns[m]->getMLNClauseInfo(i);
00898 pll.computeCountsForNewAppendedClause((*clauses)[i], &(ci->index), m,
00899 NULL, false, NULL);
00900 }
00901 }
00902 pll.compress();
00903 cout <<"Computing counts took ";
00904 Timer::printTime(cout, timer.time() - begSec); cout << endl;
00905
00907
00908
00909 wts.growToSize(numClausesFormulas + 1);
00910 for (int i = 0; i < numClausesFormulas; i++) wts[i+1] = 0;
00911
00912
00913
00914 cout << "L-BFGS-B is finding optimal weights......" << endl;
00915 begSec = timer.time();
00916 LBFGSB lbfgsb(maxIter, convThresh, &pll, numClausesFormulas);
00917 int iter;
00918 bool error;
00919 double pllValue = lbfgsb.minimize((double*)wts.getItems(), iter, error);
00920
00921 if (error) cout << "LBFGSB returned with an error!" << endl;
00922 cout << "num iterations = " << iter << endl;
00923 cout << "time taken = ";
00924 Timer::printTime(cout, timer.time() - begSec);
00925 cout << endl;
00926 cout << "pseudo-log-likelihood = " << -pllValue << endl;
00927
00928 }
00929
00931 if (indexTrans) assignWtsAndOutputMLN(out, mlns, domains, wts, indexTrans);
00932 else assignWtsAndOutputMLN(out, mlns, domains, wts);
00933
00934 out.close();
00935
00937 deleteDomains(domains);
00938 for (int i = 0; i < mlns.size(); i++) delete mlns[i];
00939 PowerSet::deletePowerSet();
00940 if (indexTrans) delete indexTrans;
00941
00942 cout << "Total time = ";
00943 Timer::printTime(cout, timer.time() - startSec); cout << endl;
00944 }