learnwts.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #ifndef LEARNWTS_H_NOV_23_2005
00068 #define LEARNWTS_H_NOV_23_2005
00069 
00070 #include <sys/time.h>
00071 #include "util.h"
00072 #include "timer.h"
00073 #include "fol.h"
00074 #include "mln.h"
00075 #include "indextranslator.h"
00076 
00077 
00078 extern const char* ZZ_TMP_FILE_POSTFIX; //defined in folhelper.h
00079 const bool DOMAINS_SHARE_DATA_STRUCT = true;
00080 
00088 void extractFileNames(const char* const & namesStr, Array<string>& namesArray)
00089 {
00090   if (namesStr == NULL) return;
00091   string s(namesStr);
00092   s = Util::trim(s);
00093   if (s.length() == 0) return;
00094   s.append(",");
00095   string::size_type cur = 0;
00096   string::size_type comma;
00097   string name;
00098   while (true)
00099   {
00100     comma = s.find(",", cur);
00101     if (comma == string::npos) return;
00102     name = s.substr(cur, comma-cur);
00103     namesArray.append(name);
00104     cur = comma+1;
00105   }
00106 }
00107 
00115 void extractArgs(const char* const & argsStr, int& argc, char** argv)
00116 {
00117   argc = 0;
00118   if (argsStr == NULL) return;
00119   string s(argsStr);
00120   s = Util::trim(s);
00121   if (s.length() == 0) return;
00122   s.append(" ");
00123   string::size_type cur = 0;
00124   string::size_type blank;
00125   string arg;
00126 
00127   while (true)
00128   {
00129     blank = s.find(" ", cur);
00130     if (blank == string::npos) return;
00131     arg = s.substr(cur, blank-cur);
00132     arg = Util::trim(arg);
00133     memset(argv[argc], '\0', 500);
00134     arg.copy(argv[argc], arg.length());
00135     argc++;
00136     cur = blank + 1;
00137   }
00138 }
00139 
00140 void createDomainAndMLN(Array<Domain*>& domains, Array<MLN*>& mlns,
00141                         const string& inMLNFile, ostringstream& constFiles,
00142                         ostringstream& dbFiles,
00143                         const StringHashArray* const & nonEvidPredNames,
00144                         const bool& addUnitClauses, const double& priorMean,
00145                         const bool& checkPlusTypes, const bool& mwsLazy,
00146                         const bool& allPredsExceptQueriesAreCW,
00147                         const StringHashArray* const & owPredNames,
00148                         const StringHashArray* const & cwPredNames)
00149 {
00150   string::size_type bslash = inMLNFile.rfind("/");
00151   string tmp = (bslash == string::npos) ? 
00152                inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
00153   char tmpInMLN[100];  
00154   sprintf(tmpInMLN, "%s%d%s",  tmp.c_str(), getpid(), ZZ_TMP_FILE_POSTFIX);
00155 
00156   ofstream out(tmpInMLN);
00157   ifstream in(inMLNFile.c_str());
00158   if (!out.good()) { cout<<"ERROR: failed to open "<<tmpInMLN <<endl; exit(-1);}
00159   if (!in.good())  { cout<<"ERROR: failed to open "<<inMLNFile<<endl; exit(-1);}
00160 
00161   string buffer;
00162   while(getline(in, buffer)) out << buffer << endl;
00163   in.close();
00164 
00165   out << constFiles.str() << endl 
00166       << dbFiles.str() << endl;
00167   out.close();
00168   
00169   // read the formulas from the input MLN
00170   Domain* domain = new Domain;
00171   MLN* mln = new MLN();
00172   
00173     // Unknown evidence atoms are filled in by EM
00174   bool warnAboutDupGndPreds = true;
00175   bool mustHaveWtOrFullStop = false;
00176   bool flipWtsOfFlippedClause = false;
00177   Domain* domain0 = (checkPlusTypes) ? domains[0] : NULL;
00178 
00179   bool ok = runYYParser(mln, domain, tmpInMLN, allPredsExceptQueriesAreCW, 
00180                         owPredNames, cwPredNames, nonEvidPredNames,
00181                         addUnitClauses,  warnAboutDupGndPreds, priorMean,
00182                         mustHaveWtOrFullStop, domain0, mwsLazy,
00183                         flipWtsOfFlippedClause);
00184 
00185   unlink(tmpInMLN);
00186   if (!ok) exit(-1);
00187   domains.append(domain);
00188   mlns.append(mln);
00189 }
00190 
00191 
00192 void createDomainsAndMLNs(Array<Domain*>& domains, Array<MLN*>& mlns, 
00193                           const bool& multipleDatabases,
00194                           const string& inMLNFile,
00195                           const Array<string>& constFilesArr,
00196                           const Array<string>& dbFilesArr,
00197                           const StringHashArray* const & nonEvidPredNames,
00198                           const bool& addUnitClauses, const double& priorMean,
00199                           const bool& mwsLazy,
00200                           const bool& allPredsExceptQueriesAreCW,
00201                           const StringHashArray* const & owPredNames,
00202                           const StringHashArray* const & cwPredNames)
00203 {
00204   if (!multipleDatabases)
00205   {
00206     ostringstream constFilesStream, dbFilesStream;
00207     for (int i = 0; i < constFilesArr.size(); i++) 
00208       constFilesStream << "#include \"" << constFilesArr[i] << "\"" << endl;
00209     for (int i = 0; i < dbFilesArr.size(); i++)    
00210       dbFilesStream << "#include \"" << dbFilesArr[i] << "\"" << endl;
00211     createDomainAndMLN(domains, mlns, inMLNFile, constFilesStream, 
00212                        dbFilesStream, nonEvidPredNames,
00213                        addUnitClauses, priorMean, false, mwsLazy,
00214                        allPredsExceptQueriesAreCW, owPredNames, cwPredNames);
00215   }
00216   else
00217   {   //if multiple databases
00218     for (int i = 0; i < dbFilesArr.size(); i++) // for each domain
00219     {
00220       cout << "parsing MLN and creating domain " << i << "..." << endl;
00221       ostringstream constFilesStream, dbFilesStream;
00222       if (constFilesArr.size() > 0)
00223         constFilesStream << "#include \"" << constFilesArr[i] << "\"" << endl;
00224       dbFilesStream    << "#include \"" << dbFilesArr[i]    << "\"" << endl;
00225       
00226       bool checkPlusTypes = (i > 0);
00227 
00228       createDomainAndMLN(domains, mlns, inMLNFile, constFilesStream,
00229                          dbFilesStream, nonEvidPredNames,
00230                          addUnitClauses, priorMean, checkPlusTypes, mwsLazy,
00231                          allPredsExceptQueriesAreCW, owPredNames, cwPredNames);
00232 
00233         // let the domains share data structures
00234       if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
00235       {
00236                   // Add new constants to base domain as external
00237           // Assumption: all domains have same number and ordering of types
00238         int numTypes = domains[i]->getNumTypes();
00239         for (int j = 0; j < numTypes; j++)
00240         {
00241           const char* typeName = domains[i]->getTypeName(j);
00242           const Array<int>* constantsByType = domains[i]->getConstantsByType(j);
00243           for (int k = 0; k < constantsByType->size(); k++)
00244           {
00245             const char* constantName =
00246               domains[i]->getConstantName((*constantsByType)[k]);
00247               // Add external constant to base domain
00248             if (!domains[0]->isConstant(constantName))
00249             {
00250               domains[0]->addExternalConstant(constantName, typeName);
00251             }
00252           }
00253         }
00254  
00255           // Share all predicate templates
00256         const ClauseHashArray* carr = mlns[i]->getClauses();
00257         for (int j = 0; j < carr->size(); j++)
00258         {
00259           Clause* c = (*carr)[j];
00260 
00261             // Share all predicate templates
00262           for (int k = 0; k < c->getNumPredicates(); k++)
00263           {
00264             Predicate* p = c->getPredicate(k);
00265             const PredicateTemplate* t 
00266               = domains[0]->getPredicateTemplate(p->getName());
00267             assert(t);
00268             p->setTemplate((PredicateTemplate*)t);
00269           }
00270         } // for all clauses
00271 
00272           // All mlns can use the following data structures from mln0
00273         ((Domain*)domains[i])->replaceTypeDualMap((
00274                                          DualMap*)domains[0]->getTypeDualMap());
00275 /*
00276         ((Domain*)domains[i])->replaceStrToPredTemplateMapAndPredDualMap(
00277                   (StrToPredTemplateMap*) domains[0]->getStrToPredTemplateMap(),
00278                   (DualMap*) domains[0]->getPredDualMap());
00279 */
00280         ((Domain*)domains[i])->replaceStrToFuncTemplateMapAndFuncDualMap(
00281                   (StrToFuncTemplateMap*) domains[0]->getStrToFuncTemplateMap(),
00282                   (DualMap*) domains[0]->getFuncDualMap());
00283         ((Domain*)domains[i])->replaceEqualPredTemplate(
00284                         (PredicateTemplate*)domains[0]->getEqualPredTemplate());
00285         ((Domain*)domains[i])->replaceFuncSet(
00286                                         (FunctionSet*)domains[0]->getFuncSet());
00287 
00288       }
00289     } // for each domain
00290 
00291         // Reorder constants in all domains
00292         domains[0]->reorderConstants(mlns[0]);
00293     for (int i = 1; i < domains.size(); i++)
00294     {
00295       ((Domain*)domains[i])->reorderConstants(
00296                                    (ConstDualMap*)domains[0]->getConstDualMap(),
00297                           (Array<Array<int>*>*)domains[0]->getConstantsByType(),
00298                   (Array<Array<int>*>*)domains[0]->getExternalConstantsByType(),
00299                                               mlns[i]);
00300     }
00301 
00302       // Share clauses across all mlns as external clauses.
00303       // This happens when per-constant is used and dbs have diff. constants
00304     Array<Array<bool>*>* externalClausesPerMLN = new Array<Array<bool>*>;
00305     externalClausesPerMLN->growToSize(mlns.size());
00306     (*externalClausesPerMLN)[0] = new Array<bool>;
00307     (*externalClausesPerMLN)[0]->growToSize(mlns[0]->getNumClauses(), false);
00308     for (int i = 1; i < mlns.size(); i++)
00309     {
00310       const ClauseHashArray* carr = mlns[i]->getClauses();
00311       (*externalClausesPerMLN)[i] = new Array<bool>;
00312       for (int j = 0; j < carr->size(); j++)
00313       {
00314         Clause* c = (*carr)[j];
00315         if (!mlns[0]->containsClause(c))
00316         {
00317           string formulaString = mlns[i]->getParentFormula(j, 0);
00318           bool hasExist = mlns[i]->isExistClause(j);
00319           if (mlns[0]->appendExternalClause(formulaString, hasExist,
00320                                             new Clause(*c), domains[0], false))
00321           {
00322             (*externalClausesPerMLN)[0]->append(false);
00323           }
00324         }
00325       } // for all clauses
00326     } // for all mlns except base
00327 
00328     const ClauseHashArray* carr = mlns[0]->getClauses();
00329     for (int j = 0; j < carr->size(); j++)
00330     {
00331       Clause* c = (*carr)[j];
00332         // Do not add existentially qualified clauses
00333       //if (mlns[i]->isExistClause(j)) continue;
00334       for (int k = 1; k < mlns.size(); k++)
00335       {
00336         if (mlns[k]->containsClause(c))
00337           (*externalClausesPerMLN)[k]->append(false);
00338         else
00339           (*externalClausesPerMLN)[k]->append(true);
00340       }
00341     } // for all clauses
00342 
00343     for (int i = 1; i < mlns.size(); i++)
00344     {
00345       mlns[i]->replaceClauses(new
00346         ClauseHashArray(*(ClauseHashArray*)mlns[0]->getClauses()));
00347       mlns[i]->replaceMLNClauseInfos(new
00348         Array<MLNClauseInfo*>(*(Array<MLNClauseInfo*>*)mlns[0]->
00349           getMLNClauseInfos()));
00350       mlns[i]->replacePredIdToClausesMap(new
00351         Array<Array<IndexClause*>*>(*(Array<Array<IndexClause*>*>*)mlns[0]->
00352           getPredIdToClausesMap()));
00353       mlns[i]->replaceFormulaAndClausesArray(new
00354         FormulaAndClausesArray(*(FormulaAndClausesArray*)mlns[0]->
00355           getFormulaAndClausesArray()));
00356       mlns[i]->replaceExternalClause((*externalClausesPerMLN)[i]);
00357     }
00358 
00359     delete externalClausesPerMLN;
00360 
00361     for (int i = 1; i < mlns.size(); i++)
00362       mlns[i]->rehashClauses();
00363  } // if multiple databases
00364   
00365 
00366   //commented out: not true when there are domains with different constants
00367   //int numClauses = mlns[0]->getNumClauses();
00368   //for (int i = 1; i < mlns.size(); i++) 
00369   //  assert(mlns[i]->getNumClauses() == numClauses);
00370   //numClauses = 0; //avoid compilation warning
00371 }
00372 
00373 void assignWtsAndOutputMLN(ostream& out, Array<MLN*>& mlns, 
00374                            Array<Domain*>& domains, const Array<double>& wts, 
00375                            IndexTranslator* const& indexTrans)
00376 {
00377     //assign the optimal weights belonging to clauses (and none of those 
00378     //belonging to existentially quantified formulas) to the MLNs
00379   double* wwts = (double*) wts.getItems();
00380   indexTrans->assignNonTiedClauseWtsToMLNs(++wwts);
00381 
00382 
00383     // output the predicate declaration
00384   out << "//predicate declarations" << endl;
00385   domains[0]->printPredicateTemplates(out);
00386   out << endl;
00387 
00388     // output the function declarations
00389   if (domains[0]->getNumFunctions() > 0) 
00390   {
00391     out << "//function declarations" << endl;
00392     domains[0]->printFunctionTemplates(out);
00393     out << endl;
00394   }
00395 
00396   mlns[0]->printMLNNonExistFormulas(out, domains[0]);
00397 
00398   const ClauseHashArray* clauseOrdering = indexTrans->getClauseOrdering();
00399   const StringHashArray* exFormOrdering = indexTrans->getExistFormulaOrdering();
00400   for (int i = 0; i < exFormOrdering->size(); i++)
00401   {
00402       // output the original formula and its weight
00403     out.width(0); out << "// "; out.width(6); 
00404     out << wts[1+clauseOrdering->size()+i] << "  " <<(*exFormOrdering)[i]<<endl;
00405     out << wts[1+clauseOrdering->size()+i] << "  " <<(*exFormOrdering)[i]<<endl;
00406     out << endl;
00407   }
00408 }
00409 
00410 
00411 void assignWtsAndOutputMLN(ostream& out, Array<MLN*>& mlns, 
00412                            Array<Domain*>& domains, const Array<double>& wts)
00413 {
00414     // assign the optimal weights to the clauses in all MLNs
00415   for (int i = 0; i < mlns.size(); i++)
00416   {
00417     MLN* mln = mlns[i];
00418     const ClauseHashArray* clauses = mln->getClauses();
00419     for (int i = 0; i < clauses->size(); i++) 
00420       (*clauses)[i]->setWt(wts[i+1]);
00421   }
00422 
00423     // output the predicate declaration
00424   out << "//predicate declarations" << endl;
00425   domains[0]->printPredicateTemplates(out);
00426   out << endl;
00427 
00428     // output the function declarations
00429   if (domains[0]->getNumFunctions() > 0) 
00430   {
00431     // output the function declarations
00432     out << "//function declarations" << endl;
00433     domains[0]->printFunctionTemplates(out);
00434     out << endl;
00435   }
00436   mlns[0]->printMLN(out, domains[0]);
00437 }
00438 
00439 
00440 void deleteDomains(Array<Domain*>& domains)
00441 {
00442   for (int i = 0; i < domains.size(); i++) 
00443   {
00444     if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
00445     {
00446       ((Domain*)domains[i])->setTypeDualMap(NULL);
00447       ((Domain*)domains[i])->setStrToPredTemplateMapAndPredDualMap(NULL, NULL);
00448       ((Domain*)domains[i])->setStrToFuncTemplateMapAndFuncDualMap(NULL, NULL);
00449       ((Domain*)domains[i])->setEqualPredTemplate(NULL);
00450       ((Domain*)domains[i])->setFuncSet(NULL);
00451 
00452         // Need this since it is pointing to domain0's copy
00453           ((Domain*)domains[i])->setConstDualMap(NULL);
00454           ((Domain*)domains[i])->setConstantsByType(NULL);
00455     }
00456     delete domains[i];
00457   }
00458 }
00459 
00460 
00461 #endif

Generated on Sun Jun 7 11:55:14 2009 for Alchemy by  doxygen 1.5.1