learnwts.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #include <fstream>
00067 #include <iostream>
00068 #include <sstream>
00069 #include "arguments.h"
00070 #include "inferenceargs.h"
00071 #include "lbfgsb.h"
00072 #include "votedperceptron.h"
00073 #include "learnwts.h"
00074 #include "maxwalksat.h"
00075 #include "mcsat.h"
00076 #include "gibbssampler.h"
00077 #include "simulatedtempering.h"
00078 
00079   //set to false to disable printing of clauses when they are counted during 
00080   //generative learning
00081 bool PRINT_CLAUSE_DURING_COUNT = true;
00082 
00083 const double DISC_DEFAULT_STD_DEV = 1;
00084 const double GEN_DEFAULT_STD_DEV = 100;
00085 
00086   // Variables for holding inference command line args are in inferenceargs.h
00087 bool discLearn = false;
00088 bool genLearn = false;
00089 char* outMLNFile = NULL;
00090 char* dbFiles = NULL;
00091 char* nonEvidPredsStr = NULL;
00092 bool noAddUnitClauses = false;
00093 bool multipleDatabases = false;
00094 bool initToZero = false;
00095 bool isQueryEvidence = false;
00096 
00097 bool noPrior = false;
00098 double priorMean = 0;
00099 double priorStdDev = -1;
00100 
00101   // Generative learning args
00102 int maxIter = 10000;
00103 double convThresh = 1e-5;
00104 bool noEqualPredWt = false;
00105 
00106   // Discriminative learning args
00107 int numIter = 200;
00108 double learningRate = 0.001;
00109 double momentum = 0.0;
00110 bool rescaleGradient = false;
00111 bool withEM = false;
00112 char* aInferStr = NULL;
00113 int amwsMaxSubsequentSteps = -1;
00114 
00115 
00116   // Inference arguments needed for disc. learning defined in inferenceargs.h
00117   // TODO: List the arguments common to learnwts and inference in
00118   // inferenceargs.h. This can't be done with a static array.
00119 ARGS ARGS::Args[] = 
00120 {
00121     // BEGIN: Common arguments
00122   ARGS("i", ARGS::Req, ainMLNFiles, 
00123        "Comma-separated input .mln files. (With the -multipleDatabases "
00124        "option, the second file to the last one are used to contain constants "
00125        "from different databases, and they correspond to the .db files "
00126        "specified with the -t option.)"),
00127 
00128   ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00129        "Specified non-evidence atoms (comma-separated with no space) are "
00130        "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00131        "appearing here cannot be query atoms and cannot appear in the -o "
00132        "option."),
00133 
00134   ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00135        "Specified evidence atoms (comma-separated with no space) are open "
00136        "world, while other evidence atoms are closed-world. "
00137        "Atoms appearing here cannot appear in the -c option."),
00138     // END: Common arguments
00139 
00140     // BEGIN: Common inference arguments
00141   ARGS("m", ARGS::Tog, amapPos, 
00142        "Run MAP inference and return only positive query atoms."),
00143 
00144   ARGS("a", ARGS::Tog, amapAll, 
00145        "Run MAP inference and show 0/1 results for all query atoms."),
00146 
00147   ARGS("p", ARGS::Tog, agibbsInfer, 
00148        "Run inference using MCMC (Gibbs sampling) and return probabilities "
00149        "for all query atoms."),
00150   
00151   ARGS("ms", ARGS::Tog, amcsatInfer,
00152        "Run inference using MC-SAT and return probabilities "
00153        "for all query atoms"),
00154 
00155   ARGS("simtp", ARGS::Tog, asimtpInfer,
00156        "Run inference using simulated tempering and return probabilities "
00157        "for all query atoms"),
00158 
00159   ARGS("seed", ARGS::Opt, aSeed,
00160        "[random] Seed used to initialize the randomizer in the inference "
00161        "algorithm. If not set, seed is initialized from the current date and "
00162        "time."),
00163 
00164   ARGS("lazy", ARGS::Opt, aLazy, 
00165        "[false] Run lazy version of inference if this flag is set."),
00166   
00167   ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox, 
00168        "[false] Lazy version of inference will not approximate by deactivating "
00169        "atoms to save memory. This flag is ignored if -lazy is not set."),
00170   
00171   ARGS("memLimit", ARGS::Opt, aMemLimit, 
00172        "[-1] Maximum limit in kbytes which should be used for inference. "
00173        "-1 means main memory available on system is used."),
00174     // END: Common inference arguments
00175 
00176     // BEGIN: MaxWalkSat args
00177   ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00178        "[1000000] (MaxWalkSat) The max number of steps taken."),
00179 
00180   ARGS("tries", ARGS::Opt, amwsTries, 
00181        "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00182 
00183   ARGS("targetWt", ARGS::Opt, amwsTargetWt,
00184        "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00185        "with weight <= specified weight."),
00186 
00187   ARGS("hard", ARGS::Opt, amwsHard, 
00188        "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00189        "satisfy a soft one."),
00190   
00191   ARGS("heuristic", ARGS::Opt, amwsHeuristic,
00192        "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00193        "2 = TABU, 3 = SAMPLESAT)."),
00194   
00195   ARGS("tabuLength", ARGS::Opt, amwsTabuLength,
00196        "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00197        "atom when using the tabu heuristic in MaxWalkSat." ),
00198 
00199   ARGS("lazyLowState", ARGS::Opt, amwsLazyLowState, 
00200        "[false] (MaxWalkSat) If false, the naive way of saving low states "
00201        "(each time a low state is found, the whole state is saved) is used; "
00202        "otherwise, a list of variables flipped since the last low state is "
00203        "kept and the low state is reconstructed. This can be much faster for "
00204        "very large data sets."),  
00205     // END: MaxWalkSat args
00206 
00207     // BEGIN: MCMC args
00208   ARGS("burnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00209        "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00210 
00211   ARGS("burnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00212        "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00213 
00214   ARGS("minSteps", ARGS::Opt, amcmcMinSteps, 
00215        "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00216 
00217   ARGS("maxSteps", ARGS::Opt, amcmcMaxSteps, 
00218        "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00219 
00220   ARGS("maxSeconds", ARGS::Opt, amcmcMaxSeconds, 
00221        "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00222     // END: MCMC args
00223   
00224     // BEGIN: Simulated tempering args
00225   ARGS("subInterval", ARGS::Opt, asimtpSubInterval,
00226         "[2] (Simulated Tempering) Selection interval between swap attempts"),
00227 
00228   ARGS("numRuns", ARGS::Opt, asimtpNumST,
00229         "[3] (Simulated Tempering) Number of simulated tempering runs"),
00230 
00231   ARGS("numSwap", ARGS::Opt, asimtpNumSwap,
00232         "[10] (Simulated Tempering) Number of swapping chains"),
00233     // END: Simulated tempering args
00234 
00235     // BEGIN: MC-SAT args
00236   ARGS("numStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00237        "[1] (MC-SAT) Number of total steps (mcsat + gibbs) for every mcsat "
00238        "step"),
00239     // END: MC-SAT args
00240 
00241     // BEGIN: SampleSat args
00242   ARGS("numSolutions", ARGS::Opt, amwsNumSolutions,
00243        "[10] (MC-SAT) Return nth SAT solution in SampleSat"),
00244 
00245   ARGS("saRatio", ARGS::Opt, assSaRatio,
00246        "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00247        "MC-SAT"),
00248 
00249   ARGS("saTemperature", ARGS::Opt, assSaTemp,
00250         "[10] (MC-SAT) Temperature (/100) for sim. annealing step in "
00251         "SampleSat"),
00252 
00253   ARGS("lateSa", ARGS::Tog, assLateSa,
00254        "[false] Run simulated annealing from the start in SampleSat"),
00255     // END: SampleSat args
00256 
00257     // BEGIN: Gibbs sampling args
00258   ARGS("numChains", ARGS::Opt, amcmcNumChains, 
00259        "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00260        "at least 2)."),
00261 
00262   ARGS("delta", ARGS::Opt, agibbsDelta,
00263        "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00264        "exceeded is less than this value."),
00265 
00266   ARGS("epsilonError", ARGS::Opt, agibbsEpsilonError,
00267        "[0.01] (Gibbs) Fractional error from true probability."),
00268 
00269   ARGS("fracConverged", ARGS::Opt, agibbsFracConverged, 
00270        "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00271        "have converged."),
00272 
00273   ARGS("walksatType", ARGS::Opt, agibbsWalksatType, 
00274        "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00275        "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00276 
00277   ARGS("samplesPerTest", ARGS::Opt, agibbsSamplesPerTest, 
00278        "[100] Perform convergence test once after this many number of samples "
00279        "per chain."),
00280     // END: Gibbs sampling args
00281 
00282     // BEGIN: Weight learning specific args
00283   ARGS("infer", ARGS::Opt, aInferStr,
00284        "Specified inference parameters when using discriminative learning. "
00285        "The arguments are to be encapsulated in \"\" and the syntax is "
00286        "identical to the infer command (run infer with no commands to see "
00287        "this). If not specified, MaxWalkSat with default parameters is used."),
00288 
00289   ARGS("d", ARGS::Tog, discLearn, "Discriminative weight learning."),
00290 
00291   ARGS("g", ARGS::Tog, genLearn, "Generative weight learning."),
00292 
00293   ARGS("o", ARGS::Req, outMLNFile, 
00294        "Output .mln file containing formulas with learned weights."),
00295 
00296   ARGS("t", ARGS::Req, dbFiles, 
00297        "Comma-separated .db files containing the training database "
00298        "(of true/false ground atoms), including function definitions, "
00299        "e.g. ai.db,graphics.db,languages.db."),
00300 
00301   ARGS("ne", ARGS::Opt, nonEvidPredsStr, 
00302        "First-order non-evidence predicates (comma-separated with no space),  "
00303        "e.g., cancer,smokes,friends. For discriminative learning, at least "
00304        "one non-evidence predicate must be specified. For generative learning, "
00305        "the specified predicates are included in the (weighted) pseudo-log-"
00306        "likelihood computation; if none are specified, all are included."),
00307     
00308   ARGS("noAddUnitClauses", ARGS::Tog, noAddUnitClauses,
00309        "If specified, unit clauses are not included in the .mln file; "
00310        "otherwise they are included."),
00311 
00312   ARGS("multipleDatabases", ARGS::Tog, multipleDatabases,
00313        "If specified, each .db file belongs to a separate database; "
00314        "otherwise all .db files belong to the same database."),
00315 
00316   ARGS("withEM", ARGS::Tog, withEM,
00317        "If set, EM is used to fill in missing truth values; "
00318        "otherwise missing truth values are set to false."),
00319 
00320   ARGS("dNumIter", ARGS::Opt, numIter, 
00321        "[200] (For discriminative learning only.) "
00322        "Number of iterations to run voted perceptron."),
00323   
00324   ARGS("dLearningRate", ARGS::Opt, learningRate, 
00325        "[0.001] (For discriminative learning only) "
00326        "Learning rate for the gradient descent in voted perceptron algorithm."),
00327 
00328   ARGS("dMomentum", ARGS::Opt, momentum, 
00329        "[0.0] (For discriminative learning only) "
00330        "Momentum term for the gradient descent in voted perceptron algorithm."),
00331        
00332   ARGS("queryEvidence", ARGS::Tog, isQueryEvidence, 
00333        "If this flag is set, then all the groundings of query preds not in db "
00334        "are assumed false evidence."),
00335        
00336   ARGS("dRescale", ARGS::Tog, rescaleGradient, 
00337        "(For discriminative learning only.) "
00338        "Rescale the gradient by the number of true groundings per weight."),
00339 
00340   ARGS("dZeroInit", ARGS::Tog, initToZero,
00341        "(For discriminative learning only.) "
00342        "Initialize clause weights to zero instead of their log odds."),
00343 
00344   ARGS("dMwsMaxSubsequentSteps", ARGS::Opt, amwsMaxSubsequentSteps,
00345        "[Same as mwsMaxSteps] (For discriminative learning only.) The max "
00346        "number of MaxWalkSat steps taken in subsequent iterations (>= 2) of "
00347        "disc. learning. If not specified, mwsMaxSteps is used in each "
00348        "iteration"),
00349   
00350   ARGS("gMaxIter", ARGS::Opt, maxIter, 
00351        "[10000] (For generative learning only.) "
00352        "Max number of iterations to run L-BFGS-B, "
00353        "the optimization algorithm for generative learning."),
00354   
00355   ARGS("gConvThresh", ARGS::Opt, convThresh, 
00356        "[1e-5] (For generative learning only.) "
00357        "Fractional change in pseudo-log-likelihood at which "
00358        "L-BFGS-B terminates."),
00359 
00360   ARGS("gNoEqualPredWt", ARGS::Opt, noEqualPredWt, 
00361        "(For generative learning only.) "
00362        "If specified, the predicates are not weighted equally in the "
00363        "pseudo-log-likelihood computation; otherwise they are."),
00364   
00365   ARGS("noPrior", ARGS::Tog, noPrior, "No Gaussian priors on formula weights."),
00366 
00367   ARGS("priorMean", ARGS::Opt, priorMean, 
00368        "[0] Means of Gaussian priors on formula weights. By default, "
00369        "for each formula, it is the weight given in the .mln input file, " 
00370        "or fraction thereof if the formula turns into multiple clauses. "
00371        "This mean applies if no weight is given in the .mln file."),
00372 
00373   ARGS("priorStdDev", ARGS::Opt, priorStdDev, 
00374        "[1 for discriminative learning. 100 for generative learning] "
00375        "Standard deviations of Gaussian priors on clause weights."),
00376 
00377   ARGS()
00378 };
00379 
00380 //bool extractPredNames(...) defined in infer.h
00381 
00382 int main(int argc, char* argv[])
00383 {
00384   ARGS::parse(argc,argv,&cout);
00385 
00386   if (!discLearn && !genLearn) 
00387   { 
00388     cout << "must specify either -d or -g "
00389          <<"(discriminative or generative learning) " << endl; 
00390     return -1; 
00391   }
00392 
00393   Timer timer;
00394   double startSec = timer.time();
00395   double begSec;
00396 
00397   if (priorStdDev < 0)
00398   {
00399     if (discLearn) 
00400     { 
00401       cout << "priorStdDev set to (discriminative learning's) default of " 
00402            << DISC_DEFAULT_STD_DEV << endl;
00403       priorStdDev = DISC_DEFAULT_STD_DEV;
00404     }
00405     else
00406     {
00407       cout << "priorStdDev set to (generative learning's) default of " 
00408            << GEN_DEFAULT_STD_DEV << endl;
00409       priorStdDev = GEN_DEFAULT_STD_DEV;      
00410     }
00411   }
00412 
00413 
00415   if (discLearn && nonEvidPredsStr == NULL)
00416   {
00417     cout << "ERROR: you must specify non-evidence predicates for "
00418          << "discriminative learning" << endl;
00419     return -1;
00420   }
00421 
00422   if (maxIter <= 0)  { cout << "maxIter must be > 0" << endl; return -1; }
00423   if (convThresh <= 0 || convThresh > 1)  
00424   { cout << "convThresh must be > 0 and <= 1" << endl; return -1;  }
00425   if (priorStdDev <= 0) { cout << "priorStdDev must be > 0" << endl; return -1;}
00426 
00427   if (amwsMaxSteps <= 0)
00428   { cout << "ERROR: maxSteps must be positive" << endl; return -1; }
00429   
00430     // If max. subsequent steps not specified, use amwsMaxSteps
00431   if (amwsMaxSubsequentSteps <= 0) amwsMaxSubsequentSteps = amwsMaxSteps;
00432 
00433   if (amwsTries <= 0)
00434   { cout << "ERROR: tries must be positive" << endl; return -1; }
00435 
00436   if (aMemLimit <= 0 && aMemLimit != -1)
00437   { cout << "ERROR: limit must be positive (or -1)" << endl; return -1; }
00438 
00439   if (!discLearn && aLazy)
00440   {
00441     cout << "ERROR: lazy can only be used with discriminative learning"
00442          << endl;
00443     return -1;
00444   }
00445 
00446   ofstream out(outMLNFile);
00447   if (!out.good())
00448   {
00449     cout << "ERROR: unable to open " << outMLNFile << endl;
00450     return -1;
00451   }
00452 
00453     // Parse the inference parameters, if given
00454   if (discLearn)
00455   {
00456       // If no inference method indicated, then use MAP
00457     if (!aInferStr)
00458     {
00459       amapPos = true;
00460     }
00461       // If inference method given, we need to parse the parameters
00462     else
00463     {
00464       int inferArgc = 0;
00465       char **inferArgv = new char*[200];
00466       for (int i = 0; i < 200; i++)
00467       {
00468         inferArgv[i] = new char[30];
00469       }
00470 
00471       extractArgs(aInferStr, inferArgc, inferArgv);
00472       cout << "extractArgs " << inferArgc << endl;
00473       for (int i = 0; i < inferArgc; i++)
00474       {
00475         cout << i << ": " << inferArgv[i] << endl;
00476       }
00477 
00478       ARGS::parseFromCommandLine(inferArgc, inferArgv);
00479 
00480         // HACK: Argument parser doesn't parse the ARGS::Tog right, so do
00481         // it here by hand
00482       for (int i = 0; i < inferArgc; i++)
00483       {
00484         if (string(inferArgv[i]) == "-m") amapPos = true;
00485         else if (string(inferArgv[i]) == "-a") amapAll = true;
00486         else if (string(inferArgv[i]) == "-p") agibbsInfer = true;
00487         else if (string(inferArgv[i]) == "-ms") amcsatInfer = true;
00488         else if (string(inferArgv[i]) == "-simtp") asimtpInfer = true;
00489       }
00490 
00491         // Delete memory allocated for args
00492       for (int i = 0; i < 200; i++)
00493       {
00494         delete[] inferArgv[i];
00495       }
00496       delete[] inferArgv; 
00497     }
00498   }
00499 
00500 
00501   //the second .mln file to the last one in ainMLNFiles _may_ be used 
00502   //to hold constants, so they are held in constFilesArr. They will be
00503   //included into the first .mln file.
00504 
00505     //extract .mln and .db, file names
00506   Array<string> constFilesArr;
00507   Array<string> dbFilesArr;
00508   extractFileNames(ainMLNFiles, constFilesArr);
00509   assert(constFilesArr.size() >= 1);
00510   string inMLNFile = constFilesArr[0];
00511   constFilesArr.removeItem(0);
00512   extractFileNames(dbFiles, dbFilesArr);
00513 
00514   if (dbFilesArr.size() <= 0)
00515   {cout<<"ERROR: must specify training data with -t option."<<endl; return -1;}
00516  
00517     // if multiple databases, check the number of .db/.func files
00518   if (multipleDatabases) 
00519   {
00520       //if # .mln files containing constants/.func files and .db files are diff
00521     if ((constFilesArr.size() > 0 && constFilesArr.size() != dbFilesArr.size()))
00522     {
00523       cout << "ERROR: when there are multiple databases, if .mln files "
00524            << "containing constants are specified, there must " 
00525            << "be the same number of them as .db files" << endl;
00526       return -1;
00527     }
00528   }
00529 
00530   StringHashArray nonEvidPredNames;
00531   if (nonEvidPredsStr)
00532   {
00533     if(!extractPredNames(nonEvidPredsStr, NULL, nonEvidPredNames))
00534     {
00535       cout << "ERROR: failed to extract non-evidence predicate names." << endl;
00536       return -1;
00537     }
00538   }
00539 
00540   StringHashArray owPredNames;
00541   StringHashArray cwPredNames;
00542 
00544 
00545   cout << "Parsing MLN and creating domains..." << endl;
00546   StringHashArray* nePredNames = (discLearn) ? &nonEvidPredNames : NULL;
00547   Array<Domain*> domains;
00548   Array<MLN*> mlns;
00549   begSec = timer.time();
00550   bool allPredsExceptQueriesAreCW = true;
00551   if (discLearn)
00552   {
00553       //extract names of open-world evidence predicates
00554     if (aOpenWorldPredsStr)
00555     {
00556       if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames)) 
00557         return -1;
00558       assert(owPredNames.size() > 0);
00559     }
00560 
00561       //extract names of closed-world non-evidence predicates
00562     if (aClosedWorldPredsStr)
00563     {
00564       if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames)) 
00565         return -1;
00566       assert(cwPredNames.size() > 0);
00567       if (!checkQueryPredsNotInClosedWorldPreds(nonEvidPredNames, cwPredNames))
00568         return -1;
00569     }
00570  
00571     allPredsExceptQueriesAreCW = owPredNames.empty();
00572   }
00573     // Parse as if lazy inference is set to true to set evidence atoms in DB
00574     // If lazy is not used, this is removed from DB
00575   createDomainsAndMLNs(domains, mlns, multipleDatabases, inMLNFile, 
00576                        constFilesArr, dbFilesArr, nePredNames,
00577                        !noAddUnitClauses, priorMean, true,
00578                        allPredsExceptQueriesAreCW, &owPredNames);
00579   cout << "Parsing MLN and creating domains took "; 
00580   Timer::printTime(cout, timer.time() - begSec); cout << endl;
00581 
00582   /*
00583   cout << "Clause prior means:" << endl;
00584   cout << "_________________________________" << endl;
00585   mlns[0]->printClausePriorMeans(cout, domains[0]);
00586   cout << "_________________________________" << endl;
00587   cout << endl;
00588 
00589   cout << "Formula prior means:" << endl;
00590   cout << "_________________________________" << endl;
00591   mlns[0]->printFormulaPriorMeans(cout);
00592   cout << "_________________________________" << endl;
00593   cout << endl;
00594   */
00595 
00597 
00598     //we need an index translator if clauses do not line up across multiple DBs
00599   IndexTranslator* indexTrans 
00600     = (IndexTranslator::needIndexTranslator(mlns, domains)) ?
00601        new IndexTranslator(&mlns, &domains) : NULL;  
00602 
00603   if (indexTrans) 
00604     cout << endl << "the weights of clauses in the CNFs of existential"
00605          << " formulas will be tied" << endl;
00606 
00607 
00608   Array<double> priorMeans, priorStdDevs;
00609   if (!noPrior)
00610   {
00611     if (indexTrans)
00612     {
00613       indexTrans->setPriorMeans(priorMeans);
00614       priorStdDevs.growToSize(priorMeans.size());
00615       for (int i = 0; i < priorMeans.size(); i++)
00616         priorStdDevs[i] = priorStdDev;
00617     }
00618     else
00619     {
00620       const ClauseHashArray* clauses = mlns[0]->getClauses();
00621       int numClauses = clauses->size();
00622       for (int i = 0; i < numClauses; i++)
00623       {
00624         priorMeans.append((*clauses)[i]->getWt());
00625         priorStdDevs.append(priorStdDev);
00626       }
00627     }
00628   }
00629   // HACK -- not sure if this is right... but the old version was broke!
00630   // This may fail when there's an indexTrans.  [Daniel]
00631   //int numClausesFormulas = priorMeans.size();
00632   int numClausesFormulas = mlns[0]->getClauses()->size();
00633 
00634 
00636   Array<double> wts;
00637 
00638     // Discriminative learning
00639   if (discLearn) 
00640   {
00641     wts.growToSize(numClausesFormulas + 1);
00642     double* wwts = (double*) wts.getItems();
00643     wwts++;
00644       // Non-evid preds as a string
00645     string nePredsStr = nonEvidPredsStr;
00646 
00647       // Set SampleSat parameters
00648     SampleSatParams* ssparams = new SampleSatParams;
00649     ssparams->lateSa = assLateSa;
00650     ssparams->saRatio = assSaRatio;
00651     ssparams->saTemp = assSaTemp;
00652 
00653       // Set MaxWalksat parameters
00654     MaxWalksatParams* mwsparams = NULL;
00655     mwsparams = new MaxWalksatParams;
00656     mwsparams->ssParams = ssparams;
00657     mwsparams->maxSteps = amwsMaxSteps;
00658     mwsparams->maxTries = amwsTries;
00659     mwsparams->targetCost = amwsTargetWt;
00660     mwsparams->hard = amwsHard;
00661       // numSolutions only applies when used in SampleSat.
00662       // When just MWS, this is set to 1
00663     mwsparams->numSolutions = amwsNumSolutions;
00664     mwsparams->heuristic = amwsHeuristic;
00665     mwsparams->tabuLength = amwsTabuLength;
00666     mwsparams->lazyLowState = amwsLazyLowState;
00667 
00668       // Set MC-SAT parameters
00669     MCSatParams* msparams = new MCSatParams;
00670     msparams->mwsParams = mwsparams;
00671       // MC-SAT needs only one chain
00672     msparams->numChains          = 1;
00673     msparams->burnMinSteps       = amcmcBurnMinSteps;
00674     msparams->burnMaxSteps       = amcmcBurnMaxSteps;
00675     msparams->minSteps           = amcmcMinSteps;
00676     msparams->maxSteps           = amcmcMaxSteps;
00677     msparams->maxSeconds         = amcmcMaxSeconds;
00678     msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
00679 
00680       // Set Gibbs parameters
00681     GibbsParams* gibbsparams = new GibbsParams;
00682     gibbsparams->mwsParams    = mwsparams;
00683     gibbsparams->numChains    = amcmcNumChains;
00684     gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00685     gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00686     gibbsparams->minSteps     = amcmcMinSteps;
00687     gibbsparams->maxSteps     = amcmcMaxSteps;
00688     gibbsparams->maxSeconds   = amcmcMaxSeconds;
00689 
00690     gibbsparams->gamma          = 1 - agibbsDelta;
00691     gibbsparams->epsilonError   = agibbsEpsilonError;
00692     gibbsparams->fracConverged  = agibbsFracConverged;
00693     gibbsparams->walksatType    = agibbsWalksatType;
00694     gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00695   
00696       // Set Sim. Tempering parameters
00697     SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00698     stparams->mwsParams    = mwsparams;
00699     stparams->numChains    = amcmcNumChains;
00700     stparams->burnMinSteps = amcmcBurnMinSteps;
00701     stparams->burnMaxSteps = amcmcBurnMaxSteps;
00702     stparams->minSteps     = amcmcMinSteps;
00703     stparams->maxSteps     = amcmcMaxSteps;
00704     stparams->maxSeconds   = amcmcMaxSeconds;
00705 
00706     stparams->subInterval = asimtpSubInterval;
00707     stparams->numST       = asimtpNumST;
00708     stparams->numSwap     = asimtpNumSwap;
00709 
00710     Array<VariableState*> states;
00711     Array<Inference*> inferences;
00712 
00713     states.growToSize(domains.size());
00714     inferences.growToSize(domains.size());
00715 
00716       // Build the state for inference in each domain
00717     Array<int> allPredGndingsAreNonEvid;
00718     Array<Predicate*> ppreds;
00719     
00720     for (int i = 0; i < domains.size(); i++)
00721     {
00722       Domain* domain = domains[i];
00723       MLN* mln = mlns[i];
00724 
00725         // Remove evidence atoms structure from DBs
00726       if (!aLazy)
00727         domains[i]->getDB()->setLazyFlag(false);
00728     
00729         // Unknown non-ev. preds
00730       GroundPredicateHashArray* unePreds = NULL;
00731 
00732         // Known non-ev. preds
00733       GroundPredicateHashArray* knePreds = NULL;
00734       Array<TruthValue>* knePredValues = NULL;
00735 
00736         // Need to set some dummy weight
00737       for (int j = 0; j < mln->getNumClauses(); j++)
00738         ((Clause*) mln->getClause(j))->setWt(1);
00739 
00740                 // Make open-world evidence preds into non-evidence
00741       if (!allPredsExceptQueriesAreCW)
00742       {
00743         for (int i = 0; i < owPredNames.size(); i++)
00744         {
00745           nePredsStr.append(",");
00746           nePredsStr.append(owPredNames[i]);
00747           nonEvidPredNames.append(owPredNames[i]);
00748         }
00749       }
00750 
00751         // Eager version: Build query preds from command line and set known
00752         // non-evidence to unknown for building the states
00753       if (!aLazy)
00754       {
00755         unePreds = new GroundPredicateHashArray;
00756         knePreds = new GroundPredicateHashArray;
00757         knePredValues = new Array<TruthValue>;
00758 
00759         allPredGndingsAreNonEvid.growToSize(domain->getNumPredicates(), false);
00760           //defined in infer.h
00761         createComLineQueryPreds(nePredsStr, domain, domain->getDB(), 
00762                                 unePreds, knePreds, 
00763                                 &allPredGndingsAreNonEvid);
00764 
00765           // Pred values not set to unknown in DB: unePreds contains
00766           // unknown, knePreds contains known non-evidence
00767 
00768           // Set known NE to unknown for building state
00769         knePredValues->growToSize(knePreds->size(), FALSE);
00770         for (int predno = 0; predno < knePreds->size(); predno++) 
00771           (*knePredValues)[predno] =
00772             domain->getDB()->setValue((*knePreds)[predno], UNKNOWN);
00773 
00774           // If first order query pred groundings are allowed to be evidence
00775           // - we assume all the predicates not in db to be false
00776           // evidence - need a better way code this.
00777         if (isQueryEvidence)
00778             // Set unknown NE to false
00779           for (int predno = 0; predno < unePreds->size(); predno++) 
00780             domain->getDB()->setValue((*unePreds)[predno], FALSE);
00781       }
00782       
00783         // Create state for inferred counts using unknown and known (set to
00784         // unknown in the DB) non-evidence preds
00785       cout << endl << "constructing state for domain " << i << "..." << endl;
00786       bool markHardGndClauses = false;
00787       bool trackParentClauseWts = true;
00788       VariableState*& state = states[i];
00789       state = new VariableState(unePreds, knePreds, knePredValues,
00790                                 &allPredGndingsAreNonEvid, markHardGndClauses,
00791                                 trackParentClauseWts, mln, domain, aLazy);
00792 
00793       Inference*& inference = inferences[i];
00794       bool trackClauseTrueCnts = true;
00795         // Different inference algorithms
00796       if (amapPos || amapAll)
00797       { // MaxWalkSat
00798           // When standalone MWS, numSolutions is always 1
00799           // (maybe there is a better way to do this?)
00800         mwsparams->numSolutions = 1;
00801         inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts,
00802                                    mwsparams);
00803       }
00804       else if (amcsatInfer)
00805       { // MC-SAT
00806         inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
00807       }
00808       else if (asimtpInfer)
00809       { // Simulated Tempering
00810           // When MWS is used in Sim. Temp., numSolutions is always 1
00811           // (maybe there is a better way to do this?)
00812         mwsparams->numSolutions = 1;
00813         inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
00814                                            stparams);
00815       }
00816       else if (agibbsInfer)
00817       { // Gibbs sampling
00818           // When MWS is used in Gibbs, numSolutions is always 1
00819           // (maybe there is a better way to do this?)
00820         mwsparams->numSolutions = 1;
00821         inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
00822                                      gibbsparams);
00823       }
00824 
00825       if (!aLazy)
00826       {
00827           // Change known NE to original values
00828         domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
00829           // Set unknown NE to false for weight initialization. This seems to
00830           // give poor results when using EM. We need to leave these
00831           // as unknown and do the counts accordingly
00832         for (int predno = 0; predno < unePreds->size(); predno++)
00833         {
00834           domain->getDB()->setValue((*unePreds)[predno], FALSE);
00835         }
00836       }
00837     }
00838     cout << endl << "done constructing variable states" << endl << endl;
00839     
00840     VotedPerceptron vp(inferences, nonEvidPredNames, indexTrans, aLazy,
00841                        rescaleGradient, withEM);
00842     if (!noPrior) 
00843       vp.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
00844                          priorStdDevs.getItems());
00845     else
00846       vp.setMeansStdDevs(-1, NULL, NULL);
00847          
00848     begSec = timer.time();
00849     cout << "learning (discriminative) weights .. " << endl;
00850     vp.learnWeights(wwts, wts.size()-1, numIter, learningRate, momentum,
00851                     !initToZero, amwsMaxSubsequentSteps);
00852     cout << endl << endl << "Done learning discriminative weights. "<< endl;
00853     cout << "Time Taken for learning = ";
00854     Timer::printTime(cout, (timer.time() - begSec)); cout << endl;
00855 
00856     if (mwsparams) delete mwsparams;
00857     if (ssparams) delete ssparams;
00858     if (msparams) delete msparams;
00859     if (gibbsparams) delete gibbsparams;
00860     if (stparams) delete stparams;
00861     for (int i = 0; i < inferences.size(); i++)  delete inferences[i];
00862     for (int i = 0; i < states.size(); i++)  delete states[i];
00863   } 
00864   else
00865   {   
00867 
00868     Array<bool> areNonEvidPreds;
00869     if (nonEvidPredNames.empty())
00870     {
00871       areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), true);
00872       for (int i = 0; i < domains[0]->getNumPredicates(); i++)
00873       {
00874           //prevent equal pred from being non-evidence preds
00875         if (domains[0]->getPredicateTemplate(i)->isEqualPred())
00876         {
00877           const char* pname = domains[0]->getPredicateTemplate(i)->getName();
00878           int predId = domains[0]->getPredicateId(pname);
00879           areNonEvidPreds[predId] = false;
00880         }
00881           //prevent internal preds from being non-evidence preds
00882         if (domains[0]->getPredicateTemplate(i)->isInternalPredicateTemplate())
00883         {
00884           const char* pname = domains[0]->getPredicateTemplate(i)->getName();
00885           int predId = domains[0]->getPredicateId(pname);
00886           areNonEvidPreds[predId] = false;
00887         }
00888       }
00889     } 
00890     else
00891     {
00892       areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), false);
00893       for (int i = 0; i < nonEvidPredNames.size(); i++)
00894       {
00895         int predId = domains[0]->getPredicateId(nonEvidPredNames[i].c_str());
00896         if (predId < 0)
00897         {
00898           cout << "ERROR: Predicate " << nonEvidPredNames[i] << " undefined." 
00899                << endl;
00900           exit(-1);
00901         }
00902         areNonEvidPreds[predId] = true;
00903       }
00904     }
00905 
00906     Array<bool>* nePreds = &areNonEvidPreds;
00907     PseudoLogLikelihood pll(nePreds, &domains, !noEqualPredWt, false,-1,-1,-1);
00908     pll.setIndexTranslator(indexTrans);
00909 
00910     if (!noPrior) 
00911       pll.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
00912                           priorStdDevs.getItems());
00913     else          
00914       pll.setMeansStdDevs(-1, NULL, NULL);
00915     
00917 
00918     begSec = timer.time();
00919     for (int m = 0; m < mlns.size(); m++)
00920     {
00921       cout << "Computing counts for clauses in domain " << m << "..." << endl;
00922       const ClauseHashArray* clauses = mlns[m]->getClauses();
00923       for (int i = 0; i < clauses->size(); i++)
00924       {
00925         if (PRINT_CLAUSE_DURING_COUNT)
00926         {
00927           cout << "clause " << i << ": ";
00928           (*clauses)[i]->printWithoutWt(cout, domains[m]);
00929           cout << endl; cout.flush();
00930         }
00931         MLNClauseInfo* ci = (MLNClauseInfo*) mlns[m]->getMLNClauseInfo(i);
00932         pll.computeCountsForNewAppendedClause((*clauses)[i], &(ci->index), m, 
00933                                               NULL, false, NULL);
00934       }
00935     }
00936     pll.compress();
00937     cout <<"Computing counts took ";
00938     Timer::printTime(cout, timer.time() - begSec); cout << endl;
00939     
00941 
00942       // initialize the clause weights
00943     wts.growToSize(numClausesFormulas + 1);
00944     for (int i = 0; i < numClausesFormulas; i++) wts[i+1] = 0;
00945     //wts[i+1] = priorMeans[i];
00946 
00947       // optimize the clause weights
00948     cout << "L-BFGS-B is finding optimal weights......" << endl;
00949     begSec = timer.time();
00950     LBFGSB lbfgsb(maxIter, convThresh, &pll, numClausesFormulas);
00951     int iter;
00952     bool error;
00953     double pllValue = lbfgsb.minimize((double*)wts.getItems(), iter, error);
00954     
00955     if (error) cout << "LBFGSB returned with an error!" << endl;
00956     cout << "num iterations        = " << iter << endl;
00957     cout << "time taken            = ";
00958     Timer::printTime(cout, timer.time() - begSec);
00959     cout << endl;
00960     cout << "pseudo-log-likelihood = " << -pllValue << endl;
00961 
00962   } // else using generative learning
00963 
00965   if (indexTrans) assignWtsAndOutputMLN(out, mlns, domains, wts, indexTrans);
00966   else            assignWtsAndOutputMLN(out, mlns, domains, wts);
00967 
00968   out.close();
00969 
00971   deleteDomains(domains);
00972   for (int i = 0; i < mlns.size(); i++) delete mlns[i];
00973   PowerSet::deletePowerSet();
00974   if (indexTrans) delete indexTrans;
00975 
00976   cout << "Total time = "; 
00977   Timer::printTime(cout, timer.time() - startSec); cout << endl;
00978 }

Generated on Wed Feb 14 15:15:18 2007 for Alchemy by  doxygen 1.5.1