00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #ifndef LEARNWTS_H_NOV_23_2005
00068 #define LEARNWTS_H_NOV_23_2005
00069
00070 #include <sys/time.h>
00071 #include "util.h"
00072 #include "timer.h"
00073 #include "fol.h"
00074 #include "mln.h"
00075 #include "indextranslator.h"
00076
00077
00078 extern const char* ZZ_TMP_FILE_POSTFIX;
00079 const bool DOMAINS_SHARE_DATA_STRUCT = true;
00080
00088 void extractFileNames(const char* const & namesStr, Array<string>& namesArray)
00089 {
00090 if (namesStr == NULL) return;
00091 string s(namesStr);
00092 s = Util::trim(s);
00093 if (s.length() == 0) return;
00094 s.append(",");
00095 string::size_type cur = 0;
00096 string::size_type comma;
00097 string name;
00098 while (true)
00099 {
00100 comma = s.find(",", cur);
00101 if (comma == string::npos) return;
00102 name = s.substr(cur, comma-cur);
00103 namesArray.append(name);
00104 cur = comma+1;
00105 }
00106 }
00107
00115 void extractArgs(const char* const & argsStr, int& argc, char** argv)
00116 {
00117 argc = 0;
00118 if (argsStr == NULL) return;
00119 string s(argsStr);
00120 s = Util::trim(s);
00121 if (s.length() == 0) return;
00122 s.append(" ");
00123 string::size_type cur = 0;
00124 string::size_type blank;
00125 string arg;
00126
00127 while (true)
00128 {
00129 blank = s.find(" ", cur);
00130 if (blank == string::npos) return;
00131 arg = s.substr(cur, blank-cur);
00132 arg = Util::trim(arg);
00133 memset(argv[argc], '\0', 500);
00134 arg.copy(argv[argc], arg.length());
00135 argc++;
00136 cur = blank + 1;
00137 }
00138 }
00139
00140 void createDomainAndMLN(Array<Domain*>& domains, Array<MLN*>& mlns,
00141 const string& inMLNFile, ostringstream& constFiles,
00142 ostringstream& dbFiles,
00143 const StringHashArray* const & nonEvidPredNames,
00144 const bool& addUnitClauses, const double& priorMean,
00145 const bool& checkPlusTypes, const bool& mwsLazy,
00146 const bool& allPredsExceptQueriesAreCW,
00147 const StringHashArray* const & owPredNames,
00148 const StringHashArray* const & cwPredNames)
00149 {
00150 string::size_type bslash = inMLNFile.rfind("/");
00151 string tmp = (bslash == string::npos) ?
00152 inMLNFile:inMLNFile.substr(bslash+1,inMLNFile.length()-bslash-1);
00153 char tmpInMLN[100];
00154 sprintf(tmpInMLN, "%s%d%s", tmp.c_str(), getpid(), ZZ_TMP_FILE_POSTFIX);
00155
00156 ofstream out(tmpInMLN);
00157 ifstream in(inMLNFile.c_str());
00158 if (!out.good()) { cout<<"ERROR: failed to open "<<tmpInMLN <<endl; exit(-1);}
00159 if (!in.good()) { cout<<"ERROR: failed to open "<<inMLNFile<<endl; exit(-1);}
00160
00161 string buffer;
00162 while(getline(in, buffer)) out << buffer << endl;
00163 in.close();
00164
00165 out << constFiles.str() << endl
00166 << dbFiles.str() << endl;
00167 out.close();
00168
00169
00170 Domain* domain = new Domain;
00171 MLN* mln = new MLN();
00172
00173
00174 bool warnAboutDupGndPreds = true;
00175 bool mustHaveWtOrFullStop = false;
00176 bool flipWtsOfFlippedClause = false;
00177 Domain* domain0 = (checkPlusTypes) ? domains[0] : NULL;
00178
00179 bool ok = runYYParser(mln, domain, tmpInMLN, allPredsExceptQueriesAreCW,
00180 owPredNames, cwPredNames, nonEvidPredNames,
00181 addUnitClauses, warnAboutDupGndPreds, priorMean,
00182 mustHaveWtOrFullStop, domain0, mwsLazy,
00183 flipWtsOfFlippedClause);
00184
00185 unlink(tmpInMLN);
00186 if (!ok) exit(-1);
00187 domains.append(domain);
00188 mlns.append(mln);
00189 }
00190
00191
00192 void createDomainsAndMLNs(Array<Domain*>& domains, Array<MLN*>& mlns,
00193 const bool& multipleDatabases,
00194 const string& inMLNFile,
00195 const Array<string>& constFilesArr,
00196 const Array<string>& dbFilesArr,
00197 const StringHashArray* const & nonEvidPredNames,
00198 const bool& addUnitClauses, const double& priorMean,
00199 const bool& mwsLazy,
00200 const bool& allPredsExceptQueriesAreCW,
00201 const StringHashArray* const & owPredNames,
00202 const StringHashArray* const & cwPredNames)
00203 {
00204 if (!multipleDatabases)
00205 {
00206 ostringstream constFilesStream, dbFilesStream;
00207 for (int i = 0; i < constFilesArr.size(); i++)
00208 constFilesStream << "#include \"" << constFilesArr[i] << "\"" << endl;
00209 for (int i = 0; i < dbFilesArr.size(); i++)
00210 dbFilesStream << "#include \"" << dbFilesArr[i] << "\"" << endl;
00211 createDomainAndMLN(domains, mlns, inMLNFile, constFilesStream,
00212 dbFilesStream, nonEvidPredNames,
00213 addUnitClauses, priorMean, false, mwsLazy,
00214 allPredsExceptQueriesAreCW, owPredNames, cwPredNames);
00215 }
00216 else
00217 {
00218 for (int i = 0; i < dbFilesArr.size(); i++)
00219 {
00220 cout << "parsing MLN and creating domain " << i << "..." << endl;
00221 ostringstream constFilesStream, dbFilesStream;
00222 if (constFilesArr.size() > 0)
00223 constFilesStream << "#include \"" << constFilesArr[i] << "\"" << endl;
00224 dbFilesStream << "#include \"" << dbFilesArr[i] << "\"" << endl;
00225
00226 bool checkPlusTypes = (i > 0);
00227
00228 createDomainAndMLN(domains, mlns, inMLNFile, constFilesStream,
00229 dbFilesStream, nonEvidPredNames,
00230 addUnitClauses, priorMean, checkPlusTypes, mwsLazy,
00231 allPredsExceptQueriesAreCW, owPredNames, cwPredNames);
00232
00233
00234 if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
00235 {
00236
00237
00238 int numTypes = domains[i]->getNumTypes();
00239 for (int j = 0; j < numTypes; j++)
00240 {
00241 const char* typeName = domains[i]->getTypeName(j);
00242 const Array<int>* constantsByType = domains[i]->getConstantsByType(j);
00243 for (int k = 0; k < constantsByType->size(); k++)
00244 {
00245 const char* constantName =
00246 domains[i]->getConstantName((*constantsByType)[k]);
00247
00248 if (!domains[0]->isConstant(constantName))
00249 {
00250 domains[0]->addExternalConstant(constantName, typeName);
00251 }
00252 }
00253 }
00254
00255
00256 const ClauseHashArray* carr = mlns[i]->getClauses();
00257 for (int j = 0; j < carr->size(); j++)
00258 {
00259 Clause* c = (*carr)[j];
00260
00261
00262 for (int k = 0; k < c->getNumPredicates(); k++)
00263 {
00264 Predicate* p = c->getPredicate(k);
00265 const PredicateTemplate* t
00266 = domains[0]->getPredicateTemplate(p->getName());
00267 assert(t);
00268 p->setTemplate((PredicateTemplate*)t);
00269 }
00270 }
00271
00272
00273 ((Domain*)domains[i])->replaceTypeDualMap((
00274 DualMap*)domains[0]->getTypeDualMap());
00275
00276
00277
00278
00279
00280 ((Domain*)domains[i])->replaceStrToFuncTemplateMapAndFuncDualMap(
00281 (StrToFuncTemplateMap*) domains[0]->getStrToFuncTemplateMap(),
00282 (DualMap*) domains[0]->getFuncDualMap());
00283 ((Domain*)domains[i])->replaceEqualPredTemplate(
00284 (PredicateTemplate*)domains[0]->getEqualPredTemplate());
00285 ((Domain*)domains[i])->replaceFuncSet(
00286 (FunctionSet*)domains[0]->getFuncSet());
00287
00288 }
00289 }
00290
00291
00292 domains[0]->reorderConstants(mlns[0]);
00293 for (int i = 1; i < domains.size(); i++)
00294 {
00295 ((Domain*)domains[i])->reorderConstants(
00296 (ConstDualMap*)domains[0]->getConstDualMap(),
00297 (Array<Array<int>*>*)domains[0]->getConstantsByType(),
00298 (Array<Array<int>*>*)domains[0]->getExternalConstantsByType(),
00299 mlns[i]);
00300 }
00301
00302
00303
00304 Array<Array<bool>*>* externalClausesPerMLN = new Array<Array<bool>*>;
00305 externalClausesPerMLN->growToSize(mlns.size());
00306 (*externalClausesPerMLN)[0] = new Array<bool>;
00307 (*externalClausesPerMLN)[0]->growToSize(mlns[0]->getNumClauses(), false);
00308 for (int i = 1; i < mlns.size(); i++)
00309 {
00310 const ClauseHashArray* carr = mlns[i]->getClauses();
00311 (*externalClausesPerMLN)[i] = new Array<bool>;
00312 for (int j = 0; j < carr->size(); j++)
00313 {
00314 Clause* c = (*carr)[j];
00315 if (!mlns[0]->containsClause(c))
00316 {
00317 string formulaString = mlns[i]->getParentFormula(j, 0);
00318 bool hasExist = mlns[i]->isExistClause(j);
00319 if (mlns[0]->appendExternalClause(formulaString, hasExist,
00320 new Clause(*c), domains[0], false))
00321 {
00322 (*externalClausesPerMLN)[0]->append(false);
00323 }
00324 }
00325 }
00326 }
00327
00328 const ClauseHashArray* carr = mlns[0]->getClauses();
00329 for (int j = 0; j < carr->size(); j++)
00330 {
00331 Clause* c = (*carr)[j];
00332
00333
00334 for (int k = 1; k < mlns.size(); k++)
00335 {
00336 if (mlns[k]->containsClause(c))
00337 (*externalClausesPerMLN)[k]->append(false);
00338 else
00339 (*externalClausesPerMLN)[k]->append(true);
00340 }
00341 }
00342
00343 for (int i = 1; i < mlns.size(); i++)
00344 {
00345 mlns[i]->replaceClauses(new
00346 ClauseHashArray(*(ClauseHashArray*)mlns[0]->getClauses()));
00347 mlns[i]->replaceMLNClauseInfos(new
00348 Array<MLNClauseInfo*>(*(Array<MLNClauseInfo*>*)mlns[0]->
00349 getMLNClauseInfos()));
00350 mlns[i]->replacePredIdToClausesMap(new
00351 Array<Array<IndexClause*>*>(*(Array<Array<IndexClause*>*>*)mlns[0]->
00352 getPredIdToClausesMap()));
00353 mlns[i]->replaceFormulaAndClausesArray(new
00354 FormulaAndClausesArray(*(FormulaAndClausesArray*)mlns[0]->
00355 getFormulaAndClausesArray()));
00356 mlns[i]->replaceExternalClause((*externalClausesPerMLN)[i]);
00357 }
00358
00359 delete externalClausesPerMLN;
00360
00361 for (int i = 1; i < mlns.size(); i++)
00362 mlns[i]->rehashClauses();
00363 }
00364
00365
00366
00367
00368
00369
00370
00371 }
00372
00373 void assignWtsAndOutputMLN(ostream& out, Array<MLN*>& mlns,
00374 Array<Domain*>& domains, const Array<double>& wts,
00375 IndexTranslator* const& indexTrans)
00376 {
00377
00378
00379 double* wwts = (double*) wts.getItems();
00380 indexTrans->assignNonTiedClauseWtsToMLNs(++wwts);
00381
00382
00383
00384 out << "//predicate declarations" << endl;
00385 domains[0]->printPredicateTemplates(out);
00386 out << endl;
00387
00388
00389 if (domains[0]->getNumFunctions() > 0)
00390 {
00391 out << "//function declarations" << endl;
00392 domains[0]->printFunctionTemplates(out);
00393 out << endl;
00394 }
00395
00396 mlns[0]->printMLNNonExistFormulas(out, domains[0]);
00397
00398 const ClauseHashArray* clauseOrdering = indexTrans->getClauseOrdering();
00399 const StringHashArray* exFormOrdering = indexTrans->getExistFormulaOrdering();
00400 for (int i = 0; i < exFormOrdering->size(); i++)
00401 {
00402
00403 out.width(0); out << "// "; out.width(6);
00404 out << wts[1+clauseOrdering->size()+i] << " " <<(*exFormOrdering)[i]<<endl;
00405 out << wts[1+clauseOrdering->size()+i] << " " <<(*exFormOrdering)[i]<<endl;
00406 out << endl;
00407 }
00408 }
00409
00410
00411 void assignWtsAndOutputMLN(ostream& out, Array<MLN*>& mlns,
00412 Array<Domain*>& domains, const Array<double>& wts)
00413 {
00414
00415 for (int i = 0; i < mlns.size(); i++)
00416 {
00417 MLN* mln = mlns[i];
00418 const ClauseHashArray* clauses = mln->getClauses();
00419 for (int i = 0; i < clauses->size(); i++)
00420 (*clauses)[i]->setWt(wts[i+1]);
00421 }
00422
00423
00424 out << "//predicate declarations" << endl;
00425 domains[0]->printPredicateTemplates(out);
00426 out << endl;
00427
00428
00429 if (domains[0]->getNumFunctions() > 0)
00430 {
00431
00432 out << "//function declarations" << endl;
00433 domains[0]->printFunctionTemplates(out);
00434 out << endl;
00435 }
00436 mlns[0]->printMLN(out, domains[0]);
00437 }
00438
00439
00440 void deleteDomains(Array<Domain*>& domains)
00441 {
00442 for (int i = 0; i < domains.size(); i++)
00443 {
00444 if (DOMAINS_SHARE_DATA_STRUCT && i > 0)
00445 {
00446 ((Domain*)domains[i])->setTypeDualMap(NULL);
00447 ((Domain*)domains[i])->setStrToPredTemplateMapAndPredDualMap(NULL, NULL);
00448 ((Domain*)domains[i])->setStrToFuncTemplateMapAndFuncDualMap(NULL, NULL);
00449 ((Domain*)domains[i])->setEqualPredTemplate(NULL);
00450 ((Domain*)domains[i])->setFuncSet(NULL);
00451
00452
00453 ((Domain*)domains[i])->setConstDualMap(NULL);
00454 ((Domain*)domains[i])->setConstantsByType(NULL);
00455 }
00456 delete domains[i];
00457 }
00458 }
00459
00460
00461 #endif