learnstruct.cpp

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, and Daniel Lowd.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, and Daniel Lowd. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, and Daniel Lowd in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #include <fstream>
00067 #include <iostream>
00068 #include "arguments.h"
00069 #include "fol.h"
00070 #include "learnwts.h"
00071 //#include "infer.h"
00072 #include "structlearn.h"
00073 
00074 
00075 char* inMLNFiles = NULL;
00076 char* outMLNFile = NULL;
00077 char* dbFiles = NULL;
00078 char* nonEvidPredsStr = NULL; 
00079 bool multipleDatabases = false;
00080 
00081 int beamSize = 5;
00082 double minWt = 0.01;
00083 double penalty = 0.01;
00084 int maxVars = 6;
00085 int maxNumPredicates = 6;
00086 int cacheSize = 500;
00087 
00088 bool noSampleClauses = false;
00089 double ddelta = 0.05;
00090 double epsilon = 0.2;
00091 int minClauseSamples = -1;
00092 int maxClauseSamples = -1;
00093 
00094 bool noSampleGndPreds = false;
00095 double fraction = 0.8;
00096 int minGndPredSamples = -1;
00097 int maxGndPredSamples = -1;
00098 
00099 bool noPrior = false;
00100 double priorMean = 0;
00101 double priorStdDev = 100;
00102 
00103 int lbMaxIter = 10000;
00104 double lbConvThresh = 1e-5;
00105 int looseMaxIter = 10;
00106 double looseConvThresh = 1e-3;
00107 
00108 int numEstBestClauses = 10;
00109 bool noWtPredsEqually = false;
00110 bool startFromEmptyMLN = false;
00111 bool tryAllFlips = false;
00112 int  bestGainUnchangedLimit = 2;
00113 
00114 bool structGradDescent = false;
00115 bool withEM = false;
00116 
00117 ARGS ARGS::Args[] = 
00118 {
00119   ARGS("i", ARGS::Req, inMLNFiles, 
00120        "Comma-separated input .mln files. (With the -multipleDatabases "
00121        "option, the second file to the last one are used to contain constants "
00122        "from different domains, and they correspond to the .db files specified "
00123        "with the -t option.)"),
00124 
00125   ARGS("o", ARGS::Req, outMLNFile, 
00126        "Output .mln file containing learned formulas and weights."),
00127 
00128   ARGS("t", ARGS::Req, dbFiles, 
00129        "Comma-separated .db files containing the training database "
00130        "(of true/false ground atoms), including function definitions, "
00131        "e.g. ai.db,graphics.db,languages.db."),
00132   
00133   ARGS("ne", ARGS::Opt, nonEvidPredsStr, 
00134        "[all predicates] Non-evidence predicates "
00135        "(comma-separated with no space), e.g., cancer,smokes,friends."),
00136     
00137   ARGS("multipleDatabases", ARGS::Tog, multipleDatabases,
00138        "If specified, each .db file belongs to a separate domain; "
00139        "otherwise all .db files belong to the same domain."),
00140 
00141   ARGS("beamSize", ARGS::Opt, beamSize, "[5] Size of beam in beam search."),
00142 
00143   ARGS("minWt", ARGS::Opt, minWt, 
00144        "[0.01] Candidate clauses are discarded if "
00145        "their absolute weights fall below this."),
00146 
00147   ARGS("penalty", ARGS::Opt, penalty, 
00148        "[0.01] Each difference between the current and previous version of a "
00149        "candidate clause penalizes the (weighted) pseudo-log-likelihood "
00150        "by this amount."),
00151 
00152   ARGS("maxVars", ARGS::Opt, maxVars, 
00153        "[6] Maximum number of variables in learned clauses."),
00154   
00155   ARGS("maxNumPredicates", ARGS::Opt, maxNumPredicates, 
00156        "[6] Maximum number of predicates in learned clauses."),
00157   
00158   ARGS("cacheSize", ARGS::Opt, cacheSize, 
00159        "[500] Size in megabytes of the cache that is used to store the clauses "
00160        "(and their counts) that are created during structure learning."),
00161 
00162   ARGS("noSampleClauses", ARGS::Tog, noSampleClauses, 
00163        "If specified, compute a clause's number of true groundings exactly, "
00164        "and do not estimate it by sampling its groundings. If not specified, "
00165        "estimate the number by sampling."),
00166 
00167   ARGS("delta", ARGS::Opt, ddelta, 
00168        "[0.05] (Used only if sampling clauses.) "
00169        "The probability that an estimate a clause's number of true groundings "
00170        "is off by more than epsilon error is less than this value. "
00171        "Used to determine the number of samples of the clause's groundings "
00172        "to draw."),
00173 
00174   ARGS("epsilon", ARGS::Opt, epsilon,
00175        "[0.2] (Used only if sampling clauses.) "
00176        "Fractional error from a clause's actual number of true groundings. "
00177        "Used to determine the number of samples of the clause's groundings "
00178        "to draw."),
00179 
00180   ARGS("minClauseSamples", ARGS::Opt, minClauseSamples,
00181        "[-1] (Used only if sampling clauses.) "
00182        "Minimum number of samples of a clause's groundings to draw. "
00183        "(-1: no minimum)"),
00184 
00185   ARGS("maxClauseSamples", ARGS::Opt, maxClauseSamples,
00186        "[-1] (Used only if sampling clauses.) "
00187        "Maximum number of samples of a clause's groundings to draw. "
00188        "(-1: no maximum)"),
00189 
00190   ARGS("noSampleAtoms", ARGS::Tog, noSampleGndPreds, 
00191        "If specified, do not estimate the (weighted) pseudo-log-likelihood by "
00192        "sampling ground atoms; otherwise, estimate the value by sampling."),
00193 
00194   ARGS("fractAtoms", ARGS::Opt, fraction,
00195        "[0.8] (Used only if sampling ground atoms.) "
00196        "Fraction of each predicate's ground atoms to draw."),
00197 
00198   ARGS("minAtomSamples", ARGS::Opt, minGndPredSamples,
00199        "[-1] (Used only if sampling ground atoms.) "
00200        "Minimum number of each predicate's ground atoms to draw. "
00201        "(-1: no minimum)"),
00202 
00203   ARGS("maxAtomSamples", ARGS::Opt, maxGndPredSamples,
00204        "[-1] (Used only if sampling ground atoms.) "
00205        "Maximum number of each predicate's ground atoms to draw. "
00206        "(-1: no maximum)"),
00207 
00208   ARGS("noPrior", ARGS::Tog, noPrior, "No Gaussian priors on formula weights."),
00209 
00210   ARGS("priorMean", ARGS::Opt, priorMean, 
00211        "[0] Means of Gaussian priors on formula weights. By default, "
00212        "for each formula, it is the weight given in the .mln input file, " 
00213        "or fraction thereof if the formula turns into multiple clauses. "
00214        "This mean applies if no weight is given in the .mln file."),
00215 
00216   ARGS("priorStdDev", ARGS::Opt, priorStdDev, 
00217        "[100] Standard deviations of Gaussian priors on clause weights."),
00218 
00219   ARGS("tightMaxIter", ARGS::Opt, lbMaxIter, 
00220        "[10000] Max number of iterations to run L-BFGS-B, "
00221        "the algorithm used to optimize the (weighted) pseudo-log-likelihood."),
00222 
00223   ARGS("tightConvThresh", ARGS::Opt, lbConvThresh, 
00224        "[1e-5] Fractional change in (weighted) pseudo-log-likelihood at which "
00225        "L-BFGS-B terminates."),
00226 
00227   ARGS("looseMaxIter", ARGS::Opt, looseMaxIter, 
00228        "[10] Max number of iterations to run L-BFGS-B "
00229        "when evaluating candidate clauses."),
00230 
00231   ARGS("looseConvThresh", ARGS::Opt, looseConvThresh, 
00232        "[1e-3] Fractional change in (weighted) pseudo-log-likelihood at which "
00233        "L-BFGS-B terminates when evaluating candidate clauses."),
00234 
00235   ARGS("numClausesReEval", ARGS::Opt, numEstBestClauses, 
00236        "[10] Keep this number of candidate clauses with the highest estimated "
00237        "scores, and re-evaluate their scores precisely."),
00238 
00239   ARGS("noWtPredsEqually", ARGS::Tog, noWtPredsEqually,
00240        "If specified, each predicate is not weighted equally. This means that "
00241        "high-arity predicates contribute more to the pseudo-log-likelihood "
00242        "than low-arity ones. If not specified, each predicate is given equal "
00243        "weight in the weighted pseudo-log-likelihood."),
00244 
00245   ARGS("startFromEmptyMLN", ARGS::Tog, startFromEmptyMLN,
00246        "If specified, start structure learning from an empty MLN. "
00247        "If the input .mln contains formulas, they will be added to the "
00248        "candidate clauses created in the first step of beam search. "
00249        "If not specified, begin structure learning from the input .mln file."),
00250        
00251   ARGS("tryAllFlips", ARGS::Tog, tryAllFlips,
00252        "If specified, the structure learning algorithm tries to flip "
00253        "the predicate signs of the formulas in the input .mln file "
00254        "in all possible ways"),
00255 
00256   ARGS("bestGainUnchangedLimit", ARGS::Opt, bestGainUnchangedLimit, 
00257        "[2] Beam search stops when the best clause found does not change "
00258        "in this number of iterations."),
00259 
00260   //ARGS("structGradDescent", ARGS::Tog, structGradDescent,
00261   //     "If set, structural gradient descent is used; "
00262   //     "otherwise beam search is used."),
00263 
00264   //ARGS("withEM", ARGS::Tog, withEM,
00265   //     "If set, relational structural EM is used to fill in missing truth values; "
00266   //     "otherwise missing truth values are set to false. Can only be used with "
00267   //     "the structural gradient descent algorithm"),
00268 
00269   ARGS()
00270 };
00271 
00272 
00273 bool checkParams()
00274 {
00275   bool ok = true;
00276   if (beamSize<=0) {cout<<"ERROR: beamSize must be positive"<<endl; ok =false;}
00277 
00278   if (minWt<0) { cout << "ERROR: minWt must be non-negative" << endl;ok =false;}
00279 
00280   if (penalty<0) { cout <<"ERROR: penalty must be non-negative"<<endl;ok=false;}
00281 
00282   if (maxVars<=0) { cout << "ERROR: maxVar must be positive" << endl;ok =false;}
00283 
00284   if (maxNumPredicates <= 0)
00285   { 
00286     cout << "ERROR: maxNumPredicates must be positive" << endl; ok = false;
00287   }
00288   
00289   if (cacheSize < 0) 
00290   { cout << "ERROR: cacheSize must be non-negative" << endl; ok = false; }
00291 
00292   if (ddelta <= 0 || ddelta > 1) 
00293   { cout << "ERROR: gamma must be between 0 and 1" << endl; ok = false; }
00294 
00295   if (epsilon <= 0 || epsilon >= 1) 
00296   { cout << "ERROR: epsilon must be between 0 and 1" << endl; ok = false; }
00297 
00298   if (fraction < 0 || fraction > 1) 
00299   { cout << "ERROR: fraction must be between 0 and 1" << endl; ok = false;}
00300 
00301   if (priorMean < 0) 
00302   { cout << "ERROR: priorMean must be non-negative" << endl; ok = false; }
00303 
00304   if (priorStdDev <= 0) 
00305   { cout << "ERROR: priorStdDev must be positive" << endl; ok = false; }
00306 
00307   if (lbMaxIter <= 0) 
00308   { cout << "ERROR: tightMaxIter must be positive" << endl;  ok = false; }
00309 
00310   if (lbConvThresh <= 0 || lbConvThresh >= 1) 
00311   { cout << "ERROR: tightConvThresh must be between 0 and 1" << endl; ok=false;}
00312 
00313   if (looseMaxIter <= 0) 
00314   { cout << "ERROR: looseMaxIter must be positive" << endl; ok = false; }
00315 
00316   if (looseConvThresh <= 0 || looseConvThresh >= 1) 
00317   { cout << "ERROR: looseConvThresh must be between 0 and 1" << endl; ok=false;}
00318   
00319   if (numEstBestClauses <= 0) 
00320   { cout << "ERROR: numClausesReEval must be positive" << endl; ok = false; }
00321 
00322   if (bestGainUnchangedLimit <= 0) 
00323   { cout << "ERROR: bestGainUnchangedLimit must be positive" << endl; ok=false;}
00324 
00325   if (!structGradDescent && withEM)
00326   { cout << "ERROR: EM can only be used with structural gradient descent" << endl; ok=false; }
00327   
00328   if (structGradDescent && nonEvidPredsStr == NULL)
00329   {
00330     cout << "ERROR: you must specify non-evidence predicates for "
00331          << "structural gradient descent" << endl;
00332     ok = false;
00333   }
00334   
00335   return ok;
00336 }
00337 
00338 
00339 //void extractFileNames(...) defined in learnwts.h
00340 //void createDomainsAndMLNs(...) defined in learnwts.h 
00341 //void deleteDomains(...) defined in learnwts.h 
00342 //bool extractPredNames(...) defined in infer.h
00343 
00344 
00345 int main(int argc, char* argv[])
00346 {
00347   ARGS::parse(argc,argv,&cout);
00348   Timer timer;
00349   double begSec, startSec = timer.time();
00350 
00351     //Compute the size in MB of the components of Term, Predicate, and Clause
00352     //that do not change. The sizes will be used when computing the running 
00353     //size of the cache of clauses.
00354   Term::computeFixedSizeB();
00355   Predicate::computeFixedSizeB();
00356   Clause::computeFixedSizeB();
00357   AuxClauseData::computeFixedSizeB();
00358   
00359 
00361 
00362   if (!checkParams()) return -1;
00363   
00364   //the second .mln file to the last one in inMLNFiles _may_ be used 
00365   //to hold constants, so they are held in constFilesArr. They will be
00366   //included into the first .mln file.
00367 
00368     //extract .mln and .db, file names
00369   Array<string> constFilesArr, dbFilesArr;
00370   extractFileNames(inMLNFiles, constFilesArr);
00371   assert(constFilesArr.size() >= 1);
00372   string inMLNFile = constFilesArr[0];
00373   constFilesArr.removeItem(0);
00374   extractFileNames(dbFiles, dbFilesArr);
00375 
00376   if (dbFilesArr.size() <= 0)
00377   { cout << "ERROR: must specify training data with -t flag."<<endl; return -1;}
00378 
00379     // if multiple databases, check the number of .db/.func files
00380   if (multipleDatabases) 
00381   {
00382       //if # .mln files containing constants and .db files are diff
00383     if ( (constFilesArr.size() > 0 && constFilesArr.size() != dbFilesArr.size()))
00384     {
00385       cout << "ERROR: when there are multiple databases, if .mln files "
00386            << "containing constants are specified, there must " 
00387            << "be the same number of them as .db files" << endl;
00388       return -1;
00389     }
00390   }
00391 
00392   StringHashArray tmpNEPredNames;
00393   Array<string> nonEvidPredNames;
00394   if (nonEvidPredsStr)
00395   {
00396     if(!extractPredNames(nonEvidPredsStr, NULL, tmpNEPredNames))
00397     {
00398       cout << "ERROR: failed to extract non-evidence predicate names." << endl;
00399       return -1;
00400     }
00401     for (int i = 0; i < tmpNEPredNames.size(); i++) 
00402       nonEvidPredNames.append(tmpNEPredNames[i]);
00403   }
00404 
00406 
00407   Array<Domain*> domains;
00408   Array<MLN*> mlns;
00409   StringHashArray* queryPredNames = NULL;
00410   if (structGradDescent)
00411   {
00412     queryPredNames = new StringHashArray();
00413     for (int i = 0; i < nonEvidPredNames.size(); i++) 
00414       queryPredNames->append(nonEvidPredNames[i]);
00415   }
00416 
00417   bool addUnitClauses = false;
00418     // TODO: Allow user to declare which preds are open-world
00419   bool allPredsExceptQueriesAreCW = true;
00420   bool mwsLazy = true;
00421   if (structGradDescent && withEM) allPredsExceptQueriesAreCW = false;
00422   begSec = timer.time();
00423   cout << "Parsing MLN and creating domains..." << endl;
00424   createDomainsAndMLNs(domains, mlns, multipleDatabases, inMLNFile, 
00425                        constFilesArr, dbFilesArr, queryPredNames,
00426                        addUnitClauses, priorMean, mwsLazy,
00427                        allPredsExceptQueriesAreCW, NULL, NULL);
00428   cout << "Parsing MLN and creating domains took ";
00429   Timer::printTime(cout, timer.time()-begSec); cout << endl << endl;
00430 
00431   /*
00432   cout << "Clause prior means:" << endl;
00433   cout << "_________________________________" << endl;
00434   mlns[0]->printClausePriorMeans(cout, domains[0]);
00435   cout << "_________________________________" << endl;
00436   cout << endl;
00437 
00438   cout << "Formula prior means:" << endl;
00439   cout << "_________________________________" << endl;
00440   mlns[0]->printFormulaPriorMeans(cout);
00441   cout << "_________________________________" << endl;
00442   cout << endl;
00443   //*/
00444 
00445 
00447 
00448   if (nonEvidPredNames.size() == 0) 
00449     domains[0]->getNonEqualPredicateNames(nonEvidPredNames);
00450   bool cacheClauses = (cacheSize > 0);
00451   bool reEvalBestCandidatesWithTightParams = true;
00452   
00453   StructLearn sl(&mlns, startFromEmptyMLN, outMLNFile, &domains, 
00454                  &nonEvidPredNames, maxVars, maxNumPredicates, cacheClauses,
00455                  cacheSize, tryAllFlips,
00456                  !noSampleClauses, ddelta, epsilon, 
00457                  minClauseSamples, maxClauseSamples,
00458                  !noPrior, priorMean, priorStdDev, 
00459                  !noWtPredsEqually, 
00460                  lbMaxIter, lbConvThresh, looseMaxIter, looseConvThresh, 
00461                  beamSize, bestGainUnchangedLimit, numEstBestClauses, 
00462                  minWt, penalty, 
00463                  !noSampleGndPreds,fraction,minGndPredSamples,maxGndPredSamples,
00464                  reEvalBestCandidatesWithTightParams, structGradDescent,
00465                  withEM);
00466   sl.run();
00467 
00468   
00470   
00471   deleteDomains(domains);
00472   for (int i = 0; i < mlns.size(); i++) delete mlns[i];
00473   PowerSet::deletePowerSet();
00474 
00475   cout << "Total time taken = "; 
00476   Timer::printTime(cout, timer.time()-startSec); cout << endl;
00477 }

Generated on Sun Jun 7 11:55:13 2009 for Alchemy by  doxygen 1.5.1