learnwts.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #include <fstream>
00068 #include <iostream>
00069 #include <sstream>
00070 #include "arguments.h"
00071 #include "inferenceargs.h"
00072 #include "lbfgsb.h"
00073 #include "discriminativelearner.h"
00074 #include "learnwts.h"
00075 #include "maxwalksat.h"
00076 #include "mcsat.h"
00077 #include "gibbssampler.h"
00078 #include "simulatedtempering.h"
00079 
00080   //set to false to disable printing of clauses when they are counted during 
00081   //generative learning
00082 bool PRINT_CLAUSE_DURING_COUNT = true;
00083 
00084 const double DISC_DEFAULT_STD_DEV = 2;
00085 const double GEN_DEFAULT_STD_DEV = 100;
00086 
00087   // Variables for holding inference command line args are in inferenceargs.h
00088 bool discLearn = false;
00089 bool genLearn = false;
00090 char* outMLNFile = NULL;
00091 char* dbFiles = NULL;
00092 char* nonEvidPredsStr = NULL;
00093 bool noAddUnitClauses = false;
00094 bool multipleDatabases = false;
00095 bool initWithLogOdds = false;
00096 bool isQueryEvidence = false;
00097 
00098 bool aPeriodicMLNs = false;
00099 
00100 bool noPrior = false;
00101 double priorMean = 0;
00102 double priorStdDev = -1;
00103 
00104   // Generative learning args
00105 int maxIter = 10000;
00106 double convThresh = 1e-5;
00107 bool noEqualPredWt = false;
00108 
00109   // Discriminative learning args
00110 int numIter = 100;
00111 double maxSec  = 0;
00112 double maxMin  = 0;
00113 double maxHour = 0;
00114 double learningRate = 0.001;
00115 double momentum = 0.0;
00116 bool withEM = false;
00117 char* aInferStr = NULL;
00118 bool noUsePerWeight = false;
00119 bool useNewton = false;
00120 bool useCG = false;
00121 bool useVP = false;
00122 int  discMethod = DiscriminativeLearner::CG;
00123 double cg_lambda = 100;
00124 double cg_max_lambda = DBL_MAX;
00125 bool   cg_noprecond = false;
00126 int amwsMaxSubsequentSteps = -1;
00127 char* ainDBListFile = NULL;
00128 
00129 
00130   // Inference arguments needed for disc. learning defined in inferenceargs.h
00131   // TODO: List the arguments common to learnwts and inference in
00132   // inferenceargs.h. This can't be done with a static array.
00133 ARGS ARGS::Args[] = 
00134 {
00135     // BEGIN: Common arguments
00136   ARGS("i", ARGS::Req, ainMLNFiles, 
00137        "Comma-separated input .mln files. (With the -multipleDatabases "
00138        "option, the second file to the last one are used to contain constants "
00139        "from different databases, and they correspond to the .db files "
00140        "specified with the -t option.)"),
00141 
00142   ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00143        "Specified non-evidence atoms (comma-separated with no space) are "
00144        "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00145        "appearing here cannot be query atoms and cannot appear in the -o "
00146        "option."),
00147 
00148   ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00149        "Specified evidence atoms (comma-separated with no space) are open "
00150        "world, while other evidence atoms are closed-world. "
00151        "Atoms appearing here cannot appear in the -c option."),
00152     // END: Common arguments
00153 
00154     // BEGIN: Common inference arguments
00155   ARGS("m", ARGS::Tog, amapPos, 
00156        "(Embed in -infer argument) "
00157        "Run MAP inference and return only positive query atoms."),
00158 
00159   ARGS("a", ARGS::Tog, amapAll, 
00160        "(Embed in -infer argument) "
00161        "Run MAP inference and show 0/1 results for all query atoms."),
00162 
00163   ARGS("p", ARGS::Tog, agibbsInfer, 
00164        "(Embed in -infer argument) "
00165        "Run inference using MCMC (Gibbs sampling) and return probabilities "
00166        "for all query atoms."),
00167   
00168   ARGS("ms", ARGS::Tog, amcsatInfer,
00169        "(Embed in -infer argument) "
00170        "Run inference using MC-SAT and return probabilities "
00171        "for all query atoms"),
00172 
00173   ARGS("simtp", ARGS::Tog, asimtpInfer,
00174        "(Embed in -infer argument) "
00175        "Run inference using simulated tempering and return probabilities "
00176        "for all query atoms"),
00177 
00178   ARGS("seed", ARGS::Opt, aSeed,
00179        "(Embed in -infer argument) "
00180        "[2350877] Seed used to initialize the randomizer in the inference "
00181        "algorithm. If not set, seed is initialized from a fixed random number."),
00182 
00183   ARGS("lazy", ARGS::Opt, aLazy, 
00184        "(Embed in -infer argument) "
00185        "[false] Run lazy version of inference if this flag is set."),
00186   
00187   ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox, 
00188        "(Embed in -infer argument) "
00189        "[false] Lazy version of inference will not approximate by deactivating "
00190        "atoms to save memory. This flag is ignored if -lazy is not set."),
00191   
00192   ARGS("memLimit", ARGS::Opt, aMemLimit, 
00193        "(Embed in -infer argument) "
00194        "[-1] Maximum limit in kbytes which should be used for inference. "
00195        "-1 means main memory available on system is used."),
00196     // END: Common inference arguments
00197 
00198     // BEGIN: MaxWalkSat args
00199   ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00200        "(Embed in -infer argument) "
00201        "[100000] (MaxWalkSat) The max number of steps taken."),
00202 
00203   ARGS("tries", ARGS::Opt, amwsTries, 
00204        "(Embed in -infer argument) "
00205        "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00206 
00207   ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00208        "(Embed in -infer argument) "
00209        "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00210        "with weight <= specified weight."),
00211 
00212   ARGS("hard", ARGS::Opt, amwsHard, 
00213        "(Embed in -infer argument) "
00214        "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00215        "satisfy a soft one."),
00216   
00217   ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00218        "(Embed in -infer argument) "
00219        "[2] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00220        "2 = TABU, 3 = SAMPLESAT)."),
00221   
00222   ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00223        "(Embed in -infer argument) "
00224        "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00225        "atom when using the tabu heuristic in MaxWalkSat." ),
00226 
00227   ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState, 
00228        "(Embed in -infer argument) "
00229        "[false] (MaxWalkSat) If false, the naive way of saving low states "
00230        "(each time a low state is found, the whole state is saved) is used; "
00231        "otherwise, a list of variables flipped since the last low state is "
00232        "kept and the low state is reconstructed. This can be much faster for "
00233        "very large data sets."),  
00234     // END: MaxWalkSat args
00235 
00236     // BEGIN: MCMC args
00237   ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00238        "(Embed in -infer argument) "
00239        "[0] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00240 
00241   ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00242        "(Embed in -infer argument) "
00243        "[0] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00244 
00245   ARGS("minSteps", ARGS::Opt, amcmcMinSteps, 
00246        "(Embed in -infer argument) "
00247        "[-1] (MCMC) Minimum number of MCMC sampling steps."),
00248 
00249   ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps, 
00250        "(Embed in -infer argument) "
00251        "[optimal] (MCMC) Maximum number of MCMC sampling steps."),
00252 
00253   ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds, 
00254        "(Embed in -infer argument) "
00255        "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00256     // END: MCMC args
00257   
00258     // BEGIN: Simulated tempering args
00259   ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00260        "(Embed in -infer argument) "
00261         "[2] (Simulated Tempering) Selection interval between swap attempts"),
00262 
00263   ARGS("numRuns", ARGS::Opt, asimtpNumST,
00264        "(Embed in -infer argument) "
00265         "[3] (Simulated Tempering) Number of simulated tempering runs"),
00266 
00267   ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00268        "(Embed in -infer argument) "
00269         "[10] (Simulated Tempering) Number of swapping chains"),
00270     // END: Simulated tempering args
00271 
00272     // BEGIN: SampleSat args
00273   ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00274        "(Embed in -infer argument) "
00275        "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00276 
00277   ARGS("saRatio", ARGS::Opt, assSaRatio,
00278        "(Embed in -infer argument) "
00279        "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00280        "MC-SAT"),
00281 
00282   ARGS("saTemperature", ARGS::Opt, assSaTemp,
00283        "(Embed in -infer argument) "
00284         "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00285         "SampleSat"),
00286 
00287   ARGS("lateSa", ARGS::Tog, assLateSa,
00288        "(Embed in -infer argument) "
00289        "[false] Run simulated annealing from the start in SampleSat"),
00290     // END: SampleSat args
00291 
00292     // BEGIN: Gibbs sampling args
00293   ARGS("numChains", ARGS::Opt, amcmcNumChains, 
00294        "(Embed in -infer argument) "
00295        "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00296        "at least 2)."),
00297 
00298   ARGS("delta", ARGS::Opt, agibbsDelta,
00299        "(Embed in -infer argument) "
00300        "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00301        "exceeded is less than this value."),
00302 
00303   ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00304        "(Embed in -infer argument) "
00305        "[0.01] (Gibbs) Fractional error from true probability."),
00306 
00307   ARGS("fracConverged", ARGS::Opt, agibbsFracConverged, 
00308        "(Embed in -infer argument) "
00309        "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00310        "have converged."),
00311 
00312   ARGS("walksatType", ARGS::Opt, agibbsWalksatType, 
00313        "(Embed in -infer argument) "
00314        "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00315        "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00316 
00317   ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest, 
00318        "(Embed in -infer argument) "
00319        "[100] Perform convergence test once after this many number of samples "
00320        "per chain."),
00321     // END: Gibbs sampling args
00322 
00323     // BEGIN: Weight learning specific args
00324   ARGS("periodic", ARGS::Tog, aPeriodicMLNs,
00325        "Write out MLNs after 1, 2, 5, 10, 20, 50, etc. iterations"),
00326 
00327   ARGS("infer", ARGS::Opt, aInferStr,
00328        "Specified inference parameters when using discriminative learning. "
00329        "The arguments are to be encapsulated in \"\" and the syntax is "
00330        "identical to the infer command (run infer with no commands to see "
00331        "this). If not specified, 5 steps of MC-SAT with no burn-in is used."),
00332 
00333   ARGS("d", ARGS::Tog, discLearn, "Discriminative weight learning."),
00334 
00335   ARGS("g", ARGS::Tog, genLearn, "Generative weight learning."),
00336 
00337   ARGS("o", ARGS::Req, outMLNFile, 
00338        "Output .mln file containing formulas with learned weights."),
00339 
00340   ARGS("t", ARGS::Opt, dbFiles, 
00341        "Comma-separated .db files containing the training database "
00342        "(of true/false ground atoms), including function definitions, "
00343        "e.g. ai.db,graphics.db,languages.db."),
00344   
00345   ARGS("l", ARGS::Opt, ainDBListFile, 
00346        "list of database files used in learning"
00347        ", each line contains a pointer to a database file."),
00348     
00349   ARGS("ne", ARGS::Opt, nonEvidPredsStr, 
00350        "First-order non-evidence predicates (comma-separated with no space),  "
00351        "e.g., cancer,smokes,friends. For discriminative learning, at least "
00352        "one non-evidence predicate must be specified. For generative learning, "
00353        "the specified predicates are included in the (weighted) pseudo-log-"
00354        "likelihood computation; if none are specified, all are included."),
00355     
00356   ARGS("noAddUnitClauses", ARGS::Tog, noAddUnitClauses,
00357        "If specified, unit clauses are not included in the .mln file; "
00358        "otherwise they are included."),
00359 
00360   ARGS("multipleDatabases", ARGS::Tog, multipleDatabases,
00361        "If specified, each .db file belongs to a separate database; "
00362        "otherwise all .db files belong to the same database."),
00363 
00364   ARGS("withEM", ARGS::Tog, withEM,
00365        "If set, EM is used to fill in missing truth values; "
00366        "otherwise missing truth values are set to false."),
00367 
00368   ARGS("dNumIter", ARGS::Opt, numIter, 
00369        "[100] (For discriminative learning only.) "
00370        "Number of iterations to run discriminative learning method."),
00371 
00372   ARGS("dMaxSec", ARGS::Opt, maxSec,
00373        "[-1] Maximum number of seconds to spend learning"),
00374 
00375   ARGS("dMaxMin", ARGS::Opt, maxMin,
00376        "[-1] Maximum number of minutes to spend learning"),
00377 
00378   ARGS("dMaxHour", ARGS::Opt, maxHour,
00379        "[-1] Maximum number of hours to spend learning"),
00380   
00381   ARGS("dLearningRate", ARGS::Opt, learningRate, 
00382        "[0.001] (For discriminative learning only) "
00383        "Learning rate for the gradient descent in disc. learning algorithm."),
00384 
00385   ARGS("dMomentum", ARGS::Opt, momentum, 
00386        "[0.0] (For discriminative learning only) "
00387        "Momentum term for the gradient descent in disc. learning algorithm."),
00388        
00389   ARGS("dNoPW", ARGS::Tog, noUsePerWeight,
00390        "[false] (For voted perceptron only.) "
00391        "Do not use per-weight learning rates, based on the number of true "
00392        "groundings per weight."),
00393   
00394   ARGS("dVP", ARGS::Tog, useVP,
00395        "[false] (For discriminative learning only) "
00396        "Use voted perceptron to learn the weights."),
00397 
00398   ARGS("dNewton", ARGS::Tog, useNewton,
00399        "[false] (For discriminative learning only) "
00400        "Use diagonalized Newton's method to learn the weights."),
00401 
00402   ARGS("dCG", ARGS::Tog, useCG,
00403        "[false] (For discriminative learning only) "
00404        "Use rescaled conjugate gradient to learn the weights."),
00405 
00406   ARGS("cgLambda", ARGS::Opt, cg_lambda,
00407        "[100] (For CG only) (For CG only) Starting value of parameter to limit "
00408        "step size"),
00409 
00410   ARGS("cgMaxLambda", ARGS::Opt, cg_max_lambda,
00411        "[no limit] (For CG only) Maximum value of parameter to limit step size"),
00412 
00413   ARGS("cgNoPrecond", ARGS::Tog, cg_noprecond,
00414        "[false] (For CG only) precondition without the diagonal Hessian"),
00415        
00416   ARGS("queryEvidence", ARGS::Tog, isQueryEvidence, 
00417        "[false] If this flag is set, then all the groundings of query preds not "
00418        "in db are assumed false evidence."),
00419 
00420   ARGS("dInitWithLogOdds", ARGS::Tog, initWithLogOdds,
00421        "[false] (For discriminative learning only.) "
00422        "Initialize clause weights to their log odds instead of zero."),
00423 
00424   ARGS("dMwsMaxSubsequentSteps", ARGS::Opt, amwsMaxSubsequentSteps,
00425        "[Same as mwsMaxSteps] (For discriminative learning only.) The max "
00426        "number of MaxWalkSat steps taken in subsequent iterations (>= 2) of "
00427        "disc. learning. If not specified, mwsMaxSteps is used in each "
00428        "iteration"),
00429   
00430   ARGS("gMaxIter", ARGS::Opt, maxIter, 
00431        "[10000] (For generative learning only.) "
00432        "Max number of iterations to run L-BFGS-B, "
00433        "the optimization algorithm for generative learning."),
00434   
00435   ARGS("gConvThresh", ARGS::Opt, convThresh, 
00436        "[1e-5] (For generative learning only.) "
00437        "Fractional change in pseudo-log-likelihood at which "
00438        "L-BFGS-B terminates."),
00439 
00440   ARGS("gNoEqualPredWt", ARGS::Opt, noEqualPredWt, 
00441        "[false] (For generative learning only.) "
00442        "If specified, the predicates are not weighted equally in the "
00443        "pseudo-log-likelihood computation; otherwise they are."),
00444   
00445   ARGS("noPrior", ARGS::Tog, noPrior, "No Gaussian priors on formula weights."),
00446 
00447   ARGS("priorMean", ARGS::Opt, priorMean, 
00448        "[0] Means of Gaussian priors on formula weights. By default, "
00449        "for each formula, it is the weight given in the .mln input file, " 
00450        "or fraction thereof if the formula turns into multiple clauses. "
00451        "This mean applies if no weight is given in the .mln file."),
00452 
00453   ARGS("priorStdDev", ARGS::Opt, priorStdDev, 
00454        "[2 for discriminative learning. 100 for generative learning] "
00455        "Standard deviations of Gaussian priors on clause weights."),
00456 
00457   ARGS()
00458 };
00459 
00460 //bool extractPredNames(...) defined in infer.h
00461 
00462 void loadArray(const char* file, Array<string>& array)
00463 {
00464   ifstream is(file);
00465   array.clear();
00466   string line;
00467   while (getline(is, line))
00468   {
00469     array.append(line);
00470   }
00471 }
00472 
00473 int main(int argc, char* argv[])
00474 {
00475   ARGS::parse(argc,argv,&cout);
00476 
00477   if (!discLearn && !genLearn) 
00478   { 
00479       // If nothing specified, then use disc. learning
00480     discLearn = true;
00481     
00482     //cout << "must specify either -d or -g "
00483     //     <<"(discriminative or generative learning) " << endl; 
00484     //return -1;
00485   }
00486 
00487   Timer timer;
00488   double startSec = timer.time();
00489   double begSec;
00490 
00491   if (priorStdDev < 0)
00492   {
00493     if (discLearn) 
00494     { 
00495       cout << "priorStdDev set to (discriminative learning's) default of " 
00496            << DISC_DEFAULT_STD_DEV << endl;
00497       priorStdDev = DISC_DEFAULT_STD_DEV;
00498     }
00499     else
00500     {
00501       cout << "priorStdDev set to (generative learning's) default of " 
00502            << GEN_DEFAULT_STD_DEV << endl;
00503       priorStdDev = GEN_DEFAULT_STD_DEV;      
00504     }
00505   }
00506 
00507 
00509   if (discLearn && nonEvidPredsStr == NULL)
00510   {
00511     cout << "ERROR: you must specify non-evidence predicates for "
00512          << "discriminative learning" << endl;
00513     return -1;
00514   }
00515 
00516   if (maxIter <= 0)  { cout << "maxIter must be > 0" << endl; return -1; }
00517   if (convThresh <= 0 || convThresh > 1)  
00518   { cout << "convThresh must be > 0 and <= 1" << endl; return -1;  }
00519   if (priorStdDev <= 0) { cout << "priorStdDev must be > 0" << endl; return -1;}
00520 
00521   if (amwsMaxSteps <= 0)
00522   { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
00523   
00524     // If max. subsequent steps not specified, use amwsMaxSteps
00525   if (amwsMaxSubsequentSteps <= 0) amwsMaxSubsequentSteps = amwsMaxSteps;
00526 
00527   if (amwsTries <= 0)
00528   { cout << "ERROR: tries must be positive" << endl; return -1; }
00529 
00530   if (aMemLimit <= 0 && aMemLimit != -1)
00531   { cout << "ERROR: limit must be positive (or -1)" << endl; return -1; }
00532 
00533   if (!discLearn && aLazy)
00534   {
00535     cout << "ERROR: lazy can only be used with discriminative learning"
00536          << endl;
00537     return -1;
00538   }
00539 
00540   ofstream out(outMLNFile);
00541   if (!out.good())
00542   {
00543     cout << "ERROR: unable to open " << outMLNFile << endl;
00544     return -1;
00545   }
00546 
00547     // Parse the inference parameters, if given
00548   if (discLearn)
00549   {
00550       // If no method given, then use CG
00551     if (!useVP && !useCG && !useNewton)
00552       useCG = true;
00553       // Per-weight can not be used with SCG or Newton
00554     if ((useCG || useNewton) && !noUsePerWeight)
00555     {
00556       noUsePerWeight = true;
00557     }
00558 
00559       // maxSteps is optimized after domains are built
00560     amcmcMaxSteps = -1;
00561     amcmcBurnMaxSteps = -1;
00562     if (!aInferStr)
00563     {
00564         // Set defaults of inference inside disc. weight learning:
00565         // MC-SAT with no burn-in, 5 steps
00566       amcsatInfer = true;
00567     }
00568       // If inference method given, we need to parse the parameters
00569     else
00570     {
00571       int inferArgc = 0;
00572       char **inferArgv = new char*[200];
00573       for (int i = 0; i < 200; i++)
00574       {
00575         inferArgv[i] = new char[500];
00576       }
00577 
00578         // Have to add program name (which is not used) to start of string
00579       string inferString = "infer ";
00580       inferString.append(aInferStr);
00581       extractArgs(inferString.c_str(), inferArgc, inferArgv);
00582       cout << "extractArgs " << inferArgc << endl;
00583       for (int i = 0; i < inferArgc; i++)
00584       {
00585         cout << i << ": " << inferArgv[i] << endl;
00586       }
00587 
00588       ARGS::parseFromCommandLine(inferArgc, inferArgv);
00589 
00590         // Delete memory allocated for args
00591       for (int i = 0; i < 200; i++)
00592       {
00593         delete[] inferArgv[i];
00594       }
00595       delete[] inferArgv; 
00596     }
00597     
00598     if (!asimtpInfer && !amapPos && !amapAll && !agibbsInfer && !amcsatInfer)
00599     {
00600         // If nothing specified, use MC-SAT
00601       amcsatInfer = true;
00602     }    
00603   }
00604 
00605 
00606   //the second .mln file to the last one in ainMLNFiles _may_ be used 
00607   //to hold constants, so they are held in constFilesArr. They will be
00608   //included into the first .mln file.
00609 
00610     //extract .mln and .db, file names
00611   Array<string> constFilesArr;
00612   extractFileNames(ainMLNFiles, constFilesArr);
00613   assert(constFilesArr.size() >= 1);
00614   string inMLNFile = constFilesArr[0];
00615   constFilesArr.removeItem(0);
00616 
00617   Array<string> dbFilesArr;
00618   if (NULL != dbFiles)
00619   {
00620     extractFileNames(dbFiles, dbFilesArr);
00621   }
00622   else
00623   {
00624     loadArray(ainDBListFile, dbFilesArr);
00625   }
00626 
00627   if (dbFilesArr.size() <= 0)
00628   {cout<<"ERROR: must specify training data with -t option."<<endl; return -1;}
00629  
00630     // if multiple databases, check the number of .db/.func files
00631   if (multipleDatabases) 
00632   {
00633       //if # .mln files containing constants/.func files and .db files are diff
00634     if ((constFilesArr.size() > 0 && constFilesArr.size() != dbFilesArr.size()))
00635     {
00636       cout << "ERROR: when there are multiple databases, if .mln files "
00637            << "containing constants are specified, there must " 
00638            << "be the same number of them as .db files" << endl;
00639       return -1;
00640     }
00641   }
00642 
00643   StringHashArray nonEvidPredNames;
00644   if (nonEvidPredsStr)
00645   {
00646     if (!extractPredNames(nonEvidPredsStr, NULL, nonEvidPredNames))
00647     {
00648       cout << "ERROR: failed to extract non-evidence predicate names." << endl;
00649       return -1;
00650     }
00651   }
00652 
00653   StringHashArray owPredNames;
00654   StringHashArray cwPredNames;
00655 
00657 
00658   cout << "Parsing MLN and creating domains..." << endl;
00659   StringHashArray* nePredNames = (discLearn) ? &nonEvidPredNames : NULL;
00660   Array<Domain*> domains;
00661   Array<MLN*> mlns;
00662   begSec = timer.time();
00663   bool allPredsExceptQueriesAreCW = true;
00664   if (discLearn)
00665   {
00666       //extract names of open-world evidence predicates
00667     if (aOpenWorldPredsStr)
00668     {
00669       if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames)) 
00670         return -1;
00671       assert(owPredNames.size() > 0);
00672     }
00673 
00674       //extract names of closed-world non-evidence predicates
00675     if (aClosedWorldPredsStr)
00676     {
00677       if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames)) 
00678         return -1;
00679       assert(cwPredNames.size() > 0);
00680       if (!checkQueryPredsNotInClosedWorldPreds(nonEvidPredNames, cwPredNames))
00681         return -1;
00682     }
00683  
00684     //allPredsExceptQueriesAreCW = owPredNames.empty();
00685     allPredsExceptQueriesAreCW = false;
00686   }
00687     // Parse as if lazy inference is set to true to set evidence atoms in DB
00688     // If lazy is not used, this is removed from DB
00689   createDomainsAndMLNs(domains, mlns, multipleDatabases, inMLNFile, 
00690                        constFilesArr, dbFilesArr, nePredNames,
00691                        !noAddUnitClauses, priorMean, true,
00692                        allPredsExceptQueriesAreCW, &owPredNames, &cwPredNames);
00693   cout << "Parsing MLN and creating domains took "; 
00694   Timer::printTime(cout, timer.time() - begSec); cout << endl;
00695 
00696   /*
00697   cout << "Clause prior means:" << endl;
00698   cout << "_________________________________" << endl;
00699   mlns[0]->printClausePriorMeans(cout, domains[0]);
00700   cout << "_________________________________" << endl;
00701   cout << endl;
00702 
00703   cout << "Formula prior means:" << endl;
00704   cout << "_________________________________" << endl;
00705   mlns[0]->printFormulaPriorMeans(cout);
00706   cout << "_________________________________" << endl;
00707   cout << endl;
00708   */
00709 
00711 
00712     //we need an index translator if clauses do not line up across multiple DBs
00713   IndexTranslator* indexTrans 
00714     = (IndexTranslator::needIndexTranslator(mlns, domains)) ?
00715        new IndexTranslator(&mlns, &domains) : NULL;  
00716 
00717   if (indexTrans) 
00718     cout << endl << "the weights of clauses in the CNFs of existential"
00719          << " formulas will be tied" << endl;
00720 
00721   Array<double> priorMeans, priorStdDevs;
00722   if (!noPrior)
00723   {
00724     if (indexTrans)
00725     {
00726       indexTrans->setPriorMeans(priorMeans);
00727       priorStdDevs.growToSize(priorMeans.size());
00728       for (int i = 0; i < priorMeans.size(); i++)
00729         priorStdDevs[i] = priorStdDev;
00730     }
00731     else
00732     {
00733       const ClauseHashArray* clauses = mlns[0]->getClauses();
00734       int numClauses = clauses->size();
00735       for (int i = 0; i < numClauses; i++)
00736       {
00737         priorMeans.append((*clauses)[i]->getWt());
00738         priorStdDevs.append(priorStdDev);
00739       }
00740     }
00741   }
00742 
00743   int numClausesFormulas;
00744   if (indexTrans)
00745       numClausesFormulas = indexTrans->getNumClausesAndExistFormulas();
00746   else
00747       numClausesFormulas = mlns[0]->getClauses()->size();
00748 
00749 
00751   Array<double> wts;
00752 
00753     // Discriminative learning
00754   if (discLearn) 
00755   {
00756     wts.growToSize(numClausesFormulas + 1);
00757     double* wwts = (double*) wts.getItems();
00758     wwts++;
00759       // Non-evid preds as a string
00760     string nePredsStr = nonEvidPredsStr;
00761 
00762       // Set SampleSat parameters
00763     SampleSatParams* ssparams = new SampleSatParams;
00764     ssparams->lateSa = assLateSa;
00765     ssparams->saRatio = assSaRatio;
00766     ssparams->saTemp = assSaTemp;
00767 
00768       // Set MaxWalksat parameters
00769     MaxWalksatParams* mwsparams = NULL;
00770     mwsparams = new MaxWalksatParams;
00771     mwsparams->ssParams = ssparams;
00772     mwsparams->maxSteps = amwsMaxSteps;
00773     mwsparams->maxTries = amwsTries;
00774     mwsparams->targetCost = amwsTargetWt;
00775     mwsparams->hard = amwsHard;
00776       // numSolutions only applies when used in SampleSat.
00777       // When just MWS, this is set to 1
00778     mwsparams->numSolutions = amwsNumSolutions;
00779     mwsparams->heuristic = amwsHeuristic;
00780     mwsparams->tabuLength = amwsTabuLength;
00781     mwsparams->lazyLowState = amwsLazyLowState;
00782 
00783       // Set MC-SAT parameters
00784     MCSatParams* msparams = new MCSatParams;
00785     msparams->mwsParams = mwsparams;
00786       // MC-SAT needs only one chain
00787     msparams->numChains          = 1;
00788     msparams->burnMinSteps       = amcmcBurnMinSteps;
00789     msparams->burnMaxSteps       = amcmcBurnMaxSteps;
00790     msparams->minSteps           = amcmcMinSteps;
00791     msparams->maxSteps           = amcmcMaxSteps;
00792     msparams->maxSeconds         = amcmcMaxSeconds;
00793 
00794       // Set Gibbs parameters
00795     GibbsParams* gibbsparams = new GibbsParams;
00796     gibbsparams->mwsParams    = mwsparams;
00797     gibbsparams->numChains    = amcmcNumChains;
00798     gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00799     gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00800     gibbsparams->minSteps     = amcmcMinSteps;
00801     gibbsparams->maxSteps     = amcmcMaxSteps;
00802     gibbsparams->maxSeconds   = amcmcMaxSeconds;
00803 
00804     gibbsparams->gamma          = 1 - agibbsDelta;
00805     gibbsparams->epsilonError   = agibbsEpsilonError;
00806     gibbsparams->fracConverged  = agibbsFracConverged;
00807     gibbsparams->walksatType    = agibbsWalksatType;
00808     gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00809   
00810       // Set Sim. Tempering parameters
00811     SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00812     stparams->mwsParams    = mwsparams;
00813     stparams->numChains    = amcmcNumChains;
00814     stparams->burnMinSteps = amcmcBurnMinSteps;
00815     stparams->burnMaxSteps = amcmcBurnMaxSteps;
00816     stparams->minSteps     = amcmcMinSteps;
00817     stparams->maxSteps     = amcmcMaxSteps;
00818     stparams->maxSeconds   = amcmcMaxSeconds;
00819 
00820     stparams->subInterval = asimtpSubInterval;
00821     stparams->numST       = asimtpNumST;
00822     stparams->numSwap     = asimtpNumSwap;
00823 
00824     Array<VariableState*> states;
00825     Array<Inference*> inferences;
00826 
00827     states.growToSize(domains.size());
00828     inferences.growToSize(domains.size());
00829 
00830       // Build the state for inference in each domain
00831     Array<int> allPredGndingsAreNonEvid;
00832     Array<Predicate*> ppreds;
00833     
00834       // Need to set some dummy weight (only in mln0 as clauses are shared)
00835 //      for (int j = 0; j < mln->getNumClauses(); j++)
00836 //        ((Clause*) mln->getClause(j))->setWt(1);
00837     for (int j = 0; j < mlns[0]->getNumClauses(); j++) 
00838     {
00839       Clause* c = (Clause*) mlns[0]->getClause(j);
00840         // If the weight was set to non-zero in the source MLN,
00841         // don't modify it while learning.
00842       if (c->getWt() != 0)
00843         c->lock();
00844       c->setWt(1);
00845     }
00846 
00847     for (int i = 0; i < domains.size(); i++)
00848     {
00849       Domain* domain = domains[i];
00850       MLN* mln = mlns[i];
00851         // Domains have been built: If user doesn't provide number of MC-SAT
00852         // steps, then use 10,000 / (min # of gndings of any clause), but not
00853         // less than 5. This is to insure 10,000 samples of all clauses
00854       if (amcmcMaxSteps <= 0)
00855       {
00856         int minSize = INT_MAX;
00857         
00858         for (int c = 0; c < mln->getNumClauses(); c++)
00859         {
00860           Clause* clause = (Clause*)mln->getClause(c);
00861           double size = clause->getNumGroundings(domain);
00862           if (size < minSize) minSize = (int)size;
00863         }
00864         int steps = 10000 / minSize;
00865         if (steps < 5) steps = 5;
00866         cout << "Setting number of MCMC steps to " << steps << endl;
00867         amcmcMaxSteps = steps;
00868         msparams->maxSteps = amcmcMaxSteps;
00869         gibbsparams->maxSteps = amcmcMaxSteps;
00870         stparams->maxSteps = amcmcMaxSteps;        
00871       }
00872 
00873         // Remove evidence atoms structure from DBs
00874       if (!aLazy)
00875         domains[i]->getDB()->setLazyFlag(false);
00876     
00877         // Unknown non-ev. preds
00878       GroundPredicateHashArray* unePreds = NULL;
00879 
00880         // Known non-ev. preds
00881       GroundPredicateHashArray* knePreds = NULL;
00882       Array<TruthValue>* knePredValues = NULL;
00883 
00884                 // Make open-world evidence preds into non-evidence
00885       if (!allPredsExceptQueriesAreCW)
00886       {
00887         for (int i = 0; i < owPredNames.size(); i++)
00888         {
00889           nePredsStr.append(",");
00890           nePredsStr.append(owPredNames[i]);
00891           nonEvidPredNames.append(owPredNames[i]);
00892         }
00893       }
00894 
00895       Array<Predicate*> gpreds;
00896       Array<TruthValue> gpredValues;
00897         // Eager version: Build query preds from command line and set known
00898         // non-evidence to unknown for building the states
00899       if (!aLazy)
00900       {
00901         unePreds = new GroundPredicateHashArray;
00902         knePreds = new GroundPredicateHashArray;
00903         knePredValues = new Array<TruthValue>;
00904 
00905         allPredGndingsAreNonEvid.growToSize(domain->getNumPredicates(), false);
00906           //defined in infer.h
00907         createComLineQueryPreds(nePredsStr, domain, domain->getDB(), 
00908                                 unePreds, knePreds, 
00909                                 &allPredGndingsAreNonEvid, NULL);
00910 
00911           // Pred values not set to unknown in DB: unePreds contains
00912           // unknown, knePreds contains known non-evidence
00913 
00914           // Set known NE to unknown for building state
00915           // and set blockEvidence to false if this was the true evidence
00916         knePredValues->growToSize(knePreds->size(), FALSE);
00917         for (int predno = 0; predno < knePreds->size(); predno++)
00918         {
00919             // If this was the true evidence in block, then erase this info
00920           int blockIdx = domain->getBlock((*knePreds)[predno]);
00921           if (blockIdx > -1 &&
00922               domain->getDB()->getValue((*knePreds)[predno]) == TRUE)
00923           {
00924             domain->setBlockEvidence(blockIdx, false);
00925           }
00926             // Set value to unknown
00927           (*knePredValues)[predno] =
00928             domain->getDB()->setValue((*knePreds)[predno], UNKNOWN);
00929         }
00930 
00931           // If first order query pred groundings are allowed to be evidence
00932           // - we assume all the predicates not in db to be false
00933           // evidence - need a better way to code this.
00934         if (isQueryEvidence)
00935         {
00936             // Set unknown NE to false
00937           for (int predno = 0; predno < unePreds->size(); predno++)
00938           {
00939             domain->getDB()->setValue((*unePreds)[predno], FALSE);
00940             delete (*unePreds)[predno];
00941           }
00942           unePreds->clear();
00943         }
00944       }
00945       else
00946       {
00947         Array<Predicate*> ppreds;
00948 
00949         domain->getDB()->setPerformingInference(false);
00950 
00951         gpreds.clear();
00952         gpredValues.clear();
00953         for (int predno = 0; predno < nonEvidPredNames.size(); predno++) 
00954         {
00955           ppreds.clear();
00956           int predid = domain->getPredicateId(nonEvidPredNames[predno].c_str());
00957           Predicate::createAllGroundings(predid, domain, ppreds);
00958           gpreds.append(ppreds);
00959         }
00960         //domain->getDB()->alterTruthValue(&gpreds, UNKNOWN, FALSE, &gpredValues);
00961         domain->getDB()->setValuesToUnknown(&gpreds, &gpredValues);
00962       }
00963       
00964         // Create state for inferred counts using unknown and known (set to
00965         // unknown in the DB) non-evidence preds
00966       cout << endl << "constructing state for domain " << i << "..." << endl;
00967       bool markHardGndClauses = false;
00968       bool trackParentClauseWts = true;
00969 
00970       VariableState*& state = states[i];
00971       state = new VariableState(unePreds, knePreds, knePredValues,
00972                                 &allPredGndingsAreNonEvid, markHardGndClauses,
00973                                 trackParentClauseWts, mln, domain, aLazy);
00974 
00975       Inference*& inference = inferences[i];
00976       bool trackClauseTrueCnts = true;
00977         // Different inference algorithms
00978       if (amapPos || amapAll)
00979       { // MaxWalkSat
00980           // When standalone MWS, numSolutions is always 1
00981           // (maybe there is a better way to do this?)
00982         mwsparams->numSolutions = 1;
00983         inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts,
00984                                    mwsparams);
00985       }
00986       else if (amcsatInfer)
00987       { // MC-SAT
00988         inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
00989       }
00990       else if (asimtpInfer)
00991       { // Simulated Tempering
00992           // When MWS is used in Sim. Temp., numSolutions is always 1
00993           // (maybe there is a better way to do this?)
00994         mwsparams->numSolutions = 1;
00995         inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
00996                                            stparams);
00997       }
00998       else if (agibbsInfer)
00999       { // Gibbs sampling
01000           // When MWS is used in Gibbs, numSolutions is always 1
01001           // (maybe there is a better way to do this?)
01002         mwsparams->numSolutions = 1;
01003         inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
01004                                      gibbsparams);
01005       }
01006 
01007       if (!aLazy)
01008       {
01009           // Change known NE to original values
01010         domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
01011           // Set unknown NE to false for weight initialization. This seems to
01012           // give poor results when using EM. We need to leave these
01013           // as unknown and do the counts accordingly
01014         for (int predno = 0; predno < unePreds->size(); predno++)
01015         {
01016           domain->getDB()->setValue((*unePreds)[predno], FALSE);
01017         }
01018       }
01019       else
01020       {
01021         domain->getDB()->setValuesToGivenValues(&gpreds, &gpredValues);
01022       
01023         //cout << "the ground predicates are :" << endl;
01024         for (int predno = 0; predno < gpreds.size(); predno++) 
01025         {
01026           //gpreds[predno]->printWithStrVar(cout, domain);
01027           //cout << endl;
01028           delete gpreds[predno];
01029         }
01030 
01031         domain->getDB()->setPerformingInference(true);
01032       }
01033     }
01034     cout << endl << "done constructing variable states" << endl << endl;
01035     
01036     if (useVP)
01037       discMethod = DiscriminativeLearner::SIMPLE;
01038     else if (useNewton)
01039       discMethod = DiscriminativeLearner::DN;
01040     else
01041       discMethod = DiscriminativeLearner::CG;
01042 
01043     DiscriminativeLearner dl(inferences, nonEvidPredNames, indexTrans, aLazy,
01044                              withEM, !noUsePerWeight, discMethod, cg_lambda,
01045                              !cg_noprecond, cg_max_lambda);
01046 
01047     if (!noPrior) 
01048       dl.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
01049                          priorStdDevs.getItems());
01050     else
01051       dl.setMeansStdDevs(-1, NULL, NULL);
01052          
01053     begSec = timer.time();
01054     cout << "learning (discriminative) weights .. " << endl;
01055     double maxTime = maxSec + 60*maxMin + 3600*maxHour;
01056     dl.learnWeights(wwts, wts.size()-1, numIter, maxTime, learningRate, 
01057                     momentum, initWithLogOdds, amwsMaxSubsequentSteps,
01058                     aPeriodicMLNs);
01059     cout << endl << endl << "Done learning discriminative weights. "<< endl;
01060     cout << "Time Taken for learning = ";
01061     Timer::printTime(cout, (timer.time() - begSec)); cout << endl;
01062 
01063     if (mwsparams) delete mwsparams;
01064     if (ssparams) delete ssparams;
01065     if (msparams) delete msparams;
01066     if (gibbsparams) delete gibbsparams;
01067     if (stparams) delete stparams;
01068     for (int i = 0; i < inferences.size(); i++)  delete inferences[i];
01069     for (int i = 0; i < states.size(); i++)  delete states[i];
01070   } 
01071   else
01072   {   
01074 
01075     Array<bool> areNonEvidPreds;
01076     if (nonEvidPredNames.empty())
01077     {
01078       areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), true);
01079       for (int i = 0; i < domains[0]->getNumPredicates(); i++)
01080       {
01081           //prevent equal pred from being non-evidence preds
01082         if (domains[0]->getPredicateTemplate(i)->isEqualPred())
01083         {
01084           const char* pname = domains[0]->getPredicateTemplate(i)->getName();
01085           int predId = domains[0]->getPredicateId(pname);
01086           areNonEvidPreds[predId] = false;
01087         }
01088           //prevent internal preds from being non-evidence preds
01089         if (domains[0]->getPredicateTemplate(i)->isInternalPredicateTemplate())
01090         {
01091           const char* pname = domains[0]->getPredicateTemplate(i)->getName();
01092           int predId = domains[0]->getPredicateId(pname);
01093           areNonEvidPreds[predId] = false;
01094         }
01095       }
01096     } 
01097     else
01098     {
01099       areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), false);
01100       for (int i = 0; i < nonEvidPredNames.size(); i++)
01101       {
01102         int predId = domains[0]->getPredicateId(nonEvidPredNames[i].c_str());
01103         if (predId < 0)
01104         {
01105           cout << "ERROR: Predicate " << nonEvidPredNames[i] << " undefined." 
01106                << endl;
01107           exit(-1);
01108         }
01109         areNonEvidPreds[predId] = true;
01110       }
01111     }
01112 
01113     Array<bool>* nePreds = &areNonEvidPreds;
01114     PseudoLogLikelihood pll(nePreds, &domains, !noEqualPredWt, false,-1,-1,-1);
01115     pll.setIndexTranslator(indexTrans);
01116 
01117     if (!noPrior) 
01118       pll.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
01119                           priorStdDevs.getItems());
01120     else          
01121       pll.setMeansStdDevs(-1, NULL, NULL);
01122     
01124 
01125     begSec = timer.time();
01126     for (int m = 0; m < mlns.size(); m++)
01127     {
01128       cout << "Computing counts for clauses in domain " << m << "..." << endl;
01129       const ClauseHashArray* clauses = mlns[m]->getClauses();
01130       for (int i = 0; i < clauses->size(); i++)
01131       {
01132         if (PRINT_CLAUSE_DURING_COUNT)
01133         {
01134           cout << "clause " << i << ": ";
01135           (*clauses)[i]->printWithoutWt(cout, domains[m]);
01136           cout << endl; cout.flush();
01137         }
01138         MLNClauseInfo* ci = (MLNClauseInfo*) mlns[m]->getMLNClauseInfo(i);
01139         pll.computeCountsForNewAppendedClause((*clauses)[i], &(ci->index), m, 
01140                                               NULL, false, NULL);
01141       }
01142     }
01143     pll.compress();
01144     cout <<"Computing counts took ";
01145     Timer::printTime(cout, timer.time() - begSec); cout << endl;
01146     
01148 
01149       // initialize the clause weights
01150     wts.growToSize(numClausesFormulas + 1);
01151     for (int i = 0; i < numClausesFormulas; i++) wts[i+1] = 0;
01152     //wts[i+1] = priorMeans[i];
01153 
01154       // optimize the clause weights
01155     cout << "L-BFGS-B is finding optimal weights......" << endl;
01156     begSec = timer.time();
01157     LBFGSB lbfgsb(maxIter, convThresh, &pll, numClausesFormulas);
01158     int iter;
01159     bool error;
01160     double pllValue = lbfgsb.minimize((double*)wts.getItems(), iter, error);
01161     
01162     if (error) cout << "LBFGSB returned with an error!" << endl;
01163     cout << "num iterations        = " << iter << endl;
01164     cout << "time taken            = ";
01165     Timer::printTime(cout, timer.time() - begSec);
01166     cout << endl;
01167     cout << "pseudo-log-likelihood = " << -pllValue << endl;
01168 
01169   } // else using generative learning
01170 
01172   if (indexTrans) assignWtsAndOutputMLN(out, mlns, domains, wts, indexTrans);
01173   else            assignWtsAndOutputMLN(out, mlns, domains, wts);
01174 
01175   out.close();
01176 
01178   deleteDomains(domains);
01179 
01180   for (int i = 0; i < mlns.size(); i++)
01181   {
01182     if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
01183     {
01184       mlns[i]->setClauses(NULL);
01185       mlns[i]->setMLNClauseInfos(NULL);
01186       mlns[i]->setPredIdToClausesMap(NULL);
01187       mlns[i]->setFormulaAndClausesArray(NULL);
01188       mlns[i]->setExternalClause(NULL);
01189     }
01190     delete mlns[i];
01191   }
01192 
01193   PowerSet::deletePowerSet();
01194   if (indexTrans) delete indexTrans;
01195 
01196   cout << "Total time = "; 
01197   Timer::printTime(cout, timer.time() - startSec); cout << endl;
01198 }

Generated on Sun Jun 7 11:55:14 2009 for Alchemy by  doxygen 1.5.1