votedperceptron.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner and Hoifung Poon.
00006  * 
00007  * Copyright [2004-07] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00009  * Poon. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner and Hoifung
00032  * Poon in the Department of Computer Science and
00033  * Engineering at the University of Washington".
00034  * 
00035  * 4. Your publications acknowledge the use or
00036  * contribution made by the Software to your research
00037  * using the following citation(s): 
00038  * Stanley Kok, Parag Singla, Matthew Richardson and
00039  * Pedro Domingos (2005). "The Alchemy System for
00040  * Statistical Relational AI", Technical Report,
00041  * Department of Computer Science and Engineering,
00042  * University of Washington, Seattle, WA.
00043  * http://www.cs.washington.edu/ai/alchemy.
00044  * 
00045  * 5. Neither the name of the University of Washington nor
00046  * the names of its contributors may be used to endorse or
00047  * promote products derived from this software without
00048  * specific prior written permission.
00049  * 
00050  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00051  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00052  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00053  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00054  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00055  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00056  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00057  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00058  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00059  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00060  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00061  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00062  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00063  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00064  * 
00065  */
00066 #ifndef VOTED_PERCEPTRON_H_OCT_30_2005
00067 #define VOTED_PERCEPTRON_H_OCT_30_2005
00068 
00069 #include "infer.h"
00070 #include "clause.h"
00071 #include "timer.h"
00072 #include "indextranslator.h"
00073 #include "maxwalksat.h"
00074 
00075 const bool vpdebug = true;
00076 const double EPSILON=.00001;
00077 
00082 class VotedPerceptron 
00083 {
00084  public:
00085 
00100   VotedPerceptron(const Array<Inference*>& inferences,
00101                   const StringHashArray& nonEvidPredNames,
00102                   IndexTranslator* const & idxTrans, const bool& lazyInference,
00103                   const bool& rescaleGradient, const bool& withEM)
00104     : domainCnt_(inferences.size()), idxTrans_(idxTrans),
00105       lazyInference_(lazyInference), rescaleGradient_(rescaleGradient),
00106       withEM_(withEM)
00107   { 
00108     cout << endl << "Constructing voted perceptron..." << endl << endl;
00109 
00110     inferences_.append(inferences);
00111     logOddsPerDomain_.growToSize(domainCnt_);
00112     clauseCntPerDomain_.growToSize(domainCnt_);
00113     
00114     for (int i = 0; i < domainCnt_; i++)
00115     {
00116       clauseCntPerDomain_[i] =
00117         inferences_[i]->getState()->getMLN()->getNumClauses();
00118       logOddsPerDomain_[i].growToSize(clauseCntPerDomain_[i], 0);
00119     }
00120 
00121     totalTrueCnts_.growToSize(domainCnt_);
00122     defaultTrueCnts_.growToSize(domainCnt_);
00123     relevantClausesPerDomain_.growToSize(domainCnt_);
00124     //relevantClausesFormulas_ is set in findRelevantClausesFormulas()
00125 
00126     findRelevantClauses(nonEvidPredNames);
00127     findRelevantClausesFormulas();
00128 
00129       // Initialize the clause wts for lazy version
00130     if (lazyInference_)
00131     {
00132       findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(nonEvidPredNames);
00133     
00134       for (int i = 0; i < domainCnt_; i++)
00135       {
00136         const MLN* mln = inferences_[i]->getState()->getMLN();
00137         Array<double>& logOdds = logOddsPerDomain_[i];
00138         assert(mln->getNumClauses() == logOdds.size());
00139         for (int j = 0; j < mln->getNumClauses(); j++)
00140           ((Clause*) mln->getClause(j))->setWt(logOdds[j]);
00141       }
00142     }
00143       // Initialize the clause wts for eager version
00144     else
00145     {      
00146       initializeWts();
00147     }
00148     
00149       // Initialize the inference / state
00150     for (int i = 0; i < inferences_.size(); i++)
00151       inferences_[i]->init();
00152   }
00153 
00154 
00155   ~VotedPerceptron() 
00156   {
00157     for (int i = 0; i < trainTrueCnts_.size(); i++)
00158       delete[] trainTrueCnts_[i];
00159   }
00160 
00161 
00162     // set the prior means and std devs.
00163   void setMeansStdDevs(const int& arrSize, const double* const & priorMeans, 
00164                        const double* const & priorStdDevs) 
00165   {
00166     if (arrSize < 0) 
00167     {
00168       usePrior_ = false;
00169       priorMeans_ = NULL;
00170       priorStdDevs_ = NULL;
00171     } 
00172     else 
00173     {
00174       //cout << "arr size = " << arrSize<<", clause count = "<<clauseCnt_<<endl;
00175       usePrior_ = true;
00176       priorMeans_ = priorMeans;
00177       priorStdDevs_ = priorStdDevs;
00178 
00179       //cout << "\t\t Mean \t\t Std Deviation" << endl;
00180       //for (int i = 0; i < arrSize; i++) 
00181       //  cout << i << "\t\t" << priorMeans_[i]<<"\t\t"<<priorStdDevs_[i]<<endl;
00182     }
00183   }
00184 
00185 
00186     // learn the weights
00187   void learnWeights(double* const & weights, const int& numWeights,
00188                     const int& maxIter, const double& learningRate,
00189                     const double& momentum, bool initWithLogOdds) 
00190   {
00191     //cout << "Learning weights discriminatively... " << endl;
00192     memset(weights, 0, numWeights*sizeof(double));
00193 
00194     double* averageWeights = new double[numWeights];
00195     double* gradient = new double[numWeights];
00196     double* lastchange = new double[numWeights];
00197 
00198       // Set the initial weight to the average log odds across domains/databases
00199     if (initWithLogOdds)
00200     {
00201         // If there is one db or the clauses for multiple databases line up
00202       if (idxTrans_ == NULL)
00203       {
00204         for (int i = 0; i < domainCnt_; i++)
00205         {
00206           Array<double>& logOdds = logOddsPerDomain_[i];
00207           assert(numWeights == logOdds.size());
00208           for (int j = 0; j < logOdds.size(); j++) weights[j] += logOdds[j];
00209         }
00210       }
00211       else
00212       { //the clauses for multiple databases do not line up
00213         const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain 
00214           = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00215 
00216         Array<int> numLogOdds; 
00217         Array<double> wtsForDomain;
00218         numLogOdds.growToSize(numWeights);
00219         wtsForDomain.growToSize(numWeights);
00220       
00221         for (int i = 0; i < domainCnt_; i++)
00222         {
00223           memset((int*)numLogOdds.getItems(), 0, numLogOdds.size()*sizeof(int));
00224           memset((double*)wtsForDomain.getItems(), 0,
00225                  wtsForDomain.size()*sizeof(double));
00226 
00227           Array<double>& logOdds = logOddsPerDomain_[i];
00228         
00229             // Map the each log odds of a clause to the weight of a
00230             // clause/formula
00231           for (int j = 0; j < logOdds.size(); j++)
00232           {
00233             Array<IdxDiv>* idxDivs =(*cIdxToCFIdxsPerDomain)[i][j];          
00234             for (int k = 0; k < idxDivs->size(); k++)
00235             {
00236               wtsForDomain[ (*idxDivs)[k].idx ] += logOdds[j];
00237               numLogOdds[ (*idxDivs)[k].idx ]++;
00238             }
00239           }
00240 
00241           for (int j = 0; j < numWeights; j++)
00242             if (numLogOdds[j] > 0) weights[j] += wtsForDomain[j]/numLogOdds[j];  
00243         }
00244       }
00245     }
00246 
00247       // Initialize weights, averageWeights, lastchange
00248     for (int i = 0; i < numWeights; i++) 
00249     {      
00250       weights[i] /= domainCnt_;
00251       averageWeights[i] = weights[i];
00252       lastchange[i] = 0.0;
00253     }
00254 
00255     for (int iter = 1; iter <= maxIter; iter++) 
00256     {
00257       cout << endl << "Iteration " << iter << " : " << endl << endl;
00258       cout << "Getting the gradient.. " << endl;
00259       getGradient(weights, gradient, numWeights);
00260       cout << endl; 
00261 
00262         // Add gradient to weights
00263       for (int w = 0; w < numWeights; w++) 
00264       {
00265         double wchange = gradient[w] * learningRate + lastchange[w] * momentum;
00266         cout << "clause/formula " << w << ": wtChange = " << wchange;
00267         cout << "  oldWt = " << weights[w];
00268         weights[w] += wchange;
00269         lastchange[w] = wchange;
00270         cout << "  newWt = " << weights[w];
00271         averageWeights[w] = (iter * averageWeights[w] + weights[w])/(iter + 1);
00272         cout << "  averageWt = " << averageWeights[w] << endl;
00273       }
00274       // done with an iteration
00275     }
00276     
00277     cout << endl << "Learned Weights : " << endl;
00278     for (int w = 0; w < numWeights; w++) 
00279     {
00280       weights[w] = averageWeights[w];
00281       cout << w << ":" << weights[w] << endl;
00282     }
00283 
00284     delete [] averageWeights;
00285     delete [] gradient;
00286     delete [] lastchange;
00287     
00288     resetDBs();
00289   }
00290  
00291  
00292  private:
00293  
00297   void resetDBs() 
00298   {
00299     if (!lazyInference_)
00300     {
00301       for (int i = 0; i < domainCnt_; i++) 
00302       {
00303         VariableState* state = inferences_[i]->getState();
00304         Database* db = state->getDomain()->getDB();
00305           // Change known NE to original values
00306         const GroundPredicateHashArray* knePreds = state->getKnePreds();
00307         const Array<TruthValue>* knePredValues = state->getKnePredValues();      
00308         db->setValuesToGivenValues(knePreds, knePredValues);
00309           // Set unknown NE back to UKNOWN
00310         const GroundPredicateHashArray* unePreds = state->getUnePreds();
00311         for (int predno = 0; predno < unePreds->size(); predno++) 
00312           db->setValue((*unePreds)[predno], UNKNOWN);
00313       }
00314     }
00315   }
00316 
00322   void findRelevantClauses(const StringHashArray& nonEvidPredNames) 
00323   {
00324     for (int d = 0; d < domainCnt_; d++)
00325     {
00326       int clauseCnt = clauseCntPerDomain_[d];
00327       Array<bool>& relevantClauses = relevantClausesPerDomain_[d];
00328       relevantClauses.growToSize(clauseCnt);
00329       memset((bool*)relevantClauses.getItems(), false, 
00330              relevantClauses.size()*sizeof(bool));
00331       const Domain* domain = inferences_[d]->getState()->getDomain();
00332       const MLN* mln = inferences_[d]->getState()->getMLN();
00333     
00334       const Array<IndexClause*>* indclauses;
00335       const Clause* clause;
00336       int predid, clauseid;
00337       for (int i = 0; i < nonEvidPredNames.size(); i++)
00338       {
00339         predid = domain->getPredicateId(nonEvidPredNames[i].c_str());
00340         //cout << "finding the relevant clauses for predid = " << predid 
00341         //     << " in domain " << d << endl;
00342         indclauses = mln->getClausesContainingPred(predid);
00343         if (indclauses) 
00344         {
00345           for (int j = 0; j < indclauses->size(); j++) 
00346           {
00347             clause = (*indclauses)[j]->clause;                  
00348             clauseid = mln->findClauseIdx(clause);
00349             relevantClauses[clauseid] = true;
00350             //cout << clauseid << " ";
00351           }
00352           //cout<<endl;
00353         }
00354       }    
00355     }
00356   }
00357 
00358   
00359   void findRelevantClausesFormulas()
00360   {
00361     if (idxTrans_ == NULL)
00362     {
00363       Array<bool>& relevantClauses = relevantClausesPerDomain_[0];
00364       relevantClausesFormulas_.growToSize(relevantClauses.size());
00365       for (int i = 0; i < relevantClauses.size(); i++)
00366         relevantClausesFormulas_[i] = relevantClauses[i];
00367     }
00368     else
00369     {
00370       idxTrans_->setRelevantClausesFormulas(relevantClausesFormulas_,
00371                                             relevantClausesPerDomain_[0]);
00372       cout << "Relevant clauses/formulas:" << endl;
00373       idxTrans_->printRelevantClausesFormulas(cout, relevantClausesFormulas_);
00374       cout << endl;
00375     }
00376   }
00377 
00378 
00388   void calculateCounts(Array<double>& trueCnt, Array<double>& falseCnt,
00389                        const int& domainIdx, const bool& hasUnknownPreds) 
00390   {
00391     Clause* clause;
00392     double tmpUnknownCnt;
00393     int clauseCnt = clauseCntPerDomain_[domainIdx];
00394     Array<bool>& relevantClauses = relevantClausesPerDomain_[domainIdx];
00395     const MLN* mln = inferences_[domainIdx]->getState()->getMLN();
00396     const Domain* domain = inferences_[domainIdx]->getState()->getDomain();
00397 
00398     for (int clauseno = 0; clauseno < clauseCnt; clauseno++) 
00399     {
00400       if (!relevantClauses[clauseno]) 
00401       {
00402         continue;
00403         //cout << "\n\nthis is an irrelevant clause.." << endl;
00404       }
00405       clause = (Clause*) mln->getClause(clauseno);
00406       clause->getNumTrueFalseUnknownGroundings(domain, domain->getDB(), 
00407                                                hasUnknownPreds,
00408                                                trueCnt[clauseno],
00409                                                falseCnt[clauseno],
00410                                                tmpUnknownCnt);
00411       assert(hasUnknownPreds || (tmpUnknownCnt==0));
00412     }
00413   }
00414 
00415 
00416   void initializeWts()
00417   {
00418     cout << "Initializing weights ..." << endl;
00419     Array<double *> trainFalseCnts;
00420     trainTrueCnts_.growToSize(domainCnt_);
00421     trainFalseCnts.growToSize(domainCnt_);
00422   
00423     for (int i = 0; i < domainCnt_; i++)
00424     {
00425       int clauseCnt = clauseCntPerDomain_[i];
00426       VariableState* state = inferences_[i]->getState();
00427       const GroundPredicateHashArray* unePreds = state->getUnePreds();
00428       const GroundPredicateHashArray* knePreds = state->getKnePreds();
00429 
00430       trainTrueCnts_[i] = new double[clauseCnt];
00431       trainFalseCnts[i] = new double[clauseCnt];
00432 
00433       int totalPreds = unePreds->size() + knePreds->size();
00434         // Used to store gnd preds to be ignored in the count because they are
00435         // UNKNOWN
00436       Array<bool>* unknownPred = new Array<bool>;
00437       unknownPred->growToSize(totalPreds, false);
00438       for (int predno = 0; predno < totalPreds; predno++) 
00439       {
00440         GroundPredicate* p;
00441         if (predno < unePreds->size())
00442           p = (*unePreds)[predno];
00443         else
00444           p = (*knePreds)[predno - unePreds->size()];
00445         TruthValue tv = state->getDomain()->getDB()->getValue(p);
00446 
00447         //assert(tv != UNKNOWN);
00448         if (tv == TRUE)
00449         {
00450           state->setValueOfAtom(predno + 1, true);
00451           p->setTruthValue(true);
00452         }
00453         else
00454         {
00455           state->setValueOfAtom(predno + 1, false);
00456           p->setTruthValue(false);
00457             // Can have unknown truth values when using EM. We want to ignore
00458             // these when performing the counts
00459           if (tv == UNKNOWN)
00460           {
00461             (*unknownPred)[predno] = true;
00462           }
00463         }
00464       }
00465 
00466       state->initMakeBreakCostWatch(0);
00467       //cout<<"getting true cnts => "<<endl;
00468       state->getNumClauseGndingsWithUnknown(trainTrueCnts_[i], clauseCnt, true,
00469                                            unknownPred);
00470       //cout<<endl;
00471       //cout<<"getting false cnts => "<<endl;
00472       state->getNumClauseGndingsWithUnknown(trainFalseCnts[i], clauseCnt, false,
00473                                             unknownPred);
00474       delete unknownPred;
00475       if (vpdebug)
00476       {
00477         for (int clauseno = 0; clauseno < clauseCnt; clauseno++)
00478         {
00479           cout << clauseno << " : tc = " << trainTrueCnts_[i][clauseno]
00480                << " ** fc = " << trainFalseCnts[i][clauseno] << endl;
00481         }
00482       }
00483     }
00484 
00485     double tc,fc;
00486     cout << "List of CNF Clauses : " << endl;
00487     for (int clauseno = 0; clauseno < clauseCntPerDomain_[0]; clauseno++)
00488     {
00489       if (!relevantClausesPerDomain_[0][clauseno])
00490       {
00491         for (int i = 0; i < domainCnt_; i++)
00492         {
00493           Array<double>& logOdds = logOddsPerDomain_[i];
00494           logOdds[clauseno] = 0.0;
00495         }
00496         continue;
00497       }
00498       //cout << endl << endl;
00499       cout << clauseno << ":";
00500       const Clause* clause =
00501         inferences_[0]->getState()->getMLN()->getClause(clauseno);
00502       //cout << (*fncArr)[clauseno]->formula <<endl;
00503       clause->print(cout, inferences_[0]->getState()->getDomain());
00504       cout << endl;
00505       
00506       tc = 0.0; fc = 0.0;
00507       for (int i = 0; i < domainCnt_;i++)
00508       {
00509         tc += trainTrueCnts_[i][clauseno];
00510         fc += trainFalseCnts[i][clauseno];
00511       }
00512         
00513       //cout << "true count  = " << tc << endl;
00514       //cout << "false count = " << fc << endl;
00515         
00516       double weight = 0.0;
00517       double totalCnt = tc + fc;
00518                 
00519       if (totalCnt == 0) 
00520       {
00521         //cout << "NOTE: Total count is 0 for clause " << clauseno << endl;
00522         weight = EPSILON;
00523       } 
00524       else 
00525       {
00526         double prob =  tc / (tc+fc);
00527         if (prob == 0) prob = 0.00001;
00528         if (prob == 1) prob = 0.99999;
00529         weight = log(prob/(1-prob));
00530           //if weight exactly equals 0, make it small non zero, so that clause  
00531           //is not ignored during the construction of the MRF
00532         //if(weight == 0) weight = 0.0001;
00533           //commented above - make sure all weights are positive in the
00534           //beginning
00535         //if(weight < EPSILON) weight = EPSILON;
00536         if (abs(weight) < EPSILON) weight = EPSILON;
00537           //cout << "Prob " << prob << " becomes weight of " << weight << endl;
00538       }
00539       for (int i = 0; i < domainCnt_; i++) 
00540       {
00541         Array<double>& logOdds = logOddsPerDomain_[i];
00542         logOdds[clauseno] = weight;
00543       }
00544     }
00545     cout << endl;
00546     
00547     for (int i = 0; i < trainFalseCnts.size(); i++)
00548       delete[] trainFalseCnts[i];
00549   }
00550 
00559   void findCountsInitializeWtsAndSetNonEvidPredsToUnknownInDB(
00560                                        const StringHashArray& nonEvidPredNames)
00561   {
00562     bool hasUnknownPreds;
00563     Array<Array<double> > totalFalseCnts; 
00564     Array<Array<double> > defaultFalseCnts;
00565     totalFalseCnts.growToSize(domainCnt_);
00566     defaultFalseCnts.growToSize(domainCnt_);
00567     
00568     Array<Predicate*> gpreds;
00569     Array<Predicate*> ppreds;
00570     Array<TruthValue> gpredValues;
00571     Array<TruthValue> tmpValues;
00572 
00573     for (int i = 0; i < domainCnt_; i++) 
00574     {
00575       const Domain* domain = inferences_[i]->getState()->getDomain();
00576       int clauseCnt = clauseCntPerDomain_[i];
00577       domain->getDB()->setPerformingInference(false);
00578 
00579       //cout << endl << "Getting the counts for the domain " << i << endl;
00580       gpreds.clear();
00581       gpredValues.clear();
00582       tmpValues.clear();
00583       for (int predno = 0; predno < nonEvidPredNames.size(); predno++) 
00584       {
00585         ppreds.clear();
00586         int predid = domain->getPredicateId(nonEvidPredNames[predno].c_str());
00587         Predicate::createAllGroundings(predid, domain, ppreds);
00588         //cout<<"size of gnd for pred " << predid << " = "<<ppreds.size()<<endl;
00589         gpreds.append(ppreds);
00590       }
00591       
00592       domain->getDB()->alterTruthValue(&gpreds, UNKNOWN, FALSE, &gpredValues);
00593           
00594       //cout <<"size of unknown set for domain "<<i<<" = "<<gpreds.size()<<endl;
00595       //cout << "size of the values " << i << " = " << gpredValues.size()<<endl;
00596         
00597       hasUnknownPreds = false;
00598       
00599       Array<double>& trueCnt = totalTrueCnts_[i];
00600       Array<double>& falseCnt = totalFalseCnts[i];
00601       trueCnt.growToSize(clauseCnt);
00602       falseCnt.growToSize(clauseCnt);
00603       calculateCounts(trueCnt, falseCnt, i, hasUnknownPreds);
00604 
00605       //cout << "got the total counts..\n\n\n" << endl;
00606       
00607       hasUnknownPreds = true;
00608 
00609       domain->getDB()->setValuesToUnknown(&gpreds, &tmpValues);
00610 
00611       Array<double>& dTrueCnt = defaultTrueCnts_[i];
00612       Array<double>& dFalseCnt = defaultFalseCnts[i];
00613       dTrueCnt.growToSize(clauseCnt);
00614       dFalseCnt.growToSize(clauseCnt);
00615       calculateCounts(dTrueCnt, dFalseCnt, i, hasUnknownPreds);
00616 
00617       //commented out: no need to revert the grounded non-evidence predicates
00618       //               to their initial values because we want to set ALL of
00619       //               them to UNKNOWN
00620       //assert(gpreds.size() == gpredValues.size());
00621       //domain->getDB()->setValuesToGivenValues(&gpreds, &gpredValues);
00622           
00623       //cout << "the ground predicates are :" << endl;
00624       for (int predno = 0; predno < gpreds.size(); predno++) 
00625         delete gpreds[predno];
00626 
00627       domain->getDB()->setPerformingInference(true);
00628     }
00629     //cout << endl << endl;
00630     //cout << "got the default counts..." << endl;     
00631     for (int clauseno = 0; clauseno < clauseCntPerDomain_[0]; clauseno++) 
00632     {
00633       double tc = 0;
00634       double fc = 0;
00635       for (int i = 0; i < domainCnt_; i++) 
00636       {
00637         Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00638         Array<double>& logOdds = logOddsPerDomain_[i];
00639       
00640         if (!relevantClauses[clauseno]) { logOdds[clauseno] = 0; continue; }
00641         tc += totalTrueCnts_[i][clauseno] - defaultTrueCnts_[i][clauseno];
00642         fc += totalFalseCnts[i][clauseno] - defaultFalseCnts[i][clauseno];
00643 
00644         if (vpdebug)
00645           cout << clauseno << " : tc = " << tc << " ** fc = "<< fc <<endl;      
00646       }
00647       
00648       double weight = 0.0;
00649 
00650       if ((tc + fc) == 0) 
00651       {
00652         //cout << "NOTE: Total count is 0 for clause " << clauseno << endl;
00653       } 
00654       else 
00655       {
00656         double prob = tc / (tc+fc);
00657         if (prob == 0) prob = 0.00001;
00658         if (prob == 1) prob = 0.99999;
00659         weight = log(prob / (1-prob));
00660             //if weight exactly equals 0, make it small non zero, so that clause
00661             //is not ignored during the construction of the MRF
00662         //if (weight == 0) weight = 0.0001;
00663         if (abs(weight) < EPSILON) weight = EPSILON;
00664           //cout << "Prob " << prob << " becomes weight of " << weight << endl;
00665       }
00666       
00667         // Set logOdds in all domains to the weight calculated
00668       for(int i = 0; i < domainCnt_; i++) 
00669       { 
00670         Array<double>& logOdds = logOddsPerDomain_[i];
00671         logOdds[clauseno] = weight;
00672       }
00673     }
00674   }
00675  
00676   
00680   void infer() 
00681   {
00682     for (int i = 0; i < domainCnt_; i++) 
00683     {
00684       VariableState* state = inferences_[i]->getState();
00685       state->setGndClausesWtsToSumOfParentWts();
00686       //inferences_[i]->init();
00687         // MWS: Search is started from state at end of last iteration
00688       state->init();
00689       inferences_[i]->infer();
00690       state->saveLowStateToGndPreds();
00691     }
00692   }
00693 
00698   void fillInMissingValues()
00699   {
00700     assert(withEM_);
00701     cout << "Filling in missing data ..." << endl;
00702       // Get values of initial unknown preds by producing MAP state of
00703       // unknown preds given known evidence and non-evidence preds (VPEM)
00704     Array<Array<TruthValue> > ueValues;
00705     ueValues.growToSize(domainCnt_);
00706     for (int i = 0; i < domainCnt_; i++)
00707     {
00708       VariableState* state = inferences_[i]->getState();
00709       const Domain* domain = state->getDomain();
00710       const GroundPredicateHashArray* knePreds = state->getKnePreds();
00711       const Array<TruthValue>* knePredValues = state->getKnePredValues();
00712 
00713         // Mark known non-evidence preds as evidence
00714       domain->getDB()->setValuesToGivenValues(knePreds, knePredValues);
00715 
00716         // Infer missing values
00717       state->setGndClausesWtsToSumOfParentWts();
00718         // MWS: Search is started from state at end of last iteration
00719       state->init();
00720       inferences_[i]->infer();
00721       state->saveLowStateToGndPreds();
00722 
00723       if (vpdebug)
00724       {
00725         cout << "Inferred following values: " << endl;
00726         inferences_[i]->printProbabilities(cout);
00727       }
00728 
00729         // Compute counts
00730       if (lazyInference_)
00731       {
00732         Array<double>& trueCnt = totalTrueCnts_[i];
00733         Array<double> falseCnt;
00734         bool hasUnknownPreds = false;
00735         falseCnt.growToSize(trueCnt.size());
00736         calculateCounts(trueCnt, falseCnt, i, hasUnknownPreds);
00737       }
00738       else
00739       {
00740         int clauseCnt = clauseCntPerDomain_[i];
00741         state->initMakeBreakCostWatch(0);
00742         //cout<<"getting true cnts => "<<endl;
00743         const Array<double>* clauseTrueCnts =
00744           inferences_[i]->getClauseTrueCnts();
00745         assert(clauseTrueCnts->size() == clauseCnt);
00746         for (int j = 0; j < clauseCnt; j++)
00747           trainTrueCnts_[i][j] = (*clauseTrueCnts)[j];
00748       }
00749 
00750         // Set evidence values back
00751       //assert(uePreds.size() == ueValues[i].size());
00752       //domain->getDB()->setValuesToGivenValues(&uePreds, &ueValues[i]);
00753         // Set non-evidence values to unknown
00754       Array<TruthValue> tmpValues;
00755       tmpValues.growToSize(knePreds->size());
00756       domain->getDB()->setValuesToUnknown(knePreds, &tmpValues);
00757     }
00758     cout << "Done filling in missing data" << endl;    
00759   }
00760 
00761   void getGradientForDomain(double* const & gradient, const int& domainIdx)
00762   {
00763     Array<bool>& relevantClauses = relevantClausesPerDomain_[domainIdx];
00764     int clauseCnt = clauseCntPerDomain_[domainIdx];
00765     double* trainCnts = NULL;
00766     double* inferredCnts = NULL;
00767     double* clauseTrainCnts = new double[clauseCnt]; 
00768     double* clauseInferredCnts = new double[clauseCnt];
00769     double trainCnt, inferredCnt;
00770     Array<double>& totalTrueCnts = totalTrueCnts_[domainIdx];
00771     Array<double>& defaultTrueCnts = defaultTrueCnts_[domainIdx];    
00772     const MLN* mln = inferences_[domainIdx]->getState()->getMLN();
00773     const Domain* domain = inferences_[domainIdx]->getState()->getDomain();
00774 
00775     memset(clauseTrainCnts, 0, clauseCnt*sizeof(double));
00776     memset(clauseInferredCnts, 0, clauseCnt*sizeof(double));
00777 
00778     if (!lazyInference_)
00779     {
00780       if (!inferredCnts) inferredCnts = new double[clauseCnt];
00781 
00782       const Array<double>* clauseTrueCnts =
00783         inferences_[domainIdx]->getClauseTrueCnts();
00784       assert(clauseTrueCnts->size() == clauseCnt);
00785       for (int i = 0; i < clauseCnt; i++)
00786         inferredCnts[i] = (*clauseTrueCnts)[i];
00787       trainCnts = trainTrueCnts_[domainIdx];
00788     }
00789       //loop over all the training examples
00790     //cout << "\t\ttrain count\t\t\t\tinferred count" << endl << endl;
00791     for (int clauseno = 0; clauseno < clauseCnt; clauseno++) 
00792     {
00793       if (!relevantClauses[clauseno]) continue;
00794       
00795       if (lazyInference_)
00796       {
00797         Clause* clause = (Clause*) mln->getClause(clauseno);
00798 
00799         trainCnt = totalTrueCnts[clauseno];
00800         inferredCnt =
00801           clause->getNumTrueGroundings(domain, domain->getDB(), false);
00802         trainCnt -= defaultTrueCnts[clauseno];
00803         inferredCnt -= defaultTrueCnts[clauseno];
00804       
00805         clauseTrainCnts[clauseno] += trainCnt;
00806         clauseInferredCnts[clauseno] += inferredCnt;
00807       }
00808       else
00809       {
00810         clauseTrainCnts[clauseno] += trainCnts[clauseno];
00811         clauseInferredCnts[clauseno] += inferredCnts[clauseno];
00812       }
00813       //cout << clauseno << ":\t\t" <<trainCnt<<"\t\t\t\t"<<inferredCnt<<endl;
00814     }
00815 
00816     if (vpdebug)
00817     {
00818       cout << "net counts : " << endl;
00819       cout << "\t\ttrain count\t\t\t\tinferred count" << endl << endl;
00820     }
00821 
00822     for (int clauseno = 0; clauseno < clauseCnt; clauseno++) 
00823     {
00824       if (!relevantClauses[clauseno]) continue;
00825       
00826       if (vpdebug)
00827         cout << clauseno << ":\t\t" << clauseTrainCnts[clauseno] << "\t\t\t\t"
00828              << clauseInferredCnts[clauseno] << endl;
00829       if (rescaleGradient_ && clauseTrainCnts[clauseno] > 0)
00830       {
00831         gradient[clauseno] += 
00832           (clauseTrainCnts[clauseno] - clauseInferredCnts[clauseno])
00833             / clauseTrainCnts[clauseno];
00834       }
00835       else
00836       {
00837         gradient[clauseno] += clauseTrainCnts[clauseno] - 
00838                               clauseInferredCnts[clauseno];
00839       }
00840     }
00841 
00842     delete[] clauseTrainCnts;
00843     delete[] clauseInferredCnts;
00844   }
00845 
00846 
00847     // Get the gradient 
00848   void getGradient(double* const & weights, double* const & gradient,
00849                    const int numWts) 
00850   {
00851     // Set the weights and run inference
00852     
00853     //cout << "New Weights = **** " << endl << endl;
00854     
00855       // If there is one db or the clauses for multiple databases line up
00856     if (idxTrans_ == NULL)
00857     {
00858       int clauseCnt = clauseCntPerDomain_[0];
00859       for (int i = 0; i < domainCnt_; i++)
00860       {
00861         Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00862         assert(clauseCntPerDomain_[i] == clauseCnt);
00863         const MLN* mln = inferences_[i]->getState()->getMLN();
00864         
00865         for (int j = 0; j < clauseCnt; j++) 
00866         {
00867           Clause* c = (Clause*) mln->getClause(j);
00868           if (relevantClauses[j]) c->setWt(weights[j]);
00869           else                    c->setWt(0);
00870         }
00871       }
00872     }
00873     else
00874     {   // The clauses for multiple databases do not line up
00875       Array<Array<double> >* wtsPerDomain = idxTrans_->getWtsPerDomain();
00876       const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain 
00877         = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00878       
00879       for (int i = 0; i < domainCnt_; i++)
00880       {
00881         Array<double>& wts = (*wtsPerDomain)[i];
00882         memset((double*)wts.getItems(), 0, wts.size()*sizeof(double));
00883 
00884           //map clause/formula weights to clause weights
00885         for (int j = 0; j < wts.size(); j++)
00886         {
00887           Array<IdxDiv>* idxDivs = (*cIdxToCFIdxsPerDomain)[i][j];          
00888           for (int k = 0; k < idxDivs->size(); k++)
00889             wts[j] += weights[ (*idxDivs)[k].idx ] / (*idxDivs)[k].div;
00890         }
00891       }
00892       
00893       for (int i = 0; i < domainCnt_; i++)
00894       {
00895         Array<bool>& relevantClauses = relevantClausesPerDomain_[i];
00896         int clauseCnt = clauseCntPerDomain_[i];
00897         Array<double>& wts = (*wtsPerDomain)[i];
00898         assert(wts.size() == clauseCnt);
00899         const MLN* mln = inferences_[i]->getState()->getMLN();
00900 
00901         for (int j = 0; j < clauseCnt; j++)
00902         {
00903           Clause* c = (Clause*) mln->getClause(j);
00904           if (relevantClauses[j]) c->setWt(wts[j]);
00905           else                   c->setWt(0);
00906         }
00907       }
00908     }
00909     //for (int i = 0; i < numWts; i++) cout << i << " : " << weights[i] << endl;
00910 
00911     if (withEM_) fillInMissingValues();
00912     cout << "Running inference ..." << endl;
00913     infer();
00914     cout << "Done with inference" << endl;
00915 
00916       // Compute the gradient
00917     memset(gradient, 0, numWts*sizeof(double));
00918 
00919       // There is one DB or the clauses of multiple DBs line up
00920     if (idxTrans_ == NULL)
00921     {
00922       for (int i = 0; i < domainCnt_; i++) 
00923       {           
00924         //cout << "For domain number " << i << endl << endl; 
00925         getGradientForDomain(gradient, i);        
00926       }
00927     }
00928     else
00929     {
00930         // The clauses for multiple databases do not line up
00931       Array<Array<double> >* gradsPerDomain = idxTrans_->getGradsPerDomain();
00932       const Array<Array<Array<IdxDiv>*> >* cIdxToCFIdxsPerDomain 
00933         = idxTrans_->getClauseIdxToClauseFormulaIdxsPerDomain();
00934      
00935       for (int i = 0; i < domainCnt_; i++) 
00936       {           
00937         //cout << "For domain number " << i << endl << endl; 
00938 
00939         Array<double>& grads = (*gradsPerDomain)[i];
00940         memset((double*)grads.getItems(), 0, grads.size()*sizeof(double));
00941         
00942         getGradientForDomain((double*)grads.getItems(), i);
00943         
00944           // map clause gradient to clause/formula gradients
00945         assert(grads.size() == clauseCntPerDomain_[i]);
00946         for (int j = 0; j < grads.size(); j++)
00947         {
00948           Array<IdxDiv>* idxDivs = (*cIdxToCFIdxsPerDomain)[i][j];          
00949           for (int k = 0; k < idxDivs->size(); k++)
00950             gradient[ (*idxDivs)[k].idx ] += grads[j] / (*idxDivs)[k].div;
00951         }
00952       }
00953     }
00954 
00955       // Add the deriative of the prior 
00956     if (usePrior_) 
00957     {
00958           for (int i = 0; i < numWts; i++) 
00959       {
00960         if (!relevantClausesFormulas_[i]) continue;
00961         double priorDerivative = -(weights[i]-priorMeans_[i])/
00962                                  (priorStdDevs_[i]*priorStdDevs_[i]);
00963         //cout << i << " : " << "gradient : " << gradient[i]
00964         //     << "  prior gradient : " << priorDerivative;
00965         gradient[i] += priorDerivative; 
00966             //cout << "  net gradient : " << gradient[i] << endl; 
00967       }
00968     }
00969   }
00970 
00971 
00972  private:
00973   int domainCnt_;
00974   //Array<Domain*> domains_;  
00975   //Array<MLN*> mlns_;
00976   Array<Array<double> > logOddsPerDomain_;
00977   Array<int> clauseCntPerDomain_;
00978 
00979         // Used in lazy version
00980   Array<Array<double> > totalTrueCnts_; 
00981   Array<Array<double> > defaultTrueCnts_;
00982 
00983   Array<Array<bool> > relevantClausesPerDomain_;
00984   Array<bool> relevantClausesFormulas_;
00985 
00986         // Used to compute cnts from mrf
00987   Array<double*> trainTrueCnts_;
00988 
00989   bool usePrior_;
00990   const double* priorMeans_, * priorStdDevs_; 
00991 
00992   IndexTranslator* idxTrans_; //not owned by object; don't delete
00993   
00994   bool lazyInference_;
00995   bool rescaleGradient_;
00996   bool isQueryEvidence_;
00997 
00998   Array<Inference*> inferences_;
00999   
01000     // Using EM to fill in missing values?
01001   bool withEM_;
01002 };
01003 
01004 
01005 #endif

Generated on Tue Jan 16 05:30:03 2007 for Alchemy by  doxygen 1.5.1