00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #ifndef VOTED_PERCEPTRON_H_OCT_30_2005
00067 #define VOTED_PERCEPTRON_H_OCT_30_2005
00068
00069 #include "infer.h"
00070 #include "clause.h"
00071 #include "timer.h"
00072 #include "indextranslator.h"
00073 #include "maxwalksat.h"
00074
00075 const bool vpdebug = true;
00076 const double EPSILON=.00001;
00077
00082 class VotedPerceptron
00083 {
00084 public:
00085
00100 VotedPerceptron(const Array<Inference*>& inferences,
00101 const StringHashArray& nonEvidPredNames,
00102 IndexTranslator* const & idxTrans, const bool& lazyInference,
00103 const bool& rescaleGradient, const bool& withEM)
00104 : domainCnt_(inferences.size()), idxTrans_(idxTrans),
00105 lazyInference_(lazyInference), rescaleGradient_(rescaleGradient),
00106 withEM_(withEM)
00107 {
00108 cout << endl << "Constructing voted perceptron..." << endl << endl;
00109
00110 inferences_.append(inferences);
00111 logOddsPerDomain_.growToSize(domainCnt_);
00112 clauseCntPerDomain_.growToSize(domainCnt_);
00113
00114 for (int i = 0; i < domainCnt_; i++)
00115 {
00116 clauseCntPerDomain_[i] =
00117 inferences_[i]->getState()->getMLN()->getNumClauses();
00118 logOddsPerDomain_[i].growToSize(clauseCntPerDomain_[i], 0);
00119 }
00120
00121 totalTrueCnts_.growToSize(domainCnt_);
00122 defaultTrueCnts_.growToSize(domainCnt_);
00123 relevantClausesPerDomain_.growToSize(domainCnt_);
00124
00125
00126 findRelevantClauses(nonEvidPredNames);
00127 findRelevantClausesFormulas();
00128
00129
00130 if (lazyInference_)
00131 {
00132 findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(nonEvidPredNames);
00133
00134 for (int i = 0; i < domainCnt_; i++)
00135 {
00136 const MLN* mln = inferences_[i]->getState()->getMLN();
00137 Array<double>& logOdds = logOddsPerDomain_[i];
00138 assert(mln->getNumClauses() == logOdds.size());
00139 for (int j = 0; j < mln->getNumClauses(); j++)
00140 ((Clause*) mln->getClause(j))->setWt(logOdds[j]);
00141 }
00142 }
00143
00144 else
00145 {
00146 initializeWts();
00147 }
00148
00149
00150 for (int i = 0; i < inferences_.size(); i++)
00151 inferences_[i]->init();
00152 }
00153
00154
00155 ~VotedPerceptron()
00156 {
00157 for (int i = 0; i < trainTrueCnts_.size(); i++)
00158 delete[] trainTrueCnts_[i];
00159 }
00160
00161
00162
00163 void setMeansStdDevs(const int& arrSize, const double* const & priorMeans,
00164 const double* const & priorStdDevs)
00165 {
00166 if (arrSize < 0)
00167 {
00168 usePrior_ = false;
00169 priorMeans_ = NULL;
00170 priorStdDevs_ = NULL;
00171 }
00172 else
00173 {
00174
00175 usePrior_ = true;
00176 priorMeans_ = priorMeans;
00177 priorStdDevs_ = priorStdDevs;
00178
00179
00180
00181
00182 }
00183 }
00184
00185
00186
00187 void learnWeights(double* const & weights, const int& numWeights,
00188 const int& maxIter, const double& learningRate,
00189 const double& momentum, bool initWithLogOdds)
00190 {
00191
00192 memset(weights, 0, numWeights*sizeof(double));
00193
00194 double* averageWeights = new double[numWeights];
00195 double* gradient = new double[numWeights];
00196 double* lastchange = new double[numWeights];
00197
00198
00199 if (initWithLogOdds)
00200 {
00201
00202 if (idxTrans_ == NULL)
00203 {
00204 for (int i = 0; i < domainCnt_; i++)
00205 {
00206 Array<double>& logOdds = logOddsPerDomain_[i];
00207 assert(numWeights == logOdds.size());
00208 for (int j = 0; j < logOdds.size(); j++) weights[j] += logOdds[j];
00209 }
00210 }
00211 else
00212 {
00213 const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain
00214 = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00215
00216 Array<int> numLogOdds;
00217 Array<double> wtsForDomain;
00218 numLogOdds.growToSize(numWeights);
00219 wtsForDomain.growToSize(numWeights);
00220
00221 for (int i = 0; i < domainCnt_; i++)
00222 {
00223 memset((int*)numLogOdds.getItems(), 0, numLogOdds.size()*sizeof(int));
00224 memset((double*)wtsForDomain.getItems(), 0,
00225 wtsForDomain.size()*sizeof(double));
00226
00227 Array<double>& logOdds = logOddsPerDomain_[i];
00228
00229
00230
00231 for (int j = 0; j < logOdds.size(); j++)
00232 {
00233 Array<IdxDiv>* idxDivs =(*cIdxToCFIdxsPerDomain)[i][j];
00234 for (int k = 0; k < idxDivs->size(); k++)
00235 {
00236 wtsForDomain[ (*idxDivs)[k].idx ] += logOdds[j];
00237 numLogOdds[ (*idxDivs)[k].idx ]++;
00238 }
00239 }
00240
00241 for (int j = 0; j < numWeights; j++)
00242 if (numLogOdds[j] > 0) weights[j] += wtsForDomain[j]/numLogOdds[j];
00243 }
00244 }
00245 }
00246
00247
00248 for (int i = 0; i < numWeights; i++)
00249 {
00250 weights[i] /= domainCnt_;
00251 averageWeights[i] = weights[i];
00252 lastchange[i] = 0.0;
00253 }
00254
00255 for (int iter = 1; iter <= maxIter; iter++)
00256 {
00257 cout << endl << "Iteration " << iter << " : " << endl << endl;
00258 cout << "Getting the gradient.. " << endl;
00259 getGradient(weights, gradient, numWeights);
00260 cout << endl;
00261
00262
00263 for (int w = 0; w < numWeights; w++)
00264 {
00265 double wchange = gradient[w] * learningRate + lastchange[w] * momentum;
00266 cout << "clause/formula " << w << ": wtChange = " << wchange;
00267 cout << " oldWt = " << weights[w];
00268 weights[w] += wchange;
00269 lastchange[w] = wchange;
00270 cout << " newWt = " << weights[w];
00271 averageWeights[w] = (iter * averageWeights[w] + weights[w])/(iter + 1);
00272 cout << " averageWt = " << averageWeights[w] << endl;
00273 }
00274
00275 }
00276
00277 cout << endl << "Learned Weights : " << endl;
00278 for (int w = 0; w < numWeights; w++)
00279 {
00280 weights[w] = averageWeights[w];
00281 cout << w << ":" << weights[w] << endl;
00282 }
00283
00284 delete [] averageWeights;
00285 delete [] gradient;
00286 delete [] lastchange;
00287
00288 resetDBs();
00289 }
00290
00291
00292 private:
00293
00297 void resetDBs()
00298 {
00299 if (!lazyInference_)
00300 {
00301 for (int i = 0; i < domainCnt_; i++)
00302 {
00303 VariableState* state = inferences_[i]->getState();
00304 Database* db = state->getDomain()->getDB();
00305
00306 const GroundPredicateHashArray* knePreds = state->getKnePreds();
00307 const Array<TruthValue>* knePredValues = state->getKnePredValues();
00308 db->setValuesToGivenValues(knePreds, knePredValues);
00309
00310 const GroundPredicateHashArray* unePreds = state->getUnePreds();
00311 for (int predno = 0; predno < unePreds->size(); predno++)
00312 db->setValue((*unePreds)[predno], UNKNOWN);
00313 }
00314 }
00315 }
00316
00322 void findRelevantClauses(const StringHashArray& nonEvidPredNames)
00323 {
00324 for (int d = 0; d < domainCnt_; d++)
00325 {
00326 int clauseCnt = clauseCntPerDomain_[d];
00327 Array<bool>& relevantClauses = relevantClausesPerDomain_[d];
00328 relevantClauses.growToSize(clauseCnt);
00329 memset((bool*)relevantClauses.getItems(), false,
00330 relevantClauses.size()*sizeof(bool));
00331 const Domain* domain = inferences_[d]->getState()->getDomain();
00332 const MLN* mln = inferences_[d]->getState()->getMLN();
00333
00334 const Array<IndexClause*>* indclauses;
00335 const Clause* clause;
00336 int predid, clauseid;
00337 for (int i = 0; i < nonEvidPredNames.size(); i++)
00338 {
00339 predid = domain->getPredicateId(nonEvidPredNames[i].c_str());
00340
00341
00342 indclauses = mln->getClausesContainingPred(predid);
00343 if (indclauses)
00344 {
00345 for (int j = 0; j < indclauses->size(); j++)
00346 {
00347 clause = (*indclauses)[j]->clause;
00348 clauseid = mln->findClauseIdx(clause);
00349 relevantClauses[clauseid] = true;
00350
00351 }
00352
00353 }
00354 }
00355 }
00356 }
00357
00358
00359 void findRelevantClausesFormulas()
00360 {
00361 if (idxTrans_ == NULL)
00362 {
00363 Array<bool>& relevantClauses = relevantClausesPerDomain_[0];
00364 relevantClausesFormulas_.growToSize(relevantClauses.size());
00365 for (int i = 0; i < relevantClauses.size(); i++)
00366 relevantClausesFormulas_[i] = relevantClauses[i];
00367 }
00368 else
00369 {
00370 idxTrans_->setRelevantClausesFormulas(relevantClausesFormulas_,
00371 relevantClausesPerDomain_[0]);
00372 cout << "Relevant clauses/formulas:" << endl;
00373 idxTrans_->printRelevantClausesFormulas(cout, relevantClausesFormulas_);
00374 cout << endl;
00375 }
00376 }
00377
00378
00388 void calculateCounts(Array<double>& trueCnt, Array<double>& falseCnt,
00389 const int& domainIdx, const bool& hasUnknownPreds)
00390 {
00391 Clause* clause;
00392 double tmpUnknownCnt;
00393 int clauseCnt = clauseCntPerDomain_[domainIdx];
00394 Array<bool>& relevantClauses = relevantClausesPerDomain_[domainIdx];
00395 const MLN* mln = inferences_[domainIdx]->getState()->getMLN();
00396 const Domain* domain = inferences_[domainIdx]->getState()->getDomain();
00397
00398 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00399 {
00400 if (!relevantClauses[clauseno])
00401 {
00402 continue;
00403
00404 }
00405 clause = (Clause*) mln->getClause(clauseno);
00406 clause->getNumTrueFalseUnknownGroundings(domain, domain->getDB(),
00407 hasUnknownPreds,
00408 trueCnt[clauseno],
00409 falseCnt[clauseno],
00410 tmpUnknownCnt);
00411 assert(hasUnknownPreds || (tmpUnknownCnt==0));
00412 }
00413 }
00414
00415
00416 void initializeWts()
00417 {
00418 cout << "Initializing weights ..." << endl;
00419 Array<double *> trainFalseCnts;
00420 trainTrueCnts_.growToSize(domainCnt_);
00421 trainFalseCnts.growToSize(domainCnt_);
00422
00423 for (int i = 0; i < domainCnt_; i++)
00424 {
00425 int clauseCnt = clauseCntPerDomain_[i];
00426 VariableState* state = inferences_[i]->getState();
00427 const GroundPredicateHashArray* unePreds = state->getUnePreds();
00428 const GroundPredicateHashArray* knePreds = state->getKnePreds();
00429
00430 trainTrueCnts_[i] = new double[clauseCnt];
00431 trainFalseCnts[i] = new double[clauseCnt];
00432
00433 int totalPreds = unePreds->size() + knePreds->size();
00434
00435
00436 Array<bool>* unknownPred = new Array<bool>;
00437 unknownPred->growToSize(totalPreds, false);
00438 for (int predno = 0; predno < totalPreds; predno++)
00439 {
00440 GroundPredicate* p;
00441 if (predno < unePreds->size())
00442 p = (*unePreds)[predno];
00443 else
00444 p = (*knePreds)[predno - unePreds->size()];
00445 TruthValue tv = state->getDomain()->getDB()->getValue(p);
00446
00447
00448 if (tv == TRUE)
00449 {
00450 state->setValueOfAtom(predno + 1, true);
00451 p->setTruthValue(true);
00452 }
00453 else
00454 {
00455 state->setValueOfAtom(predno + 1, false);
00456 p->setTruthValue(false);
00457
00458
00459 if (tv == UNKNOWN)
00460 {
00461 (*unknownPred)[predno] = true;
00462 }
00463 }
00464 }
00465
00466 state->initMakeBreakCostWatch(0);
00467
00468 state->getNumClauseGndingsWithUnknown(trainTrueCnts_[i], clauseCnt, true,
00469 unknownPred);
00470
00471
00472 state->getNumClauseGndingsWithUnknown(trainFalseCnts[i], clauseCnt, false,
00473 unknownPred);
00474 delete unknownPred;
00475 if (vpdebug)
00476 {
00477 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00478 {
00479 cout << clauseno << " : tc = " << trainTrueCnts_[i][clauseno]
00480 << " ** fc = " << trainFalseCnts[i][clauseno] << endl;
00481 }
00482 }
00483 }
00484
00485 double tc,fc;
00486 cout << "List of CNF Clauses : " << endl;
00487 for (int clauseno = 0; clauseno < clauseCntPerDomain_[0]; clauseno++)
00488 {
00489 if (!relevantClausesPerDomain_[0][clauseno])
00490 {
00491 for (int i = 0; i < domainCnt_; i++)
00492 {
00493 Array<double>& logOdds = logOddsPerDomain_[i];
00494 logOdds[clauseno] = 0.0;
00495 }
00496 continue;
00497 }
00498
00499 cout << clauseno << ":";
00500 const Clause* clause =
00501 inferences_[0]->getState()->getMLN()->getClause(clauseno);
00502
00503 clause->print(cout, inferences_[0]->getState()->getDomain());
00504 cout << endl;
00505
00506 tc = 0.0; fc = 0.0;
00507 for (int i = 0; i < domainCnt_;i++)
00508 {
00509 tc += trainTrueCnts_[i][clauseno];
00510 fc += trainFalseCnts[i][clauseno];
00511 }
00512
00513
00514
00515
00516 double weight = 0.0;
00517 double totalCnt = tc + fc;
00518
00519 if (totalCnt == 0)
00520 {
00521
00522 weight = EPSILON;
00523 }
00524 else
00525 {
00526 double prob = tc / (tc+fc);
00527 if (prob == 0) prob = 0.00001;
00528 if (prob == 1) prob = 0.99999;
00529 weight = log(prob/(1-prob));
00530
00531
00532
00533
00534
00535
00536 if (abs(weight) < EPSILON) weight = EPSILON;
00537
00538 }
00539 for (int i = 0; i < domainCnt_; i++)
00540 {
00541 Array<double>& logOdds = logOddsPerDomain_[i];
00542 logOdds[clauseno] = weight;
00543 }
00544 }
00545 cout << endl;
00546
00547 for (int i = 0; i < trainFalseCnts.size(); i++)
00548 delete[] trainFalseCnts[i];
00549 }
00550
00559 void findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(
00560 const StringHashArray& nonEvidPredNames)
00561 {
00562 bool hasUnknownPreds;
00563 Array<Array<double> > totalFalseCnts;
00564 Array<Array<double> > defaultFalseCnts;
00565 totalFalseCnts.growToSize(domainCnt_);
00566 defaultFalseCnts.growToSize(domainCnt_);
00567
00568 Array<Predicate*> gpreds;
00569 Array<Predicate*> ppreds;
00570 Array<TruthValue> gpredValues;
00571 Array<TruthValue> tmpValues;
00572
00573 for (int i = 0; i < domainCnt_; i++)
00574 {
00575 const Domain* domain = inferences_[i]->getState()->getDomain();
00576 int clauseCnt = clauseCntPerDomain_[i];
00577 domain->getDB()->setPerformingInference(false);
00578
00579
00580 gpreds.clear();
00581 gpredValues.clear();
00582 tmpValues.clear();
00583 for (int predno = 0; predno < nonEvidPredNames.size(); predno++)
00584 {
00585 ppreds.clear();
00586 int predid = domain->getPredicateId(nonEvidPredNames[predno].c_str());
00587 Predicate::createAllGroundings(predid, domain, ppreds);
00588
00589 gpreds.append(ppreds);
00590 }
00591
00592 domain->getDB()->alterTruthValue(&gpreds, UNKNOWN, FALSE, &gpredValues);
00593
00594
00595
00596
00597 hasUnknownPreds = false;
00598
00599 Array<double>& trueCnt = totalTrueCnts_[i];
00600 Array<double>& falseCnt = totalFalseCnts[i];
00601 trueCnt.growToSize(clauseCnt);
00602 falseCnt.growToSize(clauseCnt);
00603 calculateCounts(trueCnt, falseCnt, i, hasUnknownPreds);
00604
00605
00606
00607 hasUnknownPreds = true;
00608
00609 domain->getDB()->setValuesToUnknown(&gpreds, &tmpValues);
00610
00611 Array<double>& dTrueCnt = defaultTrueCnts_[i];
00612 Array<double>& dFalseCnt = defaultFalseCnts[i];
00613 dTrueCnt.growToSize(clauseCnt);
00614 dFalseCnt.growToSize(clauseCnt);
00615 calculateCounts(dTrueCnt, dFalseCnt, i, hasUnknownPreds);
00616
00617
00618
00619
00620
00621
00622
00623
00624 for (int predno = 0; predno < gpreds.size(); predno++)
00625 delete gpreds[predno];
00626
00627 domain->getDB()->setPerformingInference(true);
00628 }
00629
00630
00631 for (int clauseno = 0; clauseno < clauseCntPerDomain_[0]; clauseno++)
00632 {
00633 double tc = 0;
00634 double fc = 0;
00635 for (int i = 0; i < domainCnt_; i++)
00636 {
00637 Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00638 Array<double>& logOdds = logOddsPerDomain_[i];
00639
00640 if (!relevantClauses[clauseno]) { logOdds[clauseno] = 0; continue; }
00641 tc += totalTrueCnts_[i][clauseno] - defaultTrueCnts_[i][clauseno];
00642 fc += totalFalseCnts[i][clauseno] - defaultFalseCnts[i][clauseno];
00643
00644 if (vpdebug)
00645 cout << clauseno << " : tc = " << tc << " ** fc = "<< fc <<endl;
00646 }
00647
00648 double weight = 0.0;
00649
00650 if ((tc + fc) == 0)
00651 {
00652
00653 }
00654 else
00655 {
00656 double prob = tc / (tc+fc);
00657 if (prob == 0) prob = 0.00001;
00658 if (prob == 1) prob = 0.99999;
00659 weight = log(prob / (1-prob));
00660
00661
00662
00663 if (abs(weight) < EPSILON) weight = EPSILON;
00664
00665 }
00666
00667
00668 for(int i = 0; i < domainCnt_; i++)
00669 {
00670 Array<double>& logOdds = logOddsPerDomain_[i];
00671 logOdds[clauseno] = weight;
00672 }
00673 }
00674 }
00675
00676
00680 void infer()
00681 {
00682 for (int i = 0; i < domainCnt_; i++)
00683 {
00684 VariableState* state = inferences_[i]->getState();
00685 state->setGndClausesWtsToSumOfParentWts();
00686
00687
00688 state->init();
00689 inferences_[i]->infer();
00690 state->saveLowStateToGndPreds();
00691 }
00692 }
00693
00698 void fillInMissingValues()
00699 {
00700 assert(withEM_);
00701 cout << "Filling in missing data ..." << endl;
00702
00703
00704 Array<Array<TruthValue> > ueValues;
00705 ueValues.growToSize(domainCnt_);
00706 for (int i = 0; i < domainCnt_; i++)
00707 {
00708 VariableState* state = inferences_[i]->getState();
00709 const Domain* domain = state->getDomain();
00710 const GroundPredicateHashArray* knePreds = state->getKnePreds();
00711 const Array<TruthValue>* knePredValues = state->getKnePredValues();
00712
00713
00714 domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
00715
00716
00717 state->setGndClausesWtsToSumOfParentWts();
00718
00719 state->init();
00720 inferences_[i]->infer();
00721 state->saveLowStateToGndPreds();
00722
00723 if (vpdebug)
00724 {
00725 cout << "Inferred following values: " << endl;
00726 inferences_[i]->printProbabilities(cout);
00727 }
00728
00729
00730 if (lazyInference_)
00731 {
00732 Array<double>& trueCnt = totalTrueCnts_[i];
00733 Array<double> falseCnt;
00734 bool hasUnknownPreds = false;
00735 falseCnt.growToSize(trueCnt.size());
00736 calculateCounts(trueCnt, falseCnt, i, hasUnknownPreds);
00737 }
00738 else
00739 {
00740 int clauseCnt = clauseCntPerDomain_[i];
00741 state->initMakeBreakCostWatch(0);
00742
00743 const Array<double>* clauseTrueCnts =
00744 inferences_[i]->getClauseTrueCnts();
00745 assert(clauseTrueCnts->size() == clauseCnt);
00746 for (int j = 0; j < clauseCnt; j++)
00747 trainTrueCnts_[i][j] = (*clauseTrueCnts)[j];
00748 }
00749
00750
00751
00752
00753
00754 Array<TruthValue> tmpValues;
00755 tmpValues.growToSize(knePreds->size());
00756 domain->getDB()->setValuesToUnknown(knePreds, &tmpValues);
00757 }
00758 cout << "Done filling in missing data" << endl;
00759 }
00760
00761 void getGradientForDomain(double* const & gradient, const int& domainIdx)
00762 {
00763 Array<bool>& relevantClauses = relevantClausesPerDomain_[domainIdx];
00764 int clauseCnt = clauseCntPerDomain_[domainIdx];
00765 double* trainCnts = NULL;
00766 double* inferredCnts = NULL;
00767 double* clauseTrainCnts = new double[clauseCnt];
00768 double* clauseInferredCnts = new double[clauseCnt];
00769 double trainCnt, inferredCnt;
00770 Array<double>& totalTrueCnts = totalTrueCnts_[domainIdx];
00771 Array<double>& defaultTrueCnts = defaultTrueCnts_[domainIdx];
00772 const MLN* mln = inferences_[domainIdx]->getState()->getMLN();
00773 const Domain* domain = inferences_[domainIdx]->getState()->getDomain();
00774
00775 memset(clauseTrainCnts, 0, clauseCnt*sizeof(double));
00776 memset(clauseInferredCnts, 0, clauseCnt*sizeof(double));
00777
00778 if (!lazyInference_)
00779 {
00780 if (!inferredCnts) inferredCnts = new double[clauseCnt];
00781
00782 const Array<double>* clauseTrueCnts =
00783 inferences_[domainIdx]->getClauseTrueCnts();
00784 assert(clauseTrueCnts->size() == clauseCnt);
00785 for (int i = 0; i < clauseCnt; i++)
00786 inferredCnts[i] = (*clauseTrueCnts)[i];
00787 trainCnts = trainTrueCnts_[domainIdx];
00788 }
00789
00790
00791 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00792 {
00793 if (!relevantClauses[clauseno]) continue;
00794
00795 if (lazyInference_)
00796 {
00797 Clause* clause = (Clause*) mln->getClause(clauseno);
00798
00799 trainCnt = totalTrueCnts[clauseno];
00800 inferredCnt =
00801 clause->getNumTrueGroundings(domain, domain->getDB(), false);
00802 trainCnt -= defaultTrueCnts[clauseno];
00803 inferredCnt -= defaultTrueCnts[clauseno];
00804
00805 clauseTrainCnts[clauseno] += trainCnt;
00806 clauseInferredCnts[clauseno] += inferredCnt;
00807 }
00808 else
00809 {
00810 clauseTrainCnts[clauseno] += trainCnts[clauseno];
00811 clauseInferredCnts[clauseno] += inferredCnts[clauseno];
00812 }
00813
00814 }
00815
00816 if (vpdebug)
00817 {
00818 cout << "net counts : " << endl;
00819 cout << "\t\ttrain count\t\t\t\tinferred count" << endl << endl;
00820 }
00821
00822 for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00823 {
00824 if (!relevantClauses[clauseno]) continue;
00825
00826 if (vpdebug)
00827 cout << clauseno << ":\t\t" << clauseTrainCnts[clauseno] << "\t\t\t\t"
00828 << clauseInferredCnts[clauseno] << endl;
00829 if (rescaleGradient_ && clauseTrainCnts[clauseno] > 0)
00830 {
00831 gradient[clauseno] +=
00832 (clauseTrainCnts[clauseno] - clauseInferredCnts[clauseno])
00833 / clauseTrainCnts[clauseno];
00834 }
00835 else
00836 {
00837 gradient[clauseno] += clauseTrainCnts[clauseno] -
00838 clauseInferredCnts[clauseno];
00839 }
00840 }
00841
00842 delete[] clauseTrainCnts;
00843 delete[] clauseInferredCnts;
00844 }
00845
00846
00847
00848 void getGradient(double* const & weights, double* const & gradient,
00849 const int numWts)
00850 {
00851
00852
00853
00854
00855
00856 if (idxTrans_ == NULL)
00857 {
00858 int clauseCnt = clauseCntPerDomain_[0];
00859 for (int i = 0; i < domainCnt_; i++)
00860 {
00861 Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00862 assert(clauseCntPerDomain_[i] == clauseCnt);
00863 const MLN* mln = inferences_[i]->getState()->getMLN();
00864
00865 for (int j = 0; j < clauseCnt; j++)
00866 {
00867 Clause* c = (Clause*) mln->getClause(j);
00868 if (relevantClauses[j]) c->setWt(weights[j]);
00869 else c->setWt(0);
00870 }
00871 }
00872 }
00873 else
00874 {
00875 Array<Array<double> >* wtsPerDomain = idxTrans_->getWtsPerDomain();
00876 const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain
00877 = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00878
00879 for (int i = 0; i < domainCnt_; i++)
00880 {
00881 Array<double>& wts = (*wtsPerDomain)[i];
00882 memset((double*)wts.getItems(), 0, wts.size()*sizeof(double));
00883
00884
00885 for (int j = 0; j < wts.size(); j++)
00886 {
00887 Array<IdxDiv>* idxDivs = (*cIdxToCFIdxsPerDomain)[i][j];
00888 for (int k = 0; k < idxDivs->size(); k++)
00889 wts[j] += weights[ (*idxDivs)[k].idx ] / (*idxDivs)[k].div;
00890 }
00891 }
00892
00893 for (int i = 0; i < domainCnt_; i++)
00894 {
00895 Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00896 int clauseCnt = clauseCntPerDomain_[i];
00897 Array<double>& wts = (*wtsPerDomain)[i];
00898 assert(wts.size() == clauseCnt);
00899 const MLN* mln = inferences_[i]->getState()->getMLN();
00900
00901 for (int j = 0; j < clauseCnt; j++)
00902 {
00903 Clause* c = (Clause*) mln->getClause(j);
00904 if (relevantClauses[j]) c->setWt(wts[j]);
00905 else c->setWt(0);
00906 }
00907 }
00908 }
00909
00910
00911 if (withEM_) fillInMissingValues();
00912 cout << "Running inference ..." << endl;
00913 infer();
00914 cout << "Done with inference" << endl;
00915
00916
00917 memset(gradient, 0, numWts*sizeof(double));
00918
00919
00920 if (idxTrans_ == NULL)
00921 {
00922 for (int i = 0; i < domainCnt_; i++)
00923 {
00924
00925 getGradientForDomain(gradient, i);
00926 }
00927 }
00928 else
00929 {
00930
00931 Array<Array<double> >* gradsPerDomain = idxTrans_->getGradsPerDomain();
00932 const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain
00933 = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00934
00935 for (int i = 0; i < domainCnt_; i++)
00936 {
00937
00938
00939 Array<double>& grads = (*gradsPerDomain)[i];
00940 memset((double*)grads.getItems(), 0, grads.size()*sizeof(double));
00941
00942 getGradientForDomain((double*)grads.getItems(), i);
00943
00944
00945 assert(grads.size() == clauseCntPerDomain_[i]);
00946 for (int j = 0; j < grads.size(); j++)
00947 {
00948 Array<IdxDiv>* idxDivs = (*cIdxToCFIdxsPerDomain)[i][j];
00949 for (int k = 0; k < idxDivs->size(); k++)
00950 gradient[ (*idxDivs)[k].idx ] += grads[j] / (*idxDivs)[k].div;
00951 }
00952 }
00953 }
00954
00955
00956 if (usePrior_)
00957 {
00958 for (int i = 0; i < numWts; i++)
00959 {
00960 if (!relevantClausesFormulas_[i]) continue;
00961 double priorDerivative = -(weights[i]-priorMeans_[i])/
00962 (priorStdDevs_[i]*priorStdDevs_[i]);
00963
00964
00965 gradient[i] += priorDerivative;
00966
00967 }
00968 }
00969 }
00970
00971
00972 private:
00973 int domainCnt_;
00974
00975
00976 Array<Array<double> > logOddsPerDomain_;
00977 Array<int> clauseCntPerDomain_;
00978
00979
00980 Array<Array<double> > totalTrueCnts_;
00981 Array<Array<double> > defaultTrueCnts_;
00982
00983 Array<Array<bool> > relevantClausesPerDomain_;
00984 Array<bool> relevantClausesFormulas_;
00985
00986
00987 Array<double*> trainTrueCnts_;
00988
00989 bool usePrior_;
00990 const double* priorMeans_, * priorStdDevs_;
00991
00992 IndexTranslator* idxTrans_;
00993
00994 bool lazyInference_;
00995 bool rescaleGradient_;
00996 bool isQueryEvidence_;
00997
00998 Array<Inference*> inferences_;
00999
01000
01001 bool withEM_;
01002 };
01003
01004
01005 #endif