learnwts.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #include <fstream>
00067 #include <iostream>
00068 #include <sstream>
00069 #include "arguments.h"
00070 #include "inferenceargs.h"
00071 #include "lbfgsb.h"
00072 #include "votedperceptron.h"
00073 #include "learnwts.h"
00074 #include "maxwalksat.h"
00075 #include "mcsat.h"
00076 #include "gibbssampler.h"
00077 #include "simulatedtempering.h"
00078 
00079   //set to false to disable printing of clauses when they are counted during 
00080   //generative learning
00081 bool PRINT_CLAUSE_DURING_COUNT = true;
00082 
00083 const double DISC_DEFAULT_STD_DEV = 1;
00084 const double GEN_DEFAULT_STD_DEV = 100;
00085 
00086   // Variables for holding inference command line args are in inferenceargs.h
00087 bool discLearn = false;
00088 bool genLearn = false;
00089 char* outMLNFile = NULL;
00090 char* dbFiles = NULL;
00091 char* nonEvidPredsStr = NULL;
00092 bool noAddUnitClauses = false;
00093 bool multipleDatabases = false;
00094 bool initToZero = false;
00095 bool isQueryEvidence = false;
00096 
00097 bool noPrior = false;
00098 double priorMean = 0;
00099 double priorStdDev = -1;
00100 
00101   // Generative learning args
00102 int maxIter = 10000;
00103 double convThresh = 1e-5;
00104 bool noEqualPredWt = false;
00105 
00106   // Discriminative learning args
00107 int numIter = 200;
00108 double learningRate = 0.001;
00109 double momentum = 0.0;
00110 bool rescaleGradient = false;
00111 bool withEM = false;
00112 
00113 
00114   // Inference arguments needed for disc. learning defined in inferenceargs.h
00115   // TODO: List the arguments common to learnwts and inference in
00116   // inferenceargs.h. This can't be done with a static array.
00117 ARGS ARGS::Args[] = 
00118 {
00119     // BEGIN: Common arguments
00120   ARGS("i", ARGS::Req, ainMLNFiles, 
00121        "Comma-separated input .mln files. (With the -multipleDatabases "
00122        "option, the second file to the last one are used to contain constants "
00123        "from different databases, and they correspond to the .db files "
00124        "specified with the -t option.)"),
00125 
00126   ARGS("cw", ARGS::Opt, aClosedWorldPredsStr,
00127        "Specified non-evidence atoms (comma-separated with no space) are "
00128        "closed world, otherwise, all non-evidence atoms are open world. Atoms "
00129        "appearing here cannot be query atoms and cannot appear in the -o "
00130        "option."),
00131 
00132   ARGS("ow", ARGS::Opt, aOpenWorldPredsStr,
00133        "Specified evidence atoms (comma-separated with no space) are open "
00134        "world, while other evidence atoms are closed-world. "
00135        "Atoms appearing here cannot appear in the -c option."),
00136     // END: Common arguments
00137 
00138     // BEGIN: Common inference arguments
00139   ARGS("m", ARGS::Tog, amapPos, 
00140        "Run MAP inference and return only positive query atoms."),
00141 
00142   ARGS("a", ARGS::Tog, amapAll, 
00143        "Run MAP inference and show 0/1 results for all query atoms."),
00144 
00145   ARGS("p", ARGS::Tog, agibbsInfer, 
00146        "Run inference using MCMC (Gibbs sampling) and return probabilities "
00147        "for all query atoms."),
00148   
00149   ARGS("ms", ARGS::Tog, amcsatInfer,
00150        "Run inference using MC-SAT and return probabilities "
00151        "for all query atoms"),
00152 
00153   ARGS("simtp", ARGS::Tog, asimtpInfer,
00154        "Run inference using simulated tempering and return probabilities "
00155        "for all query atoms"),
00156 
00157   ARGS("seed", ARGS::Opt, aSeed,
00158        "[random] Seed used to initialize the randomizer in the inference "
00159        "algorithm. If not set, seed is initialized from the current date and "
00160        "time."),
00161 
00162   ARGS("lazy", ARGS::Opt, aLazy, 
00163        "[false] Run lazy version of inference if this flag is set."),
00164   
00165   ARGS("lazyNoApprox", ARGS::Opt, aLazyNoApprox, 
00166        "[false] Lazy version of inference will not approximate by deactivating "
00167        "atoms to save memory. This flag is ignored if -lazy is not set."),
00168   
00169   ARGS("memLimit", ARGS::Opt, aMemLimit, 
00170        "[-1] Maximum limit in kbytes which should be used for inference. "
00171        "-1 means main memory available on system is used."),
00172     // END: Common inference arguments
00173 
00174     // BEGIN: MaxWalkSat args
00175   ARGS("mwsMaxSteps", ARGS::Opt, amwsMaxSteps,
00176        "[1000000] (MaxWalkSat) The max number of steps taken."),
00177 
00178   ARGS("mwsTries", ARGS::Opt, amwsTries, 
00179        "[1] (MaxWalkSat) The max number of attempts taken to find a solution."),
00180 
00181   ARGS("mwsTargetWt", ARGS::Opt, amwsTargetWt,
00182        "[the best possible] (MaxWalkSat) MaxWalkSat tries to find a solution "
00183        "with weight <= specified weight."),
00184 
00185   ARGS("mwsHard", ARGS::Opt, amwsHard, 
00186        "[false] (MaxWalkSat) MaxWalkSat never breaks a hard clause in order to "
00187        "satisfy a soft one."),
00188   
00189   ARGS("mwsHeuristic", ARGS::Opt, amwsHeuristic,
00190        "[1] (MaxWalkSat) Heuristic used in MaxWalkSat (0 = RANDOM, 1 = BEST, "
00191        "2 = TABU, 3 = SAMPLESAT)."),
00192   
00193   ARGS("mwsTabuLength", ARGS::Opt, amwsTabuLength,
00194        "[5] (MaxWalkSat) Minimum number of flips between flipping the same "
00195        "atom when using the tabu heuristic in MaxWalkSat." ),
00196 
00197   ARGS("mwsLazyLowState", ARGS::Opt, amwsLazyLowState, 
00198        "[false] (MaxWalkSat) If false, the naive way of saving low states "
00199        "(each time a low state is found, the whole state is saved) is used; "
00200        "otherwise, a list of variables flipped since the last low state is "
00201        "kept and the low state is reconstructed. This can be much faster for "
00202        "very large data sets."),  
00203     // END: MaxWalkSat args
00204 
00205     // BEGIN: MCMC args
00206   ARGS("mcmcBurnMinSteps", ARGS::Opt, amcmcBurnMinSteps,
00207        "[100] (MCMC) Minimun number of burn in steps (-1: no minimum)."),
00208 
00209   ARGS("mcmcBurnMaxSteps", ARGS::Opt, amcmcBurnMaxSteps,
00210        "[100] (MCMC) Maximum number of burn-in steps (-1: no maximum)."),
00211 
00212   ARGS("mcmcMinSteps", ARGS::Opt, amcmcMinSteps, 
00213        "[-1] (MCMC) Minimum number of Gibbs sampling steps."),
00214 
00215   ARGS("mcmcMaxSteps", ARGS::Opt, amcmcMaxSteps, 
00216        "[1000] (MCMC) Maximum number of Gibbs sampling steps."),
00217 
00218   ARGS("mcmcMaxSeconds", ARGS::Opt, amcmcMaxSeconds, 
00219        "[-1] (MCMC) Max number of seconds to run MCMC (-1: no maximum)."),
00220     // END: MCMC args
00221   
00222     // BEGIN: Simulated tempering args
00223   ARGS("simtpSubInterval", ARGS::Opt, asimtpSubInterval,
00224         "[2] (Simulated Tempering) Selection interval between swap attempts"),
00225 
00226   ARGS("simtpNumRuns", ARGS::Opt, asimtpNumST,
00227         "[3] (Simulated Tempering) Number of simulated tempering runs"),
00228 
00229   ARGS("simtpNumSwap", ARGS::Opt, asimtpNumSwap,
00230         "[10] (Simulated Tempering) Number of swapping chains"),
00231     // END: Simulated tempering args
00232 
00233     // BEGIN: MC-SAT args
00234   ARGS("mcsatNumStepsEveryMCSat", ARGS::Opt, amcsatNumStepsEveryMCSat,
00235        "[1] (MC-SAT) Number of total steps (mcsat & gibbs) for every mcsat "
00236        "step"),
00237     // END: MC-SAT args
00238 
00239     // BEGIN: SampleSat args
00240   ARGS("mcsatNumSolutions", ARGS::Opt, amwsNumSolutions,
00241        "[10] Return nth SAT solution in SampleSat"),
00242 
00243   ARGS("mcsatSaRatio", ARGS::Opt, assSaRatio,
00244        "[50] (MC-SAT) Ratio of sim. annealing steps mixed with WalkSAT in "
00245        "MC-SAT"),
00246 
00247   ARGS("mcsatSaTemperature", ARGS::Opt, assSaTemp,
00248         "[10] Temperature (/100) for sim. annealing step in SampleSat"),
00249 
00250   ARGS("mcsatLateSa", ARGS::Tog, assLateSa,
00251        "[false] Run simulated annealing from the start in SampleSat"),
00252     // END: SampleSat args
00253 
00254     // BEGIN: Gibbs sampling args
00255   ARGS("gibbsNumChains", ARGS::Opt, amcmcNumChains, 
00256        "[10] (Gibbs) Number of MCMC chains for Gibbs sampling (there must be "
00257        "at least 2)."),
00258 
00259   ARGS("gibbsDelta", ARGS::Opt, agibbsDelta,
00260        "[0.05] (Gibbs) During Gibbs sampling, probabilty that epsilon error is "
00261        "exceeded is less than this value."),
00262 
00263   ARGS("gibbsEpsilonError", ARGS::Opt, agibbsEpsilonError,
00264        "[0.01] (Gibbs) Fractional error from true probability."),
00265 
00266   ARGS("gibbsFracConverged", ARGS::Opt, agibbsFracConverged, 
00267        "[0.95] (Gibbs) Fraction of ground atoms with probabilities that "
00268        "have converged."),
00269 
00270   ARGS("gibbsWalksatType", ARGS::Opt, agibbsWalksatType, 
00271        "[1] (Gibbs) Use Max Walksat to initialize ground atoms' truth values "
00272        "in Gibbs sampling (1: use Max Walksat, 0: random initialization)."),
00273 
00274   ARGS("gibbsSamplesPerTest", ARGS::Opt, agibbsSamplesPerTest, 
00275        "[100] Perform convergence test once after this many number of samples "
00276        "per chain."),
00277     // END: Gibbs sampling args
00278 
00279     // BEGIN: Weight learning specific args
00280   ARGS("d", ARGS::Tog, discLearn, "Discriminative weight learning."),
00281 
00282   ARGS("g", ARGS::Tog, genLearn, "Generative weight learning."),
00283 
00284   ARGS("o", ARGS::Req, outMLNFile, 
00285        "Output .mln file containing formulas with learned weights."),
00286 
00287   ARGS("t", ARGS::Req, dbFiles, 
00288        "Comma-separated .db files containing the training database "
00289        "(of true/false ground atoms), including function definitions, "
00290        "e.g. ai.db,graphics.db,languages.db."),
00291 
00292   ARGS("ne", ARGS::Opt, nonEvidPredsStr, 
00293        "First-order non-evidence predicates (comma-separated with no space),  "
00294        "e.g., cancer,smokes,friends. For discriminative learning, at least "
00295        "one non-evidence predicate must be specified. For generative learning, "
00296        "the specified predicates are included in the (weighted) pseudo-log-"
00297        "likelihood computation; if none are specified, all are included."),
00298     
00299   ARGS("noAddUnitClauses", ARGS::Tog, noAddUnitClauses,
00300        "If specified, unit clauses are not included in the .mln file; "
00301        "otherwise they are included."),
00302 
00303   ARGS("multipleDatabases", ARGS::Tog, multipleDatabases,
00304        "If specified, each .db file belongs to a separate database; "
00305        "otherwise all .db files belong to the same database."),
00306 
00307   ARGS("withEM", ARGS::Tog, withEM,
00308        "If set, EM is used to fill in missing truth values; "
00309        "otherwise missing truth values are set to false."),
00310 
00311   ARGS("dNumIter", ARGS::Opt, numIter, 
00312        "[200] (For discriminative learning only.) "
00313        "Number of iterations to run voted perceptron."),
00314   
00315   ARGS("dLearningRate", ARGS::Opt, learningRate, 
00316        "[0.001] (For discriminative learning only) "
00317        "Learning rate for the gradient descent in voted perceptron algorithm."),
00318 
00319   ARGS("dMomentum", ARGS::Opt, momentum, 
00320        "[0.0] (For discriminative learning only) "
00321        "Momentum term for the gradient descent in voted perceptron algorithm."),
00322        
00323   ARGS("queryEvidence", ARGS::Tog, isQueryEvidence, 
00324        "If this flag is set, then all the groundings of query preds not in db "
00325        "are assumed false evidence."),
00326        
00327   ARGS("dRescale", ARGS::Tog, rescaleGradient, 
00328        "(For discriminative learning only.) "
00329        "Rescale the gradient by the number of true groundings per weight."),
00330 
00331   ARGS("dZeroInit", ARGS::Tog, initToZero,
00332        "(For discriminative learning only.) "
00333        "Initialize clause weights to zero instead of their log odds."),
00334   
00335   ARGS("gMaxIter", ARGS::Opt, maxIter, 
00336        "[10000] (For generative learning only.) "
00337        "Max number of iterations to run L-BFGS-B, "
00338        "the optimization algorithm for generative learning."),
00339   
00340   ARGS("gConvThresh", ARGS::Opt, convThresh, 
00341        "[1e-5] (For generative learning only.) "
00342        "Fractional change in pseudo-log-likelihood at which "
00343        "L-BFGS-B terminates."),
00344 
00345   ARGS("gNoEqualPredWt", ARGS::Opt, noEqualPredWt, 
00346        "(For generative learning only.) "
00347        "If specified, the predicates are not weighted equally in the "
00348        "pseudo-log-likelihood computation; otherwise they are."),
00349   
00350   ARGS("noPrior", ARGS::Tog, noPrior, "No Gaussian priors on formula weights."),
00351 
00352   ARGS("priorMean", ARGS::Opt, priorMean, 
00353        "[0] Means of Gaussian priors on formula weights. By default, "
00354        "for each formula, it is the weight given in the .mln input file, " 
00355        "or fraction thereof if the formula turns into multiple clauses. "
00356        "This mean applies if no weight is given in the .mln file."),
00357 
00358   ARGS("priorStdDev", ARGS::Opt, priorStdDev, 
00359        "[1 for discriminative learning. 100 for generative learning] "
00360        "Standard deviations of Gaussian priors on clause weights."),
00361 
00362   ARGS()
00363 };
00364 
00365 //bool extractPredNames(...) defined in infer.h
00366 
00367 int main(int argc, char* argv[])
00368 {
00369   ARGS::parse(argc,argv,&cout);
00370 
00371   if (!discLearn && !genLearn) 
00372   { 
00373     cout << "must specify either -d or -g "
00374          <<"(discriminative or generative learning) " << endl; 
00375     return -1; 
00376   }
00377 
00378   Timer timer;
00379   double startSec = timer.time();
00380   double begSec;
00381 
00382   if (priorStdDev < 0)
00383   {
00384     if (discLearn) 
00385     { 
00386       cout << "priorStdDev set to (discriminative learning's) default of " 
00387            << DISC_DEFAULT_STD_DEV << endl;
00388       priorStdDev = DISC_DEFAULT_STD_DEV;
00389     }
00390     else
00391     {
00392       cout << "priorStdDev set to (generative learning's) default of " 
00393            << GEN_DEFAULT_STD_DEV << endl;
00394       priorStdDev = GEN_DEFAULT_STD_DEV;      
00395     }
00396   }
00397 
00398 
00400   if (discLearn && nonEvidPredsStr == NULL)
00401   {
00402     cout << "ERROR: you must specify non-evidence predicates for "
00403          << "discriminative learning" << endl;
00404     return -1;
00405   }
00406 
00407   if (maxIter <= 0)  { cout << "maxIter must be > 0" << endl; return -1; }
00408   if (convThresh <= 0 || convThresh > 1)  
00409   { cout << "convThresh must be > 0 and <= 1" << endl; return -1;  }
00410   if (priorStdDev <= 0) { cout << "priorStdDev must be > 0" << endl; return -1;}
00411 
00412   if (amwsMaxSteps <= 0)
00413   { cout << "ERROR: mwsMaxSteps must be positive" << endl; return -1; }
00414 
00415   if (amwsTries <= 0)
00416   { cout << "ERROR: mwsTries must be positive" << endl; return -1; }
00417 
00418   if (aMemLimit <= 0 && aMemLimit != -1)
00419   { cout << "ERROR: mwsLimit must be positive (or -1)" << endl; return -1; }
00420 
00421   if (!discLearn && aLazy)
00422   {
00423     cout << "ERROR: lazy can only be used with discriminative learning"
00424          << endl;
00425     return -1;
00426   }
00427 
00428   if (!discLearn && withEM)
00429   { 
00430     cout << "ERROR: EM can only be used with discriminative learning" << endl;
00431     return -1; 
00432   }
00433 
00434   ofstream out(outMLNFile);
00435   if (!out.good())
00436   {
00437     cout << "ERROR: unable to open " << outMLNFile << endl;
00438     return -1;
00439   }
00440 
00441     // If no inference method indicated, then use MAP
00442   if (discLearn && !asimtpInfer && !amapPos && !amapAll && !agibbsInfer &&
00443       !amcsatInfer)
00444   {
00445     amapPos = true;
00446   }
00447 
00448 
00449   //the second .mln file to the last one in ainMLNFiles _may_ be used 
00450   //to hold constants, so they are held in constFilesArr. They will be
00451   //included into the first .mln file.
00452 
00453     //extract .mln and .db, file names
00454   Array<string> constFilesArr;
00455   Array<string> dbFilesArr;
00456   extractFileNames(ainMLNFiles, constFilesArr);
00457   assert(constFilesArr.size() >= 1);
00458   string inMLNFile = constFilesArr[0];
00459   constFilesArr.removeItem(0);
00460   extractFileNames(dbFiles, dbFilesArr);
00461 
00462   if (dbFilesArr.size() <= 0)
00463   {cout<<"ERROR: must specify training data with -t option."<<endl; return -1;}
00464  
00465     // if multiple databases, check the number of .db/.func files
00466   if (multipleDatabases) 
00467   {
00468       //if # .mln files containing constants/.func files and .db files are diff
00469     if ((constFilesArr.size() > 0 && constFilesArr.size() != dbFilesArr.size()))
00470     {
00471       cout << "ERROR: when there are multiple databases, if .mln files "
00472            << "containing constants are specified, there must " 
00473            << "be the same number of them as .db files" << endl;
00474       return -1;
00475     }
00476   }
00477 
00478   StringHashArray nonEvidPredNames;
00479   if (nonEvidPredsStr)
00480   {
00481     if(!extractPredNames(nonEvidPredsStr, NULL, nonEvidPredNames))
00482     {
00483       cout << "ERROR: failed to extract non-evidence predicate names." << endl;
00484       return -1;
00485     }
00486   }
00487 
00488   StringHashArray owPredNames;
00489   StringHashArray cwPredNames;
00490 
00492 
00493   cout << "Parsing MLN and creating domains..." << endl;
00494   StringHashArray* nePredNames = (discLearn) ? &nonEvidPredNames : NULL;
00495   Array<Domain*> domains;
00496   Array<MLN*> mlns;
00497   begSec = timer.time();
00498   bool allPredsExceptQueriesAreCW = true;
00499   if (discLearn)
00500   {
00501       //extract names of open-world evidence predicates
00502     if (aOpenWorldPredsStr)
00503     {
00504       if (!extractPredNames(string(aOpenWorldPredsStr), NULL, owPredNames)) 
00505         return -1;
00506       assert(owPredNames.size() > 0);
00507     }
00508 
00509       //extract names of closed-world non-evidence predicates
00510     if (aClosedWorldPredsStr)
00511     {
00512       if (!extractPredNames(string(aClosedWorldPredsStr), NULL, cwPredNames)) 
00513         return -1;
00514       assert(cwPredNames.size() > 0);
00515       if (!checkQueryPredsNotInClosedWorldPreds(nonEvidPredNames, cwPredNames))
00516         return -1;
00517     }
00518  
00519     allPredsExceptQueriesAreCW = owPredNames.empty();
00520   }
00521     // Parse as if lazy inference is set to true to set evidence atoms in DB
00522     // If lazy is not used, this is removed from DB
00523   createDomainsAndMLNs(domains, mlns, multipleDatabases, inMLNFile, 
00524                        constFilesArr, dbFilesArr, nePredNames,
00525                        !noAddUnitClauses, priorMean, true,
00526                        allPredsExceptQueriesAreCW, &owPredNames);
00527   cout << "Parsing MLN and creating domains took "; 
00528   Timer::printTime(cout, timer.time() - begSec); cout << endl;
00529 
00530   /*
00531   cout << "Clause prior means:" << endl;
00532   cout << "_________________________________" << endl;
00533   mlns[0]->printClausePriorMeans(cout, domains[0]);
00534   cout << "_________________________________" << endl;
00535   cout << endl;
00536 
00537   cout << "Formula prior means:" << endl;
00538   cout << "_________________________________" << endl;
00539   mlns[0]->printFormulaPriorMeans(cout);
00540   cout << "_________________________________" << endl;
00541   cout << endl;
00542   */
00543 
00545 
00546     //we need an index translator if clauses do not line up across multiple DBs
00547   IndexTranslator* indexTrans 
00548     = (IndexTranslator::needIndexTranslator(mlns, domains)) ?
00549        new IndexTranslator(&mlns, &domains) : NULL;  
00550 
00551   if (indexTrans) 
00552     cout << endl << "the weights of clauses in the CNFs of existential"
00553          << " formulas will be tied" << endl;
00554 
00555 
00556   Array<double> priorMeans, priorStdDevs;
00557   if (!noPrior)
00558   {
00559     if (indexTrans)
00560     {
00561       indexTrans->setPriorMeans(priorMeans);
00562       priorStdDevs.growToSize(priorMeans.size());
00563       for (int i = 0; i < priorMeans.size(); i++)
00564         priorStdDevs[i] = priorStdDev;
00565     }
00566     else
00567     {
00568       const ClauseHashArray* clauses = mlns[0]->getClauses();
00569       int numClauses = clauses->size();
00570       for (int i = 0; i < numClauses; i++)
00571       {
00572         priorMeans.append((*clauses)[i]->getWt());
00573         priorStdDevs.append(priorStdDev);
00574       }
00575     }
00576   }
00577   // HACK -- not sure if this is right... but the old version was broke!
00578   // This may fail when there's an indexTrans.  [Daniel]
00579   //int numClausesFormulas = priorMeans.size();
00580   int numClausesFormulas = mlns[0]->getClauses()->size();
00581 
00582 
00584   Array<double> wts;
00585 
00586     // Discriminative learning
00587   if (discLearn) 
00588   {
00589     wts.growToSize(numClausesFormulas + 1);
00590     double* wwts = (double*) wts.getItems();
00591     wwts++;
00592       // Non-evid preds as a string
00593     string nePredsStr = nonEvidPredsStr;
00594 
00595       // Set SampleSat parameters
00596     SampleSatParams* ssparams = new SampleSatParams;
00597     ssparams->lateSa = assLateSa;
00598     ssparams->saRatio = assSaRatio;
00599     ssparams->saTemp = assSaTemp;
00600 
00601       // Set MaxWalksat parameters
00602     MaxWalksatParams* mwsparams = NULL;
00603     mwsparams = new MaxWalksatParams;
00604     mwsparams->ssParams = ssparams;
00605     mwsparams->maxSteps = amwsMaxSteps;
00606     mwsparams->maxTries = amwsTries;
00607     mwsparams->targetCost = amwsTargetWt;
00608     mwsparams->hard = amwsHard;
00609       // numSolutions only applies when used in SampleSat.
00610       // When just MWS, this is set to 1
00611     mwsparams->numSolutions = amwsNumSolutions;
00612     mwsparams->heuristic = amwsHeuristic;
00613     mwsparams->tabuLength = amwsTabuLength;
00614     mwsparams->lazyLowState = amwsLazyLowState;
00615 
00616       // Set MC-SAT parameters
00617     MCSatParams* msparams = new MCSatParams;
00618     msparams->mwsParams = mwsparams;
00619       // MC-SAT needs only one chain
00620     msparams->numChains          = 1;
00621     msparams->burnMinSteps       = amcmcBurnMinSteps;
00622     msparams->burnMaxSteps       = amcmcBurnMaxSteps;
00623     msparams->minSteps           = amcmcMinSteps;
00624     msparams->maxSteps           = amcmcMaxSteps;
00625     msparams->maxSeconds         = amcmcMaxSeconds;
00626     msparams->numStepsEveryMCSat = amcsatNumStepsEveryMCSat;
00627 
00628       // Set Gibbs parameters
00629     GibbsParams* gibbsparams = new GibbsParams;
00630     gibbsparams->mwsParams    = mwsparams;
00631     gibbsparams->numChains    = amcmcNumChains;
00632     gibbsparams->burnMinSteps = amcmcBurnMinSteps;
00633     gibbsparams->burnMaxSteps = amcmcBurnMaxSteps;
00634     gibbsparams->minSteps     = amcmcMinSteps;
00635     gibbsparams->maxSteps     = amcmcMaxSteps;
00636     gibbsparams->maxSeconds   = amcmcMaxSeconds;
00637 
00638     gibbsparams->gamma          = 1 - agibbsDelta;
00639     gibbsparams->epsilonError   = agibbsEpsilonError;
00640     gibbsparams->fracConverged  = agibbsFracConverged;
00641     gibbsparams->walksatType    = agibbsWalksatType;
00642     gibbsparams->samplesPerTest = agibbsSamplesPerTest;
00643   
00644       // Set Sim. Tempering parameters
00645     SimulatedTemperingParams* stparams = new SimulatedTemperingParams;
00646     stparams->mwsParams    = mwsparams;
00647     stparams->numChains    = amcmcNumChains;
00648     stparams->burnMinSteps = amcmcBurnMinSteps;
00649     stparams->burnMaxSteps = amcmcBurnMaxSteps;
00650     stparams->minSteps     = amcmcMinSteps;
00651     stparams->maxSteps     = amcmcMaxSteps;
00652     stparams->maxSeconds   = amcmcMaxSeconds;
00653 
00654     stparams->subInterval = asimtpSubInterval;
00655     stparams->numST       = asimtpNumST;
00656     stparams->numSwap     = asimtpNumSwap;
00657 
00658     Array<VariableState*> states;
00659     Array<Inference*> inferences;
00660 
00661     states.growToSize(domains.size());
00662     inferences.growToSize(domains.size());
00663 
00664       // Build the state for inference in each domain
00665     Array<int> allPredGndingsAreNonEvid;
00666     Array<Predicate*> ppreds;
00667     
00668     for (int i = 0; i < domains.size(); i++)
00669     {
00670       Domain* domain = domains[i];
00671       MLN* mln = mlns[i];
00672 
00673         // Remove evidence atoms structure from DBs
00674       if (!aLazy)
00675         domains[i]->getDB()->setLazyFlag(false);
00676     
00677         // Unknown non-ev. preds
00678       GroundPredicateHashArray* unePreds = NULL;
00679 
00680         // Known non-ev. preds
00681       GroundPredicateHashArray* knePreds = NULL;
00682       Array<TruthValue>* knePredValues = NULL;
00683 
00684         // Need to set some dummy weight
00685       for (int j = 0; j < mln->getNumClauses(); j++)
00686         ((Clause*) mln->getClause(j))->setWt(1);
00687 
00688         // Find all evidence predicates with unknown value
00689         // and make these non-evidence
00690       int numPreds = domain->getNumPredicates();
00691       for (int j = 0; j < numPreds; j++)
00692       {
00693         const char* predName = domain->getPredicateName(j);
00694           // Only look at evidence (but not '=' predicate)
00695         if (nonEvidPredNames.contains(predName) ||
00696             domain->getPredicateTemplate(j)->isEqualPredicateTemplate())
00697           continue;
00698           
00699         bool unknownPred = false;
00700         ppreds.clear();
00701         Predicate::createAllGroundings(j, domain, ppreds);
00702         for (int k = 0; k < ppreds.size(); k++)
00703         {
00704           TruthValue tv = domain->getDB()->getValue(ppreds[k]);
00705           if (tv == UNKNOWN)
00706             unknownPred = true;
00707           delete ppreds[k];
00708         }
00709         if (unknownPred)
00710         {
00711           nePredsStr.append(",");
00712           nePredsStr.append(predName);
00713           nonEvidPredNames.append(predName);
00714         }
00715       }
00716 
00717         // Eager version: Build query preds from command line and set known
00718         // non-evidence to unknown for building the states
00719       if (!aLazy)
00720       {
00721         unePreds = new GroundPredicateHashArray;
00722         knePreds = new GroundPredicateHashArray;
00723         knePredValues = new Array<TruthValue>;
00724 
00725         allPredGndingsAreNonEvid.growToSize(domain->getNumPredicates(), false);
00726           //defined in infer.h
00727         createComLineQueryPreds(nePredsStr, domain, domain->getDB(), 
00728                                 unePreds, knePreds, 
00729                                 &allPredGndingsAreNonEvid);
00730 
00731           // Pred values not set to unknown in DB: unePreds contains
00732           // unknown, knePreds contains known non-evidence
00733 
00734           // Set known NE to unknown for building state
00735         knePredValues->growToSize(knePreds->size(), FALSE);
00736         for (int predno = 0; predno < knePreds->size(); predno++) 
00737           (*knePredValues)[predno] =
00738             domain->getDB()->setValue((*knePreds)[predno], UNKNOWN);
00739 
00740           // If first order query pred groundings are allowed to be evidence
00741           // - we assume all the predicates not in db to be false
00742           // evidence - need a better way code this.
00743         if (isQueryEvidence)
00744             // Set unknown NE to false
00745           for (int predno = 0; predno < unePreds->size(); predno++) 
00746             domain->getDB()->setValue((*unePreds)[predno], FALSE);
00747       }
00748       
00749         // Create state for inferred counts using unknown and known (set to
00750         // unknown in the DB) non-evidence preds
00751       cout << endl << "constructing state for domain " << i << "..." << endl;
00752       bool markHardGndClauses = false;
00753       bool trackParentClauseWts = true;
00754       VariableState*& state = states[i];
00755       state = new VariableState(unePreds, knePreds, knePredValues,
00756                                 &allPredGndingsAreNonEvid, markHardGndClauses,
00757                                 trackParentClauseWts, mln, domain, aLazy);
00758 
00759       Inference*& inference = inferences[i];
00760       bool trackClauseTrueCnts = true;
00761         // Different inference algorithms
00762       if (amapPos || amapAll)
00763       { // MaxWalkSat
00764           // When standalone MWS, numSolutions is always 1
00765           // (maybe there is a better way to do this?)
00766         mwsparams->numSolutions = 1;
00767         inference = new MaxWalkSat(state, aSeed, trackClauseTrueCnts,
00768                                    mwsparams);
00769       }
00770       else if (amcsatInfer)
00771       { // MC-SAT
00772         inference = new MCSAT(state, aSeed, trackClauseTrueCnts, msparams);
00773       }
00774       else if (asimtpInfer)
00775       { // Simulated Tempering
00776           // When MWS is used in Sim. Temp., numSolutions is always 1
00777           // (maybe there is a better way to do this?)
00778         mwsparams->numSolutions = 1;
00779         inference = new SimulatedTempering(state, aSeed, trackClauseTrueCnts,
00780                                            stparams);
00781       }
00782       else if (agibbsInfer)
00783       { // Gibbs sampling
00784           // When MWS is used in Gibbs, numSolutions is always 1
00785           // (maybe there is a better way to do this?)
00786         mwsparams->numSolutions = 1;
00787         inference = new GibbsSampler(state, aSeed, trackClauseTrueCnts,
00788                                      gibbsparams);
00789       }
00790 
00791       if (!aLazy)
00792       {
00793           // Change known NE to original values
00794         domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
00795           // Set unknown NE to false for weight initialization. This seems to
00796           // give poor results when using EM. We need to leave these
00797           // as unknown and do the counts accordingly
00798         for (int predno = 0; predno < unePreds->size(); predno++)
00799         {
00800           domain->getDB()->setValue((*unePreds)[predno], FALSE);
00801         }
00802       }
00803     }
00804     cout << endl << "done constructing variable states" << endl << endl;
00805     
00806     VotedPerceptron vp(inferences, nonEvidPredNames, indexTrans, aLazy,
00807                        rescaleGradient, withEM);
00808     if (!noPrior) 
00809       vp.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
00810                          priorStdDevs.getItems());
00811     else
00812       vp.setMeansStdDevs(-1, NULL, NULL);
00813          
00814     begSec = timer.time();
00815     cout << "learning (discriminative) weights .. " << endl;
00816     vp.learnWeights(wwts, wts.size()-1, numIter, learningRate, momentum,
00817                     !initToZero);
00818     cout << endl << endl << "Done learning discriminative weights. "<< endl;
00819     cout << "Time Taken for learning = ";
00820     Timer::printTime(cout, (timer.time() - begSec)); cout << endl;
00821 
00822     if (mwsparams) delete mwsparams;
00823     if (ssparams) delete ssparams;
00824     if (msparams) delete msparams;
00825     if (gibbsparams) delete gibbsparams;
00826     if (stparams) delete stparams;
00827     for (int i = 0; i < inferences.size(); i++)  delete inferences[i];
00828     for (int i = 0; i < states.size(); i++)  delete states[i];
00829   } 
00830   else
00831   {   
00833 
00834     Array<bool> areNonEvidPreds;
00835     if (nonEvidPredNames.empty())
00836     {
00837       areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), true);
00838       for (int i = 0; i < domains[0]->getNumPredicates(); i++)
00839       {
00840           //prevent equal pred from being non-evidence preds
00841         if (domains[0]->getPredicateTemplate(i)->isEqualPred())
00842         {
00843           const char* pname = domains[0]->getPredicateTemplate(i)->getName();
00844           int predId = domains[0]->getPredicateId(pname);
00845           areNonEvidPreds[predId] = false;
00846         }
00847           //prevent internal preds from being non-evidence preds
00848         if (domains[0]->getPredicateTemplate(i)->isInternalPredicateTemplate())
00849         {
00850           const char* pname = domains[0]->getPredicateTemplate(i)->getName();
00851           int predId = domains[0]->getPredicateId(pname);
00852           areNonEvidPreds[predId] = false;
00853         }
00854       }
00855     } 
00856     else
00857     {
00858       areNonEvidPreds.growToSize(domains[0]->getNumPredicates(), false);
00859       for (int i = 0; i < nonEvidPredNames.size(); i++)
00860       {
00861         int predId = domains[0]->getPredicateId(nonEvidPredNames[i].c_str());
00862         if (predId < 0)
00863         {
00864           cout << "ERROR: Predicate " << nonEvidPredNames[i] << " undefined." 
00865                << endl;
00866           exit(-1);
00867         }
00868         areNonEvidPreds[predId] = true;
00869       }
00870     }
00871 
00872     Array<bool>* nePreds = &areNonEvidPreds;
00873     PseudoLogLikelihood pll(nePreds, &domains, !noEqualPredWt, false,-1,-1,-1);
00874     pll.setIndexTranslator(indexTrans);
00875 
00876     if (!noPrior) 
00877       pll.setMeansStdDevs(numClausesFormulas, priorMeans.getItems(),
00878                           priorStdDevs.getItems());
00879     else          
00880       pll.setMeansStdDevs(-1, NULL, NULL);
00881     
00883 
00884     begSec = timer.time();
00885     for (int m = 0; m < mlns.size(); m++)
00886     {
00887       cout << "Computing counts for clauses in domain " << m << "..." << endl;
00888       const ClauseHashArray* clauses = mlns[m]->getClauses();
00889       for (int i = 0; i < clauses->size(); i++)
00890       {
00891         if (PRINT_CLAUSE_DURING_COUNT)
00892         {
00893           cout << "clause " << i << ": ";
00894           (*clauses)[i]->printWithoutWt(cout, domains[m]);
00895           cout << endl; cout.flush();
00896         }
00897         MLNClauseInfo* ci = (MLNClauseInfo*) mlns[m]->getMLNClauseInfo(i);
00898         pll.computeCountsForNewAppendedClause((*clauses)[i], &(ci->index), m, 
00899                                               NULL, false, NULL);
00900       }
00901     }
00902     pll.compress();
00903     cout <<"Computing counts took ";
00904     Timer::printTime(cout, timer.time() - begSec); cout << endl;
00905     
00907 
00908       // initialize the clause weights
00909     wts.growToSize(numClausesFormulas + 1);
00910     for (int i = 0; i < numClausesFormulas; i++) wts[i+1] = 0;
00911     //wts[i+1] = priorMeans[i];
00912 
00913       // optimize the clause weights
00914     cout << "L-BFGS-B is finding optimal weights......" << endl;
00915     begSec = timer.time();
00916     LBFGSB lbfgsb(maxIter, convThresh, &pll, numClausesFormulas);
00917     int iter;
00918     bool error;
00919     double pllValue = lbfgsb.minimize((double*)wts.getItems(), iter, error);
00920     
00921     if (error) cout << "LBFGSB returned with an error!" << endl;
00922     cout << "num iterations        = " << iter << endl;
00923     cout << "time taken            = ";
00924     Timer::printTime(cout, timer.time() - begSec);
00925     cout << endl;
00926     cout << "pseudo-log-likelihood = " << -pllValue << endl;
00927 
00928   } // else using generative learning
00929 
00931   if (indexTrans) assignWtsAndOutputMLN(out, mlns, domains, wts, indexTrans);
00932   else            assignWtsAndOutputMLN(out, mlns, domains, wts);
00933 
00934   out.close();
00935 
00937   deleteDomains(domains);
00938   for (int i = 0; i < mlns.size(); i++) delete mlns[i];
00939   PowerSet::deletePowerSet();
00940   if (indexTrans) delete indexTrans;
00941 
00942   cout << "Total time = "; 
00943   Timer::printTime(cout, timer.time() - startSec); cout << endl;
00944 }

Generated on Tue Jan 16 05:30:03 2007 for Alchemy by  doxygen 1.5.1