learnwts.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef LEARNWTS_H_NOV_23_2005
00067 #define LEARNWTS_H_NOV_23_2005
00068 
00069 #include <sys/time.h>
00070 #include "util.h"
00071 #include "timer.h"
00072 #include "fol.h"
00073 #include "mln.h"
00074 #include "indextranslator.h"
00075 
00076 
00077 extern const char* ZZ_TMP_FILE_POSTFIX; //defined in fol.y
00078 const bool DOMAINS_SHARE_DATA_STRUCT = true;
00079 
00087 void extractFileNames(const char* const & namesStr, Array<string>& namesArray)
00088 {
00089   if (namesStr == NULL) return;
00090   string s(namesStr);
00091   s = Util::trim(s);
00092   if (s.length() == 0) return;
00093   s.append(",");
00094   string::size_type cur = 0;
00095   string::size_type comma;
00096   string name;
00097   while (true)
00098   {
00099     comma = s.find(",", cur);
00100     if (comma == string::npos) return;
00101     name = s.substr(cur, comma-cur);
00102     namesArray.append(name);
00103     cur = comma+1;
00104   }
00105 }
00106 
00107 void createDomainAndMLN(Array<Domain*>& domains, Array<MLN*>& mlns,
00108                         const string& inMLNFile, ostringstream& constFiles,
00109                         ostringstream& dbFiles,
00110                         const StringHashArray* const & nonEvidPredNames,
00111                         const bool& addUnitClauses, const double& priorMean,
00112                         const bool& checkPlusTypes, const bool& mwsLazy,
00113                         const bool& allPredsExceptQueriesAreCW,
00114                         const StringHashArray* const & owPredNames)
00115 {
00116   string::size_type bslash = inMLNFile.rfind("/");
00117   string tmp = (bslash == string::npos) ? 
00118                inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
00119   char tmpInMLN[100];
00120   sprintf(tmpInMLN, "%s%s",  tmp.c_str(), ZZ_TMP_FILE_POSTFIX);
00121 
00122   ofstream out(tmpInMLN);
00123   ifstream in(inMLNFile.c_str());
00124   if (!out.good()) { cout<<"ERROR: failed to open "<<tmpInMLN <<endl; exit(-1);}
00125   if (!in.good())  { cout<<"ERROR: failed to open "<<inMLNFile<<endl; exit(-1);}
00126 
00127   string buffer;
00128   while(getline(in, buffer)) out << buffer << endl;
00129   in.close();
00130 
00131   out << constFiles.str() << endl 
00132       << dbFiles.str() << endl;
00133   out.close();
00134   
00135   // read the formulas from the input MLN
00136   Domain* domain = new Domain;
00137   MLN* mln = new MLN();
00138   
00139     // Unknown evidence atoms are filled in by EM
00140   //bool allPredsExceptQueriesAreCW = true;
00141   bool warnAboutDupGndPreds = true;
00142   bool mustHaveWtOrFullStop = false;
00143   bool flipWtsOfFlippedClause = false;
00144   Domain* domain0 = (checkPlusTypes) ? domains[0] : NULL;
00145 
00146   bool ok = runYYParser(mln, domain, tmpInMLN, allPredsExceptQueriesAreCW, 
00147                         owPredNames, nonEvidPredNames, addUnitClauses, 
00148                         warnAboutDupGndPreds, priorMean, mustHaveWtOrFullStop,
00149                         domain0, mwsLazy, flipWtsOfFlippedClause);
00150 
00151   if (!ok) { unlink(tmpInMLN); exit(-1); }
00152   domains.append(domain);
00153   mlns.append(mln);
00154   unlink(tmpInMLN);
00155 }
00156 
00157 
00158 void createDomainsAndMLNs(Array<Domain*>& domains, Array<MLN*>& mlns, 
00159                           const bool& multipleDatabases,
00160                           const string& inMLNFile,
00161                           const Array<string>& constFilesArr,
00162                           const Array<string>& dbFilesArr,
00163                           const StringHashArray* const & nonEvidPredNames,
00164                           const bool& addUnitClauses, const double& priorMean,
00165                           const bool& mwsLazy,
00166                           const bool& allPredsExceptQueriesAreCW,
00167                           const StringHashArray* const & owPredNames)
00168 {
00169   if (!multipleDatabases)
00170   {
00171     ostringstream constFilesStream, dbFilesStream;
00172     for (int i = 0; i < constFilesArr.size(); i++) 
00173       constFilesStream << "#include \"" << constFilesArr[i] << "\"" << endl;
00174     for (int i = 0; i < dbFilesArr.size(); i++)    
00175       dbFilesStream << "#include \"" << dbFilesArr[i] << "\"" << endl;
00176     createDomainAndMLN(domains, mlns, inMLNFile, constFilesStream, 
00177                        dbFilesStream, nonEvidPredNames,
00178                        addUnitClauses, priorMean, false, mwsLazy,
00179                        allPredsExceptQueriesAreCW, owPredNames);
00180   }
00181   else
00182   {   //if multiple databases
00183     for (int i = 0; i < dbFilesArr.size(); i++) // for each domain
00184     {
00185       cout << "parsing MLN and creating domain " << i << "..." << endl;
00186       ostringstream constFilesStream, dbFilesStream;
00187       if (constFilesArr.size() > 0)
00188         constFilesStream << "#include \"" << constFilesArr[i] << "\"" << endl;
00189       dbFilesStream    << "#include \"" << dbFilesArr[i]    << "\"" << endl;
00190       
00191       bool checkPlusTypes = (i > 0);
00192 
00193       createDomainAndMLN(domains, mlns, inMLNFile, constFilesStream,
00194                          dbFilesStream, nonEvidPredNames,
00195                          addUnitClauses, priorMean, checkPlusTypes, mwsLazy,
00196                          allPredsExceptQueriesAreCW, owPredNames);
00197 
00198         // let the domains share data structures
00199       if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
00200       {
00201         const ClauseHashArray* carr = mlns[i]->getClauses();
00202         for (int j = 0; j < carr->size(); j++)
00203         {
00204           Clause* c = (*carr)[j];
00205           for (int k = 0; k < c->getNumPredicates(); k++)
00206           {
00207             Predicate* p = c->getPredicate(k);
00208             const PredicateTemplate* t 
00209               = domains[0]->getPredicateTemplate(p->getName());
00210             assert(t);
00211             p->setTemplate((PredicateTemplate*)t);
00212           }
00213         }
00214 
00215         ((Domain*)domains[i])->replaceTypeDualMap((
00216                                          DualMap*)domains[0]->getTypeDualMap());
00217         ((Domain*)domains[i])->replaceStrToPredTemplateMapAndPredDualMap(
00218                   (StrToPredTemplateMap*) domains[0]->getStrToPredTemplateMap(),
00219                   (DualMap*) domains[0]->getPredDualMap());
00220         ((Domain*)domains[i])->replaceStrToFuncTemplateMapAndFuncDualMap(
00221                   (StrToFuncTemplateMap*) domains[0]->getStrToFuncTemplateMap(),
00222                   (DualMap*) domains[0]->getFuncDualMap());
00223         ((Domain*)domains[i])->replaceEqualPredTemplate(
00224                         (PredicateTemplate*)domains[0]->getEqualPredTemplate());
00225         ((Domain*)domains[i])->replaceFuncSet(
00226                                         (FunctionSet*)domains[0]->getFuncSet());
00227       }
00228 
00229     } // for each domain
00230   }
00231 
00232   //commented out: not true when there are domains with different constants
00233   //int numClauses = mlns[0]->getNumClauses();
00234   //for (int i = 1; i < mlns.size(); i++) 
00235   //  assert(mlns[i]->getNumClauses() == numClauses);
00236   //numClauses = 0; //avoid compilation warning
00237 }
00238 
00239 void assignWtsAndOutputMLN(ostream& out, Array<MLN*>& mlns, 
00240                            Array<Domain*>& domains, const Array<double>& wts, 
00241                            IndexTranslator* const& indexTrans)
00242 {
00243     //assign the optimal weights belonging to clauses (and none of those 
00244     //belonging to existentially quantified formulas) to the MLNs
00245   double* wwts = (double*) wts.getItems();
00246   indexTrans->assignNonTiedClauseWtsToMLNs(++wwts);
00247 
00248 
00249     // output the predicate declaration
00250   out << "//predicate declarations" << endl;
00251   domains[0]->printPredicateTemplates(out);
00252   out << endl;
00253 
00254   // output the function declarations
00255   out << "//function declarations" << endl;
00256   domains[0]->printFunctionTemplates(out);
00257   out << endl;
00258 
00259   mlns[0]->printMLNNonExistFormulas(out, domains[0]);
00260 
00261   const ClauseHashArray* clauseOrdering = indexTrans->getClauseOrdering();
00262   const StringHashArray* exFormOrdering = indexTrans->getExistFormulaOrdering();
00263   for (int i = 0; i < exFormOrdering->size(); i++)
00264   {
00265       // output the original formula and its weight
00266     out.width(0); out << "// "; out.width(6); 
00267     out << wts[1+clauseOrdering->size()+i] << "  " <<(*exFormOrdering)[i]<<endl;
00268     out << wts[1+clauseOrdering->size()+i] << "  " <<(*exFormOrdering)[i]<<endl;
00269     out << endl;
00270   }
00271 }
00272 
00273 
00274 void assignWtsAndOutputMLN(ostream& out, Array<MLN*>& mlns, 
00275                            Array<Domain*>& domains, const Array<double>& wts)
00276 {
00277     // assign the optimal weights to the clauses in all MLNs
00278   for (int i = 0; i < mlns.size(); i++)
00279   {
00280     MLN* mln = mlns[i];
00281     const ClauseHashArray* clauses = mln->getClauses();
00282     for (int i = 0; i < clauses->size(); i++) 
00283       (*clauses)[i]->setWt(wts[i+1]);
00284   }
00285 
00286     // output the predicate declaration
00287   out << "//predicate declarations" << endl;
00288   domains[0]->printPredicateTemplates(out);
00289   out << endl;
00290 
00291   // output the function declarations
00292   out << "//function declarations" << endl;
00293   domains[0]->printFunctionTemplates(out);
00294   out << endl;
00295   mlns[0]->printMLN(out, domains[0]);
00296 }
00297 
00298 
00299 void deleteDomains(Array<Domain*>& domains)
00300 {
00301   for (int i = 0; i < domains.size(); i++) 
00302   {
00303     if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
00304     {
00305       ((Domain*)domains[i])->setTypeDualMap(NULL);
00306       ((Domain*)domains[i])->setStrToPredTemplateMapAndPredDualMap(NULL, NULL);
00307       ((Domain*)domains[i])->setStrToFuncTemplateMapAndFuncDualMap(NULL, NULL);
00308       ((Domain*)domains[i])->setEqualPredTemplate(NULL);
00309       ((Domain*)domains[i])->setFuncSet(NULL);
00310     }
00311     delete domains[i];
00312   }
00313 }
00314 
00315 
00316 #endif

Generated on Tue Jan 16 05:30:03 2007 for Alchemy by  doxygen 1.5.1