mcmc.h

00001 /*
00002  * All of the documentation and software included in the
00003  * Alchemy Software is copyrighted by Stanley Kok, Parag
00004  * Singla, Matthew Richardson, Pedro Domingos, Marc
00005  * Sumner, Hoifung Poon, Daniel Lowd, and Jue Wang.
00006  * 
00007  * Copyright [2004-09] Stanley Kok, Parag Singla, Matthew
00008  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00009  * Poon, Daniel Lowd, and Jue Wang. All rights reserved.
00010  * 
00011  * Contact: Pedro Domingos, University of Washington
00012  * (pedrod@cs.washington.edu).
00013  * 
00014  * Redistribution and use in source and binary forms, with
00015  * or without modification, are permitted provided that
00016  * the following conditions are met:
00017  * 
00018  * 1. Redistributions of source code must retain the above
00019  * copyright notice, this list of conditions and the
00020  * following disclaimer.
00021  * 
00022  * 2. Redistributions in binary form must reproduce the
00023  * above copyright notice, this list of conditions and the
00024  * following disclaimer in the documentation and/or other
00025  * materials provided with the distribution.
00026  * 
00027  * 3. All advertising materials mentioning features or use
00028  * of this software must display the following
00029  * acknowledgment: "This product includes software
00030  * developed by Stanley Kok, Parag Singla, Matthew
00031  * Richardson, Pedro Domingos, Marc Sumner, Hoifung
00032  * Poon, Daniel Lowd, and Jue Wang in the Department of
00033  * Computer Science and Engineering at the University of
00034  * Washington".
00035  * 
00036  * 4. Your publications acknowledge the use or
00037  * contribution made by the Software to your research
00038  * using the following citation(s): 
00039  * Stanley Kok, Parag Singla, Matthew Richardson and
00040  * Pedro Domingos (2005). "The Alchemy System for
00041  * Statistical Relational AI", Technical Report,
00042  * Department of Computer Science and Engineering,
00043  * University of Washington, Seattle, WA.
00044  * http://alchemy.cs.washington.edu.
00045  * 
00046  * 5. Neither the name of the University of Washington nor
00047  * the names of its contributors may be used to endorse or
00048  * promote products derived from this software without
00049  * specific prior written permission.
00050  * 
00051  * THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF WASHINGTON
00052  * AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
00053  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
00054  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00055  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
00056  * OF WASHINGTON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
00057  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
00058  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00059  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
00060  * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
00061  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
00062  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00063  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
00064  * IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00065  * 
00066  */
00067 #ifndef MCMC_H_
00068 #define MCMC_H_
00069 
00070 #include "inference.h"
00071 #include "mcmcparams.h"
00072 
00073   // Set to true for more output
00074 const bool mcmcdebug = false;
00075 
00080 class MCMC : public Inference
00081 {
00082  public:
00083 
00090   MCMC(VariableState* state, long int seed, const bool& trackClauseTrueCnts,
00091        MCMCParams* params, Array<Array<Predicate* >* >* queryFormulas = NULL)
00092     : Inference(state, seed, trackClauseTrueCnts, queryFormulas)
00093   {
00094       // User-set parameters
00095     numChains_ = params->numChains;
00096     burnMinSteps_ = params->burnMinSteps;
00097     burnMaxSteps_ = params->burnMaxSteps;
00098     minSteps_ = params->minSteps;
00099     maxSteps_ = params->maxSteps;
00100     maxSeconds_ = params->maxSeconds;
00101   }
00102 
00103 
00104   MCMC(HVariableState* state, long int seed, const bool& trackClauseTrueCnts,
00105           MCMCParams* params)
00106           : Inference(state, seed, trackClauseTrueCnts)
00107   {
00108           // User-set parameters
00109           numChains_ = params->numChains;
00110           burnMinSteps_ = params->burnMinSteps;
00111           burnMaxSteps_ = params->burnMaxSteps;
00112           minSteps_ = params->minSteps;
00113           maxSteps_ = params->maxSteps;
00114           maxSeconds_ = params->maxSeconds;
00115   }
00116 
00120   ~MCMC() {}
00121 
00122   /*double computeHybridClauseValue(int contClauseIdx)
00123   {
00124           return hstate_->hybridClauseCost_[contClauseIdx] * HybridClauseContPartValue(contClauseIdx) * HybridClauseDisPartValue(contClauseIdx);
00125   }*/
00126 
00127   double computeHybridClauseValue(int clauseIdx, int c)
00128   {
00129           double contClauseContPartValue = HybridClauseContPartValue(clauseIdx, c);
00130           double contClauseDisPartValue = HybridClauseDisPartValue(clauseIdx, c);
00131           return hstate_->hybridWts_[clauseIdx] * contClauseContPartValue * contClauseDisPartValue;
00132   }
00133 
00134   double HybridClauseContPartValue(int contClauseIdx, int c)
00135   {
00136           PolyNomial& pl = hstate_->GetHybridClausePolynomial(contClauseIdx);
00137 
00138           assert(hstate_->hybridContClause_[contClauseIdx].size() == pl.GetVarNum());
00139 
00140           Array<double> arVar;
00141           for(int i = 0; i < hstate_->hybridContClause_[contClauseIdx].size(); ++i)
00142           {
00143                   arVar.append(truthValuesCont_[hstate_->hybridContClause_[contClauseIdx][i] - 1][c]); // the order of the cont atom here should be the same as they are in the Pl
00144           }
00145           double v = pl.ComputePlValue(arVar);
00146           
00147           return v;
00148   }
00149 
00150   double HybridClauseDisPartValue(int contClauseIdx, int c)
00151   {
00152           bool bAndOr = hstate_->hybridConjunctionDisjunction_[contClauseIdx];    
00153           int numTrueLits = 0;
00154           int numFalseLits = 0;
00155           for(int j = 0; j < hstate_->hybridDisClause_[contClauseIdx].size(); ++j)
00156           {
00157                   int atomIdx = hstate_->hybridDisClause_[contClauseIdx][j];
00158                   if((atomIdx > 0) == truthValues_[abs(atomIdx)-1][c]) // true literal
00159                   {
00160                           numTrueLits ++;
00161                           if(!bAndOr) // disjunctive clause
00162                           {
00163                                   break;
00164                           }
00165                   }
00166                   else // false literal
00167                   {
00168                           numFalseLits ++;
00169                           if(bAndOr) // conjunctive clause
00170                           {
00171                                   break;
00172                           }
00173                   }
00174           }
00175           if(!bAndOr) // disjunctive clause
00176           {
00177                   return (numTrueLits > 0)?1.0:0.0;
00178           }
00179           else // conjunctive clause
00180           {
00181                   return (numFalseLits > 0)?0.0:1.0;
00182           }
00183           return 0.0;
00184   }
00185 
00186   void updateDisPredValue(int predIdx, int chainIdx, bool updateValue)
00187   {
00188           bool bBak = truthValues_[predIdx][chainIdx];
00189           Array<int>& occDisPos = hstate_->getPosOccurenceArray(predIdx + 1);
00190           Array<int>& occDisNeg = hstate_->getNegOccurenceArray(predIdx + 1);
00191           Array<int>& occCont = hstate_->getContDisOccurenceArray(predIdx + 1) ;
00192 
00193           if(updateValue)
00194           {
00195                   if(!bBak)
00196                   {
00197                           // update numTrueLits of discrete clauses where it appears 
00198                           for(int j = 0; j < occDisPos.size(); ++j)
00199                           {
00200                                   int disClauseIdx = occDisPos[j];
00201                                   numTrueLits_[disClauseIdx][chainIdx] += 1;
00202                           }
00203 
00204                           for( int j = 0; j < occDisNeg.size(); ++j)
00205                           {
00206                                   int disClauseIdx = occDisNeg[j];
00207                                   numTrueLits_[disClauseIdx][chainIdx] -= 1;
00208                           }
00209                   }
00210           }
00211           else
00212           {
00213                   if(bBak)
00214                   {
00215                           for( int j = 0; j < occDisPos.size(); ++j)
00216                           {
00217                                   int disClauseIdx = occDisPos[j];
00218                                   numTrueLits_[disClauseIdx][chainIdx] -= 1;
00219                           }
00220 
00221                           for( int j = 0; j < occDisNeg.size(); ++j)
00222                           {
00223                                   int disClauseIdx = occDisNeg[j];
00224                                   numTrueLits_[disClauseIdx][chainIdx] += 1;
00225                           }
00226                   }     
00227           }
00228 
00229           truthValues_[predIdx][chainIdx] = updateValue;
00230 
00231           if(bBak != updateValue)
00232           {
00233                   for( int j = 0; j < occCont.size(); ++j)
00234                   {
00235                           int hybridClauseIdx = occCont[j];
00236                           hybridClauseDisPartValueMCMC_[hybridClauseIdx] = HybridClauseDisPartValue(hybridClauseIdx, chainIdx)==1.0?true:false;
00237                   }
00238           }  
00239   }
00240 
00241   void updateProposalContValue(int contPredIdx, int chainIdx)
00242   {
00243           Array<int>& occContCont = hstate_->hybridContOccurrence_[contPredIdx + 1];
00244           PolyNomial pl;
00245           for(int j = 0; j < occContCont.size(); j++)
00246           {
00247                   int contClauseIdx = occContCont[j];
00248                   PolyNomial pltmp = hstate_->hybridPls_[contClauseIdx];
00249                   //find the in idx here
00250                   int inIdx = -1;
00251                   Array<double> arVars;
00252                   for(int k = 0; k < hstate_->hybridContClause_[contClauseIdx].size(); k++)
00253                   {
00254                           arVars.append(truthValuesCont_[hstate_->hybridContClause_[contClauseIdx][k]-1][chainIdx]);
00255                           if (hstate_->hybridContClause_[contClauseIdx][k] == contPredIdx+1)
00256                           {
00257                                   //pltmp.ReduceToOneVar(,k)
00258                                   inIdx = k;
00259                           }                       
00260                   }
00261                   if (inIdx == -1)
00262                   {
00263                           cout <<  "faint" << endl;
00264                   }                       
00265 
00266                   pltmp.ReduceToOneVar(arVars, inIdx);
00267                   pl.AddPl(pltmp);
00268           }
00269           
00270           double miu = 0, stdev = 0;
00271           pl.GetGaussianPara(&miu, &stdev);
00272           truthValuesCont_[contPredIdx][chainIdx] = ExtRandom::gaussRandom(miu, stdev);
00273 
00274           for(int j = 0; j < occContCont.size(); j++)
00275           {
00276                   int hybridClauseIdx = occContCont[j];
00277                   hybridClauseContPartValueMCMC_[hybridClauseIdx][chainIdx] = HybridClauseContPartValue(hybridClauseIdx,chainIdx);
00278           }  
00279   }
00280 
00281   // Computes the probability of a ground discrete predicate in a chain, adapted to accommodate the hybrid case.
00282   double getProbabilityOfPredH(const int& predIdx, const int& chainIdx, const double& invTemp)
00283   {
00284           // Different for multi-chain
00285           if (numChains_ > 1)
00286           {
00287                   double wtDisAsTrue = 0, wtDisAsFalse = 0;
00288                   bool bBak = truthValues_[predIdx][chainIdx];
00289 
00290                   Array<int>& occDisPos = hstate_->getPosOccurenceArray(predIdx + 1);
00291                   Array<int>& occDisNeg = hstate_->getNegOccurenceArray(predIdx + 1);
00292                   Array<int>& occCont = hstate_->getContDisOccurenceArray(predIdx + 1) ;
00293 
00294                   // for clauses where the variable appears as a true literal
00295                   for(int j = 0; j < occDisPos.size(); j++)
00296                   {
00297                           int disClauseIdx = occDisPos[j];
00298                           wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00299 
00300                           if(numTrueLits_[disClauseIdx][chainIdx] > 1)
00301                           {
00302                                   wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00303                           }
00304                           else if(numTrueLits_[disClauseIdx][chainIdx] == 1 && !bBak) // there was only 1 true literal and the only true literal is some one else
00305                           {
00306                                   wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00307                           }
00308                   }
00309 
00310                   // for clauses where the variable appears as a false literal
00311                   for(int j = 0; j < occDisNeg.size(); j++)
00312                   {
00313                           int disClauseIdx = occDisNeg[j];
00314                           if(numTrueLits_[disClauseIdx][chainIdx] > 1) // there were more than one true literals 
00315                           {
00316                                   wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00317                           }
00318                           else if(numTrueLits_[disClauseIdx][chainIdx] == 1 && bBak) // there was only 1 true literal and the only true literal is some one else
00319                           {
00320                                   wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00321                           }
00322                           wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00323                   }
00324 
00325                   for (int j = 0; j < occCont.size(); j++) {
00326                           int hybridClauseIdx = occCont[j];
00327                           double contPart = hybridClauseContPartValueMCMC_[hybridClauseIdx][chainIdx];
00328                           double wt = hstate_->hybridWts_[hybridClauseIdx];
00329 
00330                           truthValues_[predIdx][chainIdx] = true;
00331                           double disPartAsTrue = HybridClauseDisPartValue(hybridClauseIdx,chainIdx);
00332                           truthValues_[predIdx][chainIdx] = false;
00333                           double disPartAsFalse = HybridClauseDisPartValue(hybridClauseIdx,chainIdx);
00334                           truthValues_[predIdx][chainIdx] = bBak;
00335 
00336                           wtDisAsTrue += wt*disPartAsTrue*contPart;
00337                           wtDisAsFalse += wt*disPartAsFalse*contPart;
00338                   }
00339 
00340                   // get probabilities
00341                   double wtDiff = (wtDisAsFalse - wtDisAsTrue) * invTemp;
00342                   double prob;
00343                   if (wtDiff > 403.429)
00344                   {
00345                           prob = 0;
00346                   }
00347                   else if (wtDiff < -403.429)
00348                   {
00349                           prob = 1;
00350                   }
00351                   else
00352                   {
00353                           prob = 1 / ( 1 + exp(wtDiff));
00354                   }
00355                   return prob;
00356           }  // For chain id ends.
00357           else
00358           {
00359                   GroundPredicate* gndPred = hstate_->getGndPred(predIdx);
00360                   double wtDisAsTrue = 0, wtDisAsFalse = 0;
00361                   bool bBak = gndPred->getTruthValue();
00362 
00363                   Array<int>& occDisPos = hstate_->getPosOccurenceArray(predIdx + 1);
00364                   Array<int>& occDisNeg = hstate_->getNegOccurenceArray(predIdx + 1);
00365                   Array<int>& occCont = hstate_->getContDisOccurenceArray(predIdx + 1) ;
00366 
00367                   gndPred->setTruthValue(true);
00368                   //hstate_->setValueOfAtom(i+1, true);           
00369                   for(int j = 0; j < occDisPos.size(); j++)
00370                   {
00371                           int disClauseIdx = occDisPos[j];
00372                           wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];                      
00373                   }
00374 
00375                   for(int j = 0; j < occDisNeg.size(); j++)
00376                   {
00377                           int disClauseIdx = occDisNeg[j];
00378                           for(int k = 0; k < hstate_->clause_[disClauseIdx].size(); k++)
00379                           {
00380                                   int lit = hstate_->clause_[disClauseIdx][k];
00381                                   GroundPredicate* gndPredtmp =  hstate_->getGndPred(abs(lit)-1);
00382                                   if ((lit > 0) == gndPredtmp->getTruthValue()) //true literal
00383                                   {
00384                                           wtDisAsTrue += hstate_->clauseCost_[disClauseIdx];
00385                                           break;
00386                                   }
00387                           }
00388                   }
00389 
00390                   for (int j = 0; j < occCont.size(); j++)
00391                   {
00392                           int contClauseIdx = occCont[j];
00393                           wtDisAsTrue += hstate_->HybridClauseValue(contClauseIdx);
00394                                   // computeHybridClauseValue(contClauseIdx);
00395                   }
00396 
00397 
00398                   gndPred->setTruthValue(false);
00399                   for(int j = 0; j < occDisNeg.size(); j++)
00400                   {
00401                           int disClauseIdx = occDisNeg[j];
00402                           wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00403                   }
00404 
00405                   for(int j = 0; j < occDisPos.size(); j++)
00406                   {
00407                           int disClauseIdx = occDisPos[j];
00408                           for(int k = 0; k < hstate_->clause_[disClauseIdx].size(); k++)
00409                           {
00410                                   int lit = hstate_->clause_[disClauseIdx][k];
00411                                   GroundPredicate* gndPredtmp =  hstate_->getGndPred(abs(lit)-1);
00412                                   if ((lit > 0) == gndPredtmp->getTruthValue()) //true literal
00413                                   {
00414                                           wtDisAsFalse += hstate_->clauseCost_[disClauseIdx];
00415                                           break;
00416                                   }
00417                           }
00418                   }
00419 
00420                   for (int j = 0; j < occCont.size(); j++)
00421                   {
00422                           int contClauseIdx = occCont[j];
00423                           wtDisAsFalse += hstate_->HybridClauseValue(contClauseIdx);
00424                   }
00425 
00426                   gndPred->setTruthValue(bBak);
00427 
00428                   return 1.0 / ( 1.0 + exp((wtDisAsFalse - wtDisAsTrue) * invTemp));
00429           }
00430   }
00431 
00435   virtual void printNetwork(ostream& out)
00436   {
00437   } 
00438 
00442   void printProbabilities(ostream& out)
00443   {
00444     for (int i = 0; i < state_->getNumAtoms(); i++)
00445     {
00446       double prob = getProbTrue(i);
00447 
00448         // Uniform smoothing
00449       prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00450       state_->printGndPred(i, out);
00451       out << " " << prob << endl;
00452     }    
00453   }
00454 
00468   void getChangedPreds(vector<string>& changedPreds, vector<float>& probs,
00469                        vector<float>& oldProbs, const float& probDelta)
00470   {
00471     changedPreds.clear();
00472     probs.clear();
00473     int numAtoms = state_->getNumAtoms();
00474       // Atoms may have been added to the state, previous prob. was 0
00475     oldProbs.resize(numAtoms, 0);
00476     for (int i = 0; i < numAtoms; i++)
00477     {
00478       double prob = getProbTrue(i);
00479       if (abs(prob - oldProbs[i]) > probDelta)
00480       {
00481           // Truth value has changed: Store new value (not smoothed) in oldProbs
00482           // and add to two return vectors
00483         oldProbs[i] = prob;
00484           // Uniform smoothing
00485         prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00486         ostringstream oss(ostringstream::out);
00487         state_->printGndPred(i, oss);
00488         changedPreds.push_back(oss.str());
00489         probs.push_back(prob);
00490       }
00491     }    
00492   }
00493 
00494 
00495   double getProbabilityH(GroundPredicate* const& gndPred)
00496   {
00497           int idx = hstate_->getGndPredIndex(gndPred);
00498           double prob = 0.0;
00499           if (idx >= 0) prob = getProbTrue(idx);
00500           // Uniform smoothing
00501           return (prob*10000 + 1/2.0)/(10000 + 1.0);
00502   }
00503 
00510   double getProbability(GroundPredicate* const& gndPred)
00511   {
00512     int idx = state_->getGndPredIndex(gndPred);
00513     double prob = 0.0;
00514     if (idx >= 0) prob = getProbTrue(idx);
00515       // Uniform smoothing
00516     return (prob*10000 + 1/2.0)/(10000 + 1.0);
00517   }
00518 
00522   void printTruePreds(ostream& out)
00523   {
00524     for (int i = 0; i < state_->getNumAtoms(); i++)
00525     {
00526       double prob = getProbTrue(i);
00527 
00528         // Uniform smoothing
00529       prob = (prob*10000 + 1/2.0)/(10000 + 1.0);
00530       if (prob >= 0.5) state_->printGndPred(i, out);
00531     }    
00532   }
00533 
00534   void printTruePredsH(ostream& out)
00535   {
00536           for (int i = 0; i < hstate_->getNumAtoms(); i++)
00537           {
00538                   double prob = getProbTrue(i);
00539 
00540                   // Uniform smoothing
00541                   prob = (prob * 10000 + 1/2.0) / (10000 + 1.0);
00542                   if (prob >= 0.5) hstate_->printGndPred(i, out);
00543           }    
00544   }
00545 
00546 
00547  protected:
00548 
00557   void initTruthValuesAndWts(const int& numChains)
00558   {
00559     int numPreds = state_->getNumAtoms();
00560     truthValues_.growToSize(numPreds);
00561     wtsWhenFalse_.growToSize(numPreds);
00562     wtsWhenTrue_.growToSize(numPreds);
00563     for (int i = 0; i < numPreds; i++)
00564     {
00565       truthValues_[i].growToSize(numChains, false);
00566       wtsWhenFalse_[i].growToSize(numChains, 0);
00567       wtsWhenTrue_[i].growToSize(numChains, 0);
00568     }
00569     
00570     int numClauses = state_->getNumClauses();
00571     numTrueLits_.growToSize(numClauses);
00572     for (int i = 0; i < numClauses; i++)
00573     {
00574       numTrueLits_[i].growToSize(numChains, 0);
00575     }
00576   }
00577 
00582   void initNumTrue()
00583   {
00584     int numPreds = state_->getNumAtoms();
00585     numTrue_.growToSize(numPreds);
00586     for (int i = 0; i < numTrue_.size(); i++)
00587       numTrue_[i] = 0;
00588   }
00589 
00596   void initNumTrueLits(const int& numChains)
00597   {
00598       // Single chain
00599     if (numChains == 1) state_->resetMakeBreakCostWatch();
00600     for (int i = 0; i < state_->getNumClauses(); i++)
00601     {
00602       GroundClause* gndClause = state_->getGndClause(i);
00603       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
00604       {
00605         const int atomIdx = abs(state_->getAtomInClause(j, i)) - 1;
00606         const bool sense = gndClause->getGroundPredicateSense(j);
00607         if (numChains > 1)
00608         {
00609           for (int c = 0; c < numChains; c++)
00610           {
00611             if (truthValues_[atomIdx][c] == sense)
00612             {
00613               numTrueLits_[i][c]++;
00614               assert(numTrueLits_[i][c] <= state_->getNumAtoms());
00615             }
00616           }
00617         }
00618         else
00619         { // Single chain
00620           GroundPredicate* gndPred = state_->getGndPred(atomIdx);
00621           if (gndPred->getTruthValue() == sense)
00622             state_->incrementNumTrueLits(i);
00623           assert(state_->getNumTrueLits(i) <= state_->getNumAtoms());
00624         }        
00625       }
00626     }
00627   }
00628  
00636   void randomInitGndPredsTruthValues(const int& numChains)
00637   {
00638     for (int c = 0; c < numChains; c++)
00639     {
00640       if (mcmcdebug) cout << "Chain " << c << ":" << endl;
00641         // For each block: select one to set to true
00642       for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++)
00643       {
00644           // If evidence atom exists, then all others are false
00645         if (state_->getDomain()->getBlockEvidence(i))
00646         {
00647             // If 2nd argument is -1, then all are set to false
00648           setOthersInBlockToFalse(c, -1, i);
00649           continue;
00650         }
00651 
00652         bool ok = false;
00653         while (!ok)
00654         {
00655           const Predicate* pred = state_->getDomain()->getRandomPredInBlock(i);
00656           GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00657           int idx = state_->getIndexOfGroundPredicate(gndPred);
00658             
00659           delete gndPred;
00660           delete pred;
00661 
00662           if (idx >= 0)
00663           {
00664               // Truth values are stored differently for multi-chain
00665             if (numChains_ > 1)
00666               truthValues_[idx][c] = true;
00667             else
00668             {
00669               GroundPredicate* gndPred = state_->getGndPred(i);
00670               gndPred->setTruthValue(true);
00671             }
00672             setOthersInBlockToFalse(c, idx, i);
00673             ok = true;
00674           }
00675         }
00676       }
00677       
00678         // Random tv for all not in blocks
00679       for (int i = 0; i < truthValues_.size(); i++)
00680       {
00681           // Predicates in blocks have been handled above
00682         if (state_->getBlockIndex(i) == -1)
00683         {
00684           bool tv = genTruthValueForProb(0.5);
00685             // Truth values are stored differently for multi-chain
00686           if (numChains_ > 1)
00687             truthValues_[i][c] = tv;
00688           else
00689           {
00690             GroundPredicate* gndPred = state_->getGndPred(i);
00691             gndPred->setTruthValue(tv);
00692           }
00693           if (mcmcdebug) cout << "Pred " << i << " set to " << tv << endl;
00694         }
00695       }
00696     }
00697   }
00698 
00705   bool genTruthValueForProb(const double& p)
00706   {
00707     if (p == 1.0) return true;
00708     if (p == 0.0) return false;
00709     bool r = random() <= p*RAND_MAX;
00710     return r;
00711   }
00712 
00722   double getProbabilityOfPred(const int& predIdx, const int& chainIdx,
00723                               const double& invTemp)
00724   {
00725       // Different for multi-chain
00726     if (numChains_ > 1)
00727     {
00728       return 1.0 /
00729              ( 1.0 + exp((wtsWhenFalse_[predIdx][chainIdx] - 
00730                           wtsWhenTrue_[predIdx][chainIdx]) *
00731                           invTemp));      
00732     }
00733     else
00734     {
00735       GroundPredicate* gndPred = state_->getGndPred(predIdx);
00736       return 1.0 /
00737              ( 1.0 + exp((gndPred->getWtWhenFalse() - 
00738                           gndPred->getWtWhenTrue()) *
00739                           invTemp));
00740     }
00741   }
00742  
00751   void setOthersInBlockToFalse(const int& chainIdx, const int& atomIdx,
00752                                const int& blockIdx)
00753   {
00754     int blockSize = state_->getDomain()->getBlockSize(blockIdx);
00755     for (int i = 0; i < blockSize; i++)
00756     {
00757       const Predicate* pred = state_->getDomain()->getPredInBlock(i, blockIdx);
00758       GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00759       int idx = state_->getIndexOfGroundPredicate(gndPred);
00760 
00761       delete gndPred;
00762       delete pred;
00763 
00764         // Pred is in the state
00765       if (idx >= 0 && idx != atomIdx)
00766         truthValues_[idx][chainIdx] = false;
00767     }
00768   }
00769 
00780   void performGibbsStep(const int& chainIdx, const bool& burningIn,
00781                         GroundPredicateHashArray& affectedGndPreds,
00782                         Array<int>& affectedGndPredIndices)
00783   {
00784     if (mcmcdebug) cout << "Gibbs step" << endl;
00785 
00786       // For each block: select one to set to true
00787     for (int i = 0; i < state_->getDomain()->getNumPredBlocks(); i++)
00788     {
00789         // If evidence atom exists, then all others stay false
00790       if (state_->getDomain()->getBlockEvidence(i)) continue;
00791 
00792       int chosen = gibbsSampleFromBlock(chainIdx, i, 1);
00793         // Truth values are stored differently for multi-chain
00794       bool truthValue;
00795       const Predicate* pred = state_->getDomain()->getPredInBlock(chosen, i);
00796       GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
00797       int idx = state_->getIndexOfGroundPredicate(gndPred);
00798 
00799       delete gndPred;
00800       delete pred;
00801       
00802         // If gnd pred in state:
00803       if (idx >= 0)
00804       {
00805         gndPred = state_->getGndPred(idx);
00806         if (numChains_ > 1) truthValue = truthValues_[idx][chainIdx];
00807         else truthValue = gndPred->getTruthValue();
00808           // If chosen pred was false, then need to set previous true
00809           // one to false and update wts
00810         if (!truthValue)
00811         {
00812           int blockSize = state_->getDomain()->getBlockSize(i);
00813           for (int j = 0; j < blockSize; j++)
00814           {
00815               // Truth values are stored differently for multi-chain
00816             bool otherTruthValue;
00817             const Predicate* otherPred = 
00818               state_->getDomain()->getPredInBlock(j, i);
00819             GroundPredicate* otherGndPred =
00820               new GroundPredicate((Predicate*)otherPred);
00821             int otherIdx = state_->getIndexOfGroundPredicate(gndPred);
00822 
00823             delete otherGndPred;
00824             delete otherPred;
00825       
00826               // If gnd pred in state:
00827             if (otherIdx >= 0)
00828             {
00829               otherGndPred = state_->getGndPred(otherIdx);
00830               if (numChains_ > 1)
00831                 otherTruthValue = truthValues_[otherIdx][chainIdx];
00832               else
00833                 otherTruthValue = otherGndPred->getTruthValue();
00834               if (otherTruthValue)
00835               {
00836                   // Truth values are stored differently for multi-chain
00837                 if (numChains_ > 1)
00838                   truthValues_[otherIdx][chainIdx] = false;
00839                 else
00840                   otherGndPred->setTruthValue(false);
00841               
00842                 affectedGndPreds.clear();
00843                 affectedGndPredIndices.clear();
00844                 gndPredFlippedUpdates(otherIdx, chainIdx, affectedGndPreds,
00845                                       affectedGndPredIndices);
00846                 updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00847                                      chainIdx);
00848               }
00849             }
00850           }
00851             // Set truth value and update wts for chosen atom
00852             // Truth values are stored differently for multi-chain
00853           if (numChains_ > 1) truthValues_[idx][chainIdx] = true;
00854           else gndPred->setTruthValue(true);
00855           affectedGndPreds.clear();
00856           affectedGndPredIndices.clear();
00857           gndPredFlippedUpdates(idx, chainIdx, affectedGndPreds,
00858                                 affectedGndPredIndices);
00859           updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00860                                chainIdx);
00861         }
00862           // If in actual gibbs sampling phase, track the num of times
00863           // the ground predicate is set to true
00864         if (!burningIn) numTrue_[idx]++;
00865       }
00866     }
00867 
00868       // Now go through all preds not in blocks
00869     for (int i = 0; i < state_->getNumAtoms(); i++)
00870     {
00871         // Predicates in blocks have been handled above
00872       if (state_->getBlockIndex(i) >= 0) continue;
00873 
00874       if (mcmcdebug)
00875       {
00876         cout << "Chain " << chainIdx << ": Probability of pred "
00877              << i << " is " << getProbabilityOfPred(i, chainIdx, 1) << endl;
00878       }
00879       
00880       bool newAssignment
00881         = genTruthValueForProb(getProbabilityOfPred(i, chainIdx, 1));
00882 
00883         // Truth values are stored differently for multi-chain
00884       bool truthValue;
00885       GroundPredicate* gndPred = state_->getGndPred(i);
00886       if (numChains_ > 1) truthValue = truthValues_[i][chainIdx];
00887       else truthValue = gndPred->getTruthValue();
00888         // If gndPred is flipped, do updates & find all affected gndPreds
00889       if (newAssignment != truthValue)
00890       {
00891         if (mcmcdebug)
00892         {
00893           cout << "Chain " << chainIdx << ": Changing truth value of pred "
00894                << i << " to " << newAssignment << endl;
00895         }
00896         
00897         if (numChains_ > 1) truthValues_[i][chainIdx] = newAssignment;
00898         else gndPred->setTruthValue(newAssignment);
00899         affectedGndPreds.clear();
00900         affectedGndPredIndices.clear();
00901         gndPredFlippedUpdates(i, chainIdx, affectedGndPreds,
00902                               affectedGndPredIndices);
00903         updateWtsForGndPreds(affectedGndPreds, affectedGndPredIndices,
00904                              chainIdx);
00905       }
00906 
00907         // If in actual gibbs sampling phase, track the num of times
00908         // the ground predicate is set to true
00909       if (!burningIn && newAssignment) numTrue_[i]++;
00910     }
00911       // If keeping track of true clause groundings
00912     if (!burningIn && trackClauseTrueCnts_)
00913       state_->getNumClauseGndings(clauseTrueCnts_, true);
00914 
00915     if (mcmcdebug) cout << "End of Gibbs step" << endl;
00916   }
00917 
00926   void updateWtsForGndPredsH(GroundPredicateHashArray& gndPreds,
00927           Array<int>& gndPredIndices, const int& chainIdx)
00928   {
00929           if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
00930           // for each ground predicate whose MB has changed
00931           for (int g = 0; g < gndPreds.size(); g++)
00932           {
00933                   double wtIfNoChange = 0, wtIfInverted = 0, wt;
00934                   // Ground clauses in which this pred occurs
00935                   Array<int>& negGndClauses =
00936                           hstate_->getNegOccurenceArray(gndPredIndices[g] + 1);
00937                   Array<int>& posGndClauses =
00938                           hstate_->getPosOccurenceArray(gndPredIndices[g] + 1);
00939 
00940                   int gndClauseIdx;
00941                   bool sense;
00942                   if (mcmcdebug)
00943                   {
00944                           cout << "Ground clauses in which pred " << g << " occurs neg.: "
00945                                   << negGndClauses.size() << endl;
00946                           cout << "Ground clauses in which pred " << g << " occurs pos.: "
00947                                   << posGndClauses.size() << endl;
00948                   }
00949 
00950                   for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
00951                   {
00952                           if (i < negGndClauses.size())
00953                           {
00954                                   gndClauseIdx = negGndClauses[i];
00955                                   if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
00956                                   sense = false;
00957                           }
00958                           else
00959                           {
00960                                   gndClauseIdx = posGndClauses[i - negGndClauses.size()];
00961                                   if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
00962                                   sense = true;
00963                           }
00964 
00965                           GroundClause* gndClause = hstate_->getGndClause(gndClauseIdx);
00966                           if (gndClause->isHardClause())
00967                                   wt = hstate_->getClauseCost(gndClauseIdx);
00968                           else
00969                                   wt = gndClause->getWt();
00970                           // NumTrueLits are stored differently for multi-chain
00971                           int numSatLiterals;
00972                           if (numChains_ > 1)
00973                                   numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
00974                           else
00975                                   numSatLiterals = hstate_->getNumTrueLits(gndClauseIdx);
00976                           if (numSatLiterals > 1)
00977                           {
00978                                   // Some other literal is making it sat, so it doesn't matter
00979                                   // if pos. clause. If neg., nothing can be done to unsatisfy it.
00980                                   if (wt > 0)
00981                                   {
00982                                           wtIfNoChange += wt;
00983                                           wtIfInverted += wt;
00984                                   }
00985                           }
00986                           else 
00987                                   if (numSatLiterals == 1) 
00988                                   {
00989                                           if (wt > 0) wtIfNoChange += wt;
00990                                           // Truth values are stored differently for multi-chain
00991                                           bool truthValue;
00992                                           if (numChains_ > 1)
00993                                                   truthValue = truthValues_[gndPredIndices[g]][chainIdx];
00994                                           else
00995                                                   truthValue = gndPreds[g]->getTruthValue();
00996                                           // If the current truth value is the same as its sense in gndClause
00997                                           if (truthValue == sense) 
00998                                           {
00999                                                   // This gndPred is the only one making this function satisfied
01000                                                   if (wt < 0) wtIfInverted += abs(wt);
01001                                           }
01002                                           else 
01003                                           {
01004                                                   // Some other literal is making it satisfied
01005                                                   if (wt > 0) wtIfInverted += wt;
01006                                           }
01007                                   }
01008                                   else
01009                                           if (numSatLiterals == 0) 
01010                                           {
01011                                                   // None satisfy, so when gndPred switch to its negative, it'll satisfy
01012                                                   if (wt > 0) wtIfInverted += wt;
01013                                                   else if (wt < 0) wtIfNoChange += abs(wt);
01014                                           }
01015                   } // for each ground clause that gndPred appears in
01016 
01017                   if (mcmcdebug)
01018                   {
01019                           cout << "wtIfNoChange of pred " << g << ": "
01020                                   << wtIfNoChange << endl;
01021                           cout << "wtIfInverted of pred " << g << ": "
01022                                   << wtIfInverted << endl;
01023                   }
01024 
01025                   // Clause info is stored differently for multi-chain
01026                   if (numChains_ > 1)
01027                   {
01028                           if (truthValues_[gndPredIndices[g]][chainIdx]) 
01029                           {
01030                                   wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01031                                   wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01032                           }
01033                           else 
01034                           {
01035                                   wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01036                                   wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01037                           }
01038                   }
01039                   else
01040                   { // Single chain
01041                           if (gndPreds[g]->getTruthValue())
01042                           {
01043                                   gndPreds[g]->setWtWhenTrue(wtIfNoChange);
01044                                   gndPreds[g]->setWtWhenFalse(wtIfInverted);
01045                           }
01046                           else
01047                           {
01048                                   gndPreds[g]->setWtWhenFalse(wtIfNoChange);
01049                                   gndPreds[g]->setWtWhenTrue(wtIfInverted);            
01050                           }
01051                   }
01052           } // for each ground predicate whose MB has changed
01053           if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
01054   }
01055 
01056   
01057 
01058 
01067   void updateWtsForGndPreds(GroundPredicateHashArray& gndPreds,
01068                             Array<int>& gndPredIndices,
01069                             const int& chainIdx)
01070   {
01071     if (mcmcdebug) cout << "Entering updateWtsForGndPreds" << endl;
01072       // for each ground predicate whose MB has changed
01073     for (int g = 0; g < gndPreds.size(); g++)
01074     {
01075       double wtIfNoChange = 0, wtIfInverted = 0, wt;
01076         // Ground clauses in which this pred occurs
01077       Array<int>& negGndClauses =
01078         state_->getNegOccurenceArray(gndPredIndices[g] + 1);
01079       Array<int>& posGndClauses =
01080         state_->getPosOccurenceArray(gndPredIndices[g] + 1);
01081       int gndClauseIdx;
01082       bool sense;
01083       
01084       if (mcmcdebug)
01085       {
01086         cout << "Ground clauses in which pred " << g << " occurs neg.: "
01087              << negGndClauses.size() << endl;
01088         cout << "Ground clauses in which pred " << g << " occurs pos.: "
01089              << posGndClauses.size() << endl;
01090       }
01091       
01092       for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
01093       {
01094         if (i < negGndClauses.size())
01095         {
01096           gndClauseIdx = negGndClauses[i];
01097           if (mcmcdebug) cout << "Neg. in clause " << gndClauseIdx << endl;
01098           sense = false;
01099         }
01100         else
01101         {
01102           gndClauseIdx = posGndClauses[i - negGndClauses.size()];
01103           if (mcmcdebug) cout << "Pos. in clause " << gndClauseIdx << endl;
01104           sense = true;
01105         }
01106         
01107         GroundClause* gndClause = state_->getGndClause(gndClauseIdx);
01108         if (gndClause->isHardClause())
01109           wt = state_->getClauseCost(gndClauseIdx);
01110         else
01111           wt = gndClause->getWt();
01112           // NumTrueLits are stored differently for multi-chain
01113         int numSatLiterals;
01114         if (numChains_ > 1)
01115           numSatLiterals = numTrueLits_[gndClauseIdx][chainIdx];
01116         else
01117           numSatLiterals = state_->getNumTrueLits(gndClauseIdx);
01118 
01119         if (numSatLiterals > 1)
01120         {
01121             // Some other literal is making it sat, so it doesn't matter
01122             // if pos. clause. If neg., nothing can be done to unsatisfy it.
01123           if (wt > 0)
01124           {
01125             wtIfNoChange += wt;
01126             wtIfInverted += wt;
01127           }
01128         }
01129         else 
01130         if (numSatLiterals == 1) 
01131         {
01132           if (wt > 0) wtIfNoChange += wt;
01133             // Truth values are stored differently for multi-chain
01134           bool truthValue;
01135           if (numChains_ > 1)
01136             truthValue = truthValues_[gndPredIndices[g]][chainIdx];
01137           else
01138             truthValue = gndPreds[g]->getTruthValue();
01139             // If the current truth value is the same as its sense in gndClause
01140           if (truthValue == sense) 
01141           {
01142             // This gndPred is the only one making this function satisfied
01143             if (wt < 0) wtIfInverted += abs(wt);
01144           }
01145           else 
01146           {
01147               // Some other literal is making it satisfied
01148             if (wt > 0) wtIfInverted += wt;
01149           }
01150         }
01151         else
01152         if (numSatLiterals == 0) 
01153         {
01154           // None satisfy, so when gndPred switch to its negative, it'll satisfy
01155           if (wt > 0) wtIfInverted += wt;
01156           else if (wt < 0) wtIfNoChange += abs(wt);
01157         }
01158       } // for each ground clause that gndPred appears in
01159 
01160       if (mcmcdebug)
01161       {
01162         cout << "wtIfNoChange of pred " << g << ": "
01163              << wtIfNoChange << endl;
01164         cout << "wtIfInverted of pred " << g << ": "
01165              << wtIfInverted << endl;
01166       }
01167 
01168         // Clause info is stored differently for multi-chain
01169       if (numChains_ > 1)
01170       {
01171         if (truthValues_[gndPredIndices[g]][chainIdx]) 
01172         {
01173           wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01174           wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01175         }
01176         else 
01177         {
01178           wtsWhenFalse_[gndPredIndices[g]][chainIdx] = wtIfNoChange;
01179           wtsWhenTrue_[gndPredIndices[g]][chainIdx] = wtIfInverted;
01180         }
01181       }
01182       else
01183       { // Single chain
01184         if (gndPreds[g]->getTruthValue())
01185         {
01186           gndPreds[g]->setWtWhenTrue(wtIfNoChange);
01187           gndPreds[g]->setWtWhenFalse(wtIfInverted);
01188         }
01189         else
01190         {
01191           gndPreds[g]->setWtWhenFalse(wtIfNoChange);
01192           gndPreds[g]->setWtWhenTrue(wtIfInverted);            
01193         }
01194       }
01195     } // for each ground predicate whose MB has changed
01196     if (mcmcdebug) cout << "Leaving updateWtsForGndPreds" << endl;
01197   }
01198 
01208   int gibbsSampleFromBlock(const int& chainIdx, const int& blockIndex,
01209                            const double& invTemp)
01210   {
01211     Array<double> numerators;
01212     double denominator = 0;
01213 
01214     int blockSize = state_->getDomain()->getBlockSize(blockIndex);
01215     for (int i = 0; i < blockSize; i++)
01216     {
01217       const Predicate* pred =
01218         state_->getDomain()->getPredInBlock(i, blockIndex);
01219       GroundPredicate* gndPred = new GroundPredicate((Predicate*)pred);
01220       int idx = state_->getIndexOfGroundPredicate(gndPred);
01221       
01222       delete gndPred;
01223       delete pred;
01224 
01225         // Prob is 0 if atom not in state
01226       double prob = 0.0;
01227         // Pred is in the state; otherwise, prob is zero
01228       if (idx >= 0)
01229         prob = getProbabilityOfPred(idx, chainIdx, invTemp);
01230 
01231       numerators.append(prob);
01232       denominator += prob;
01233     }
01234 
01235     double r = random();
01236     double numSum = 0.0;
01237     for (int i = 0; i < blockSize; i++)
01238     {
01239       numSum += numerators[i];
01240       if (r < ((numSum / denominator) * RAND_MAX))
01241       {
01242         return i;
01243       }
01244     }
01245     return blockSize - 1;
01246   }
01247 
01256   void gndPredFlippedUpdates(const int& gndPredIdx, const int& chainIdx,
01257                              GroundPredicateHashArray& affectedGndPreds,
01258                              Array<int>& affectedGndPredIndices)
01259   {
01260     if (mcmcdebug) cout << "Entering gndPredFlippedUpdates" << endl;
01261     int numAtoms = state_->getNumAtoms();
01262     GroundPredicate* gndPred = state_->getGndPred(gndPredIdx);
01263     affectedGndPreds.append(gndPred, numAtoms);
01264     affectedGndPredIndices.append(gndPredIdx);
01265     assert(affectedGndPreds.size() <= numAtoms);
01266 
01267     Array<int>& negGndClauses =
01268       state_->getNegOccurenceArray(gndPredIdx + 1);
01269     Array<int>& posGndClauses =
01270       state_->getPosOccurenceArray(gndPredIdx + 1);
01271     int gndClauseIdx;
01272     GroundClause* gndClause; 
01273     bool sense;
01274 
01275       // Find the Markov blanket of this ground predicate
01276     for (int i = 0; i < negGndClauses.size() + posGndClauses.size(); i++)
01277     {
01278       if (i < negGndClauses.size())
01279       {
01280         gndClauseIdx = negGndClauses[i];
01281         sense = false;
01282       }
01283       else
01284       {
01285         gndClauseIdx = posGndClauses[i - negGndClauses.size()];
01286         sense = true;
01287       }
01288       gndClause = state_->getGndClause(gndClauseIdx);
01289 
01290         // Different for multi-chain
01291       if (numChains_ > 1)
01292       {
01293         if (truthValues_[gndPredIdx][chainIdx] == sense)
01294           numTrueLits_[gndClauseIdx][chainIdx]++;
01295         else
01296           numTrueLits_[gndClauseIdx][chainIdx]--;
01297       }
01298       else
01299       { // Single chain
01300         if (gndPred->getTruthValue() == sense)
01301           state_->incrementNumTrueLits(gndClauseIdx);
01302         else
01303           state_->decrementNumTrueLits(gndClauseIdx);
01304       }
01305       
01306       for (int j = 0; j < gndClause->getNumGroundPredicates(); j++)
01307       {
01308         const GroundPredicateHashArray* gpha = state_->getGndPredHashArrayPtr();
01309         GroundPredicate* pred = 
01310           (GroundPredicate*)gndClause->getGroundPredicate(j,
01311             (GroundPredicateHashArray*)gpha);
01312         affectedGndPreds.append(pred, numAtoms);
01313         affectedGndPredIndices.append(
01314                                abs(gndClause->getGroundPredicateIndex(j)) - 1);
01315         assert(affectedGndPreds.size() <= numAtoms);
01316       }
01317     }
01318     if (mcmcdebug) cout << "Leaving gndPredFlippedUpdates" << endl;
01319   }
01320 
01321   double getProbTrue(const int& predIdx) const { return numTrue_[predIdx]; }
01322   
01323   void setProbTrue(const int& predIdx, const double& p)
01324   { 
01325     assert(p >= 0);
01326     numTrue_[predIdx] = p;
01327   }
01328 
01335   void saveLowStateToChain(const int& chainIdx)
01336   {
01337     for (int i = 0; i < state_->getNumAtoms(); i++)
01338       truthValues_[i][chainIdx] = state_->getValueOfLowAtom(i + 1);
01339   }
01340 
01346   void setMCMCParameters(MCMCParams* params)
01347   {
01348       // User-set parameters
01349     numChains_ = params->numChains;
01350     burnMinSteps_ = params->burnMinSteps;
01351     burnMaxSteps_ = params->burnMaxSteps;
01352     minSteps_ = params->minSteps;
01353     maxSteps_ = params->maxSteps;
01354     maxSeconds_ = params->maxSeconds;    
01355   }
01356 
01357   void scaleSamples(double factor)
01358   {
01359     minSteps_ = (int)(minSteps_ * factor);
01360     maxSteps_ = (int)(maxSteps_ * factor);
01361   }
01362   
01363  public: 
01364   // truth value assignment for continuous variables for each chain
01365   Array<Array<double> > truthValuesCont_;
01366   // storing values for the discrete parts of hybrid clauses for each chain
01367   Array<Array<bool> > hybridClauseDisPartValueMCMC_; 
01368   // storing values for the continuous parts of hybrid clauses for each chain
01369   Array<Array<double> > hybridClauseContPartValueMCMC_; 
01370   
01371  protected:
01372  
01374     // No. of chains which MCMC will use
01375   int numChains_;
01376     // Min. no. of burn-in steps MCMC will take per chain
01377   int burnMinSteps_;
01378     // Max. no. of burn-in steps MCMC will take per chain
01379   int burnMaxSteps_;
01380     // Min. no. of sampling steps MCMC will take per chain
01381   int minSteps_;
01382     // Max. no. of sampling steps MCMC will take per chain
01383   int maxSteps_;
01384     // Max. no. of seconds MCMC should run
01385   int maxSeconds_;
01387 
01388     // Truth values in each chain for each ground predicate (truthValues_[p][c])
01389   Array<Array<bool> > truthValues_;
01390     // Wts when false in each chain for each ground predicate
01391   Array<Array<double> > wtsWhenFalse_;
01392     // Wts when true in each chain for each groud predicate
01393   Array<Array<double> > wtsWhenTrue_;
01394 
01395     // Number of times each ground predicate is set to true
01396     // overloaded to hold probability that ground predicate is true
01397   Array<double> numTrue_; // numTrue_[p]
01398 
01399     // Num. of satisfying literals in each chain for each groud predicate
01400     // numTrueLits_[clause][chain]
01401   Array<Array<int> > numTrueLits_;
01402 };
01403 
01404 #endif /*MCMC_H_*/

Generated on Sun Jun 7 11:55:12 2009 for Alchemy by  doxygen 1.5.1