inference.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #ifndef INFERENCE_H_
00068 #define INFERENCE_H_
00069 
00070 #include "variablestate.h"
00071 #include "hvariablestate.h"
00072 
00073   // Default seed when none specified (some random number)
00074 const long int DEFAULT_SEED = 2350877;
00075 
00081 class Inference
00082 {
00083  public:
00084 
00096   Inference(VariableState* state, long int seed,
00097             const bool& trackClauseTrueCnts,
00098             Array<Array<Predicate* >* >* queryFormulas = NULL)
00099       : seed_(seed), state_(state), saveAllCounts_(false),
00100         clauseTrueCnts_(NULL), clauseTrueSqCnts_(NULL),
00101         numSamples_(0),
00102         allClauseTrueCnts_(NULL), oldClauseTrueCnts_(NULL),
00103         oldAllClauseTrueCnts_(NULL), queryFormulas_(queryFormulas)
00104   {
00105       // If seed not specified, then init always to same random number
00106     if (seed_ == -1) seed_ = DEFAULT_SEED;
00107     srandom(seed_);
00108 
00109     trackClauseTrueCnts_ = trackClauseTrueCnts;
00110     if (trackClauseTrueCnts_ && state_)
00111     {
00112       int numClauses = state_->getMLN()->getNumClauses();
00113 
00114         // clauseTrueCnts_ and clauseTrueSqCnts_ will hold the true 
00115         // counts (and squared true counts) for each first-order clause
00116       clauseTrueCnts_ = new Array<double>(numClauses, 0);
00117       clauseTrueSqCnts_ = new Array<double>(numClauses, 0);
00118     }
00119     
00120     if (queryFormulas_)
00121       qfProbs_ = new Array<double>(queryFormulas_->size(), 0);
00122     else
00123       qfProbs_ = NULL;
00124   }
00125   
00126   Inference(HVariableState* state, long int seed,
00127             const bool& trackClauseTrueCnts)
00128       : seed_(seed), hstate_(state), saveAllCounts_(false),
00129         clauseTrueCnts_(NULL), clauseTrueCntsCont_(NULL), 
00130         clauseTrueSqCnts_(NULL), numSamples_(0),
00131         allClauseTrueCnts_(NULL), oldClauseTrueCnts_(NULL),
00132         oldAllClauseTrueCnts_(NULL)
00133   {
00134         // If seed not specified, then init always to same random number
00135       if (seed_ == -1) seed_ = DEFAULT_SEED;
00136       srandom(seed_);
00137           
00138           trackClauseTrueCnts_ = trackClauseTrueCnts;
00139           if (trackClauseTrueCnts_ && hstate_)
00140           {
00141                   // clauseTrueCnts_ will hold the true counts for each first-order
00142           // clause
00143                   clauseTrueCnts_ = new Array<double>;
00144                   clauseTrueCnts_->growToSize(hstate_->getMLN()->getNumClauses(), 0);
00145                   clauseTrueCntsCont_ = new Array<double>;
00146                   clauseTrueCntsCont_->growToSize(hstate_->getNumContFormulas(), 0);
00147           }
00148   }
00149   
00153   virtual ~Inference()
00154   {
00155     delete clauseTrueCnts_;
00156     delete clauseTrueSqCnts_;
00157     delete allClauseTrueCnts_;
00158 
00159     delete oldAllClauseTrueCnts_;
00160     delete oldClauseTrueCnts_;
00161     
00162     if (qfProbs_) delete qfProbs_;
00163   }
00164 
00165 
00166   void saveAllCounts(bool saveCounts=true)
00167   {
00168     if (saveAllCounts_ == saveCounts)
00169       return;
00170 
00171     saveAllCounts_ = saveCounts;
00172     if (saveCounts)
00173     {
00174       allClauseTrueCnts_ = new Array<Array<double> >;
00175       oldAllClauseTrueCnts_ = new Array<Array<double> >;
00176     }
00177     else
00178     {
00179       delete allClauseTrueCnts_;
00180       delete oldAllClauseTrueCnts_;
00181       allClauseTrueCnts_ = NULL;
00182       oldAllClauseTrueCnts_ = NULL;
00183     }
00184   }
00185 
00186 
00190   virtual void init() = 0;
00191 
00195   virtual void infer() = 0;
00196 
00200   virtual void printNetwork(ostream& out) = 0; 
00201 
00205   virtual void printProbabilities(ostream& out) = 0;
00206   
00213   virtual void getChangedPreds(vector<string>& changedPreds,
00214                                vector<float>& probs,
00215                                vector<float>& oldProbs,
00216                                const float& probDelta) = 0;
00217 
00218   
00222   virtual void printTruePreds(ostream& out) = 0;
00223   virtual void printTruePredsH(ostream& out) = 0;
00224 
00228   virtual double getProbability(GroundPredicate* const& gndPred) = 0;
00229   virtual double getProbabilityH(GroundPredicate* const& gndPred) = 0;
00230 
00234   void printQFProbs(ostream& out, Domain* domain)
00235   {
00236     if (qfProbs_)
00237     {
00238       for (int i = 0; i < queryFormulas_->size(); i++)
00239       {
00240         Array<Predicate* >* formula = (*queryFormulas_)[i];
00241         for (int j = 0; j < formula->size(); j++)
00242         {
00243           (*formula)[j]->printWithStrVar(out, domain);
00244           if (j != formula->size() - 1) out << " ^ ";
00245         }
00246         out << " " << (*qfProbs_)[i] << endl;
00247       }
00248     }
00249   }
00250 
00251   long int getSeed() { return seed_; }
00252   void setSeed(long int s) { seed_ = s; }
00253   
00254   VariableState* getState() { return state_; }
00255   void setState(VariableState* s) { state_ = s; }
00256 
00257   HVariableState* getHState() { return hstate_; }
00258   void setHState(HVariableState* s) { hstate_ = s; }
00259   
00265   virtual void scaleSamples(double factor) { /* Override this... */ }
00266 
00267 
00268   
00269   const Array<double>* getClauseTrueCnts()   { return clauseTrueCnts_; }
00270   const Array<double>* getClauseTrueSqCnts() { return clauseTrueSqCnts_; }
00271   int getNumSamples() const { return numSamples_; }
00272 
00273   // Compute the full Hessian matrix from the stored inferred counts.
00274   // Only works when saveAllCounts(true) has been called before inference.
00275   // Caller is responsible for deleting Hessian.
00276   //
00277   // WARNING: The size of the Hessian is equal to the number of weights
00278   // squared.  Therefore, this is not practical in a model with thousands
00279   // of weights!
00280   const Array<Array<double> >* getHessian()
00281   {
00282     int numClauses = state_->getMLN()->getNumClauses();
00283     int numSamples = allClauseTrueCnts_->size();
00284 
00285     // Allocate Hessian
00286     Array<Array<double> >* hessian = new Array<Array<double> >(numClauses);
00287     for (int i = 0; i < numClauses; i++)
00288       (*hessian)[i].growToSize(numClauses);
00289 
00290     // The i jth element is:
00291     // E[n_i] E[n_j] - E[n_i * n_j]
00292     // where n is the vector of all clause counts
00293 
00294     for (int i = 0; i < numClauses; i++)
00295     {
00296       for (int j = 0; j < numClauses; j++)
00297       {
00298         double ni = 0.0;
00299         double nj = 0.0;
00300         double ninj = 0.0;
00301         for (int s = 0; s < numSamples; s++)
00302         {
00303           ni += (*allClauseTrueCnts_)[s][i];
00304           nj += (*allClauseTrueCnts_)[s][j];
00305           ninj += (*allClauseTrueCnts_)[s][i] 
00306                 * (*allClauseTrueCnts_)[s][j]; 
00307         }
00308         double n = numSamples;
00309         (*hessian)[i][j] = ni/n * nj/n - ninj/n;
00310       }
00311     }
00312 
00313     return hessian;
00314   }
00315 
00316 
00317   // Alternate way to compute product of Hessian with vector
00318   // WARNING: much less efficient!!
00319   const Array<double>* getHessianVectorProduct2(Array<double>& v)
00320   {
00321     int numClauses = state_->getMLN()->getNumClauses();
00322     const Array<Array<double> >* hessian = getHessian();
00323     Array<double>* product = new Array<double>(numClauses,0);
00324 
00325     for (int clauseno = 0; clauseno < numClauses; clauseno++)
00326     {
00327       (*product)[clauseno] = 0.0;
00328       for (int i = 0; i < numClauses; i++)
00329         (*product)[clauseno] += (*hessian)[clauseno][i] * v[i];
00330     }
00331 
00332     delete hessian;
00333     return product;
00334   }
00335 
00336 
00337   const Array<double>* getHessianVectorProduct(const Array<double>& v)
00338   {
00339     int numClauses = state_->getMLN()->getNumClauses();
00340     int numSamples = allClauseTrueCnts_->size();
00341 
00342       // For minimizing the negative log likelihood, 
00343       // the ith element of H v is:
00344       //   E[n_i * vn] - E[n_i] E[vn]
00345       // where n is the vector of all clause counts
00346       // and vn is the dot product of v and n.
00347 
00348     double sumVN = 0;
00349     Array<double> sumN(numClauses, 0);
00350     Array<double> sumNiVN(numClauses, 0);
00351 
00352     // Get sufficient statistics from each sample, 
00353     // so we can compute expectations
00354     for (int s = 0; s < numSamples; s++) 
00355     {
00356       Array<double>& n = (*allClauseTrueCnts_)[s];
00357 
00358       // Compute v * n
00359       double vn = 0;
00360       for (int i = 0; i < numClauses; i++)
00361         vn += v[i] * n[i];
00362       
00363       // Tally v*n, n_i, and n_i v*n
00364       sumVN += vn;
00365       for (int i = 0; i < numClauses; i++)
00366       {
00367         sumN[i]    += n[i];
00368         sumNiVN[i] += n[i] * vn;
00369       }
00370     }
00371 
00372     // Compute actual product from the sufficient stats
00373     Array<double>* product = new Array<double>(numClauses,0);
00374     for (int clauseno = 0; clauseno < numClauses; clauseno++)
00375     {
00376       double E_vn = sumVN/numSamples;
00377       double E_ni = sumN[clauseno]/numSamples;
00378       double E_nivn = sumNiVN[clauseno]/numSamples;
00379       (*product)[clauseno] = E_nivn - E_ni * E_vn;
00380     }
00381 
00382     return product;
00383   }
00384 
00385 
00386   void resetCnts() 
00387   {
00388     for (int clauseno = 0; clauseno < clauseTrueCnts_->size(); clauseno++)
00389     {
00390       (*clauseTrueCnts_)[clauseno]   = 0;
00391       (*clauseTrueSqCnts_)[clauseno] = 0;
00392     }
00393     numSamples_ = 0;
00394 
00395     if (saveAllCounts_)
00396     {
00397       delete allClauseTrueCnts_;
00398       allClauseTrueCnts_ = new Array<Array<double> >;
00399     }
00400   }
00401 
00402 
00403   void saveCnts()
00404   {
00405     if (!saveAllCounts_)
00406       return;
00407 
00408     // DEBUG
00409     //cout << "Saving counts.  numSamples_ = " << numSamples_ << endl;
00410 
00411     delete oldAllClauseTrueCnts_;
00412     oldAllClauseTrueCnts_ = new Array<Array<double> > (*allClauseTrueCnts_);
00413 
00414     /* DEBUG
00415     cout << "old counts size: " << oldAllClauseTrueCnts_->size() << endl;
00416     cout << "marker1: " << (*allClauseTrueCnts_)[0][0] << endl;
00417     cout << "marker2: " << (*clauseTrueCnts_)[0] << endl;
00418     */
00419   }
00420 
00421 
00422   void restoreCnts()
00423   {
00424     if (!saveAllCounts_)
00425       return;
00426 
00427     resetCnts();
00428 
00429     *allClauseTrueCnts_ = *oldAllClauseTrueCnts_;
00430     for (int i = 0; i < allClauseTrueCnts_->size(); i++) 
00431     {
00432       int numcounts = (*allClauseTrueCnts_)[i].size();
00433       for (int j = 0; j < numcounts; j++)
00434       {
00435         double currcount = (*allClauseTrueCnts_)[i][j];
00436         (*clauseTrueCnts_)[j]   += currcount;
00437         (*clauseTrueSqCnts_)[j] += currcount * currcount;
00438       }
00439       numSamples_++;
00440     }
00441 
00442     /* DEBUG
00443     cout << "marker1: " << (*allClauseTrueCnts_)[0][0] << endl;
00444     cout << "marker2: " << (*clauseTrueCnts_)[0] << endl;
00445     cout << "numSamples_ = " << numSamples_ << endl;
00446     */
00447   }
00448 
00449 
00450   void tallyCntsFromState()
00451   {
00452     int numcounts = clauseTrueCnts_->size();
00453     Array<double> currCounts(numcounts, 0.0);
00454     state_->getNumClauseGndings(&currCounts, true);
00455 
00456     if (saveAllCounts_)
00457     {
00458       allClauseTrueCnts_->append(Array<double>());
00459       (*allClauseTrueCnts_)[numSamples_].growToSize(numcounts);
00460     }
00461 
00462     for (int i = 0; i < numcounts; i++)
00463     {
00464       if (saveAllCounts_)
00465         (*allClauseTrueCnts_)[numSamples_][i] = currCounts[i];
00466       
00467       (*clauseTrueCnts_)[i]   += currCounts[i];
00468       (*clauseTrueSqCnts_)[i] += currCounts[i] * currCounts[i];
00469     }
00470     numSamples_++;
00471   }
00472 
00473  protected:
00474   
00475     // Seed for randomizer. If not set, then date + time is used
00476   long int seed_;
00477     // State of atoms and clauses used during inference
00478     // Does not belong to inference (do not delete)
00479   VariableState* state_;
00480   HVariableState* hstate_;
00481   
00482     // Save all counts for all iterations
00483   bool saveAllCounts_;
00484     // Holds the average number of true groundings of each first-order clause
00485     // in the mln associated with this inference
00486   Array<double>* clauseTrueCnts_;
00487   Array<double>* clauseTrueCntsCont_;
00488   
00489     // Holds the average number of true groundings squared of each 
00490     // first-order clause in the mln associated with this inference
00491   Array<double>* clauseTrueSqCnts_;
00492     // Number of samples taken of the true counts
00493   int numSamples_;
00494     // Indicates if true counts for each first-order clause are being kept
00495   bool trackClauseTrueCnts_;
00496     // Where these counts are stored: (*allClauseTrueCnts_)[i][j] has the true
00497     // counts for clause j in sample i
00498   Array<Array<double> >* allClauseTrueCnts_;
00499 
00500   Array<double>* oldClauseTrueCnts_;
00501   Array<Array<double> >* oldAllClauseTrueCnts_;
00502   
00503   Array<Array<Predicate* >* >* queryFormulas_;
00504   Array<double>* qfProbs_;
00505 };
00506 
00507 #endif /*INFERENCE_H_*/

Generated on Sun Jun 7 11:55:12 2009 for Alchemy by  doxygen 1.5.1