00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066 #include "clausesampler.h"
00067 #include "clause.h"
00068
00069
00070 double ClauseSampler::estimateNumTrueGroundings(Clause* const& clause,
00071 const Predicate*const& flippedGndPred,
00072 const Domain* const& domain,
00073 double numSamples)
00074 {
00075 Array<VarsGroundedType*>* vgtArr = new Array<VarsGroundedType*>;
00076 clause->createVarIdToVarsGroundedType(domain, vgtArr);
00077
00078 Database* db = domain->getDB();
00079 TrueFalseGroundingsStore* tfGndingsStore
00080 = domain->getTrueFalseGroundingsStore();
00081
00082 int numPreds = clause->getNumPredicates();
00083 assert(numPreds > 1);
00084 if (numSamples <= 0) numSamples = computeNumSamples(numPreds);
00085
00086
00087 Array<float> corrProbLimit(numPreds);
00088 double sumTrueGnds;
00089 Array<double> numTrueGndsPerPos;
00090 getProbInfo(clause, domain->getDB(), flippedGndPred, vgtArr, tfGndingsStore,
00091 corrProbLimit, sumTrueGnds, numTrueGndsPerPos);
00092
00093 double estX = 0, x = 0, x2 = 0, tmp, ave, stddev, sampNeeded;
00094 int i, numj;
00095 Array<int>* samp;
00096 Array<Array<int>*> samples;
00097 Array<double> sampleWt;
00098
00099 int s = 0;
00100 if (sumTrueGnds > 0)
00101 {
00102 for (; s < numSamples; s++)
00103 {
00104 i = choosePredPos(corrProbLimit);
00105 if (i < 0) continue;
00106 samp = chooseSample(clause, vgtArr, domain, i, flippedGndPred);
00107 numj = testSampleMembership(clause, vgtArr, db, *samp, sumTrueGnds);
00108
00109 assert(numj > 0);
00110 tmp = sumTrueGnds / (double)numj;
00111 estX += tmp;
00112 x += tmp;
00113 x2 += tmp*tmp;
00114 delete samp;
00115
00116 if (s > 0 && s % NUM_SAMPLES_TEST_CONV == 0)
00117 {
00118 ave = x/(s+1);
00119 stddev = sqrt(x2/(s+1) - ave*ave);
00120 sampNeeded = getNumSamplesNeeded(stddev, 1.0-delta_, epsilon_*ave);
00121 if (s+1 > sampNeeded) { s++; break; }
00122 x = 0;
00123 x2 = 0;
00124 }
00125 }
00126 }
00127
00128 clause->deleteVarIdToVarsGroundedType(vgtArr);
00129
00130 if (sumTrueGnds == 0) return 0;
00131 return estX/s;
00132 }
00133
00134
00135 void
00136 ClauseSampler::getNumGroundingWithEachPredGrounded(const Clause* const & clause,
00137 Array<double>& gndingsWithPredGnded,
00138 const Array<VarsGroundedType*>& vgtArr)
00139 {
00140 double totalNumGndings = 1;
00141 for (int i = 1; i < vgtArr.size(); i++)
00142 {
00143 if (vgtArr[i] == NULL) continue;
00144 totalNumGndings *= vgtArr[i]->numGndings;
00145 }
00146
00147
00148 Array<bool> varSeen; varSeen.growToSize(vgtArr.size(),false);
00149 for (int i = 0; i < clause->getNumPredicates(); i++)
00150 {
00151 double numGndings = 1;
00152 memset((void*)varSeen.getItems(), false, varSeen.size()*sizeof(bool));
00153 Predicate* pred = clause->getPredicate(i);
00154 for (int j = 0; j < pred->getNumTerms(); j++)
00155 {
00156 const Term* t = pred->getTerm(j);
00157 int varId = -(t->getId());
00158 if (varId > 0)
00159 {
00160 if (!varSeen[varId])
00161 {
00162 varSeen[varId] = true;
00163 numGndings *= vgtArr[varId]->numGndings;
00164 }
00165 }
00166 }
00167 gndingsWithPredGnded.append(totalNumGndings/numGndings);
00168 }
00169 }
00170
00171
00172 void ClauseSampler::getProbInfo(const Clause* const & clause,
00173 const Database* const & db,
00174 const Predicate* const & flippedGndPred,
00175 const Array<VarsGroundedType*>* const & vgtArr,
00176 TrueFalseGroundingsStore* const& tfGndingsStore,
00177 Array<float>& corrProbLimit,double& sumTrueGnds,
00178 Array<double>& numTrueGndsPerPos)
00179 {
00180 Predicate* pred;
00181 int numGndings;
00182 double si;
00183 sumTrueGnds = 0;
00184 Array<double> gndingsWithPredGnded;
00185 getNumGroundingWithEachPredGrounded(clause, gndingsWithPredGnded, *vgtArr);
00186
00187 for (int i = 0; i < clause->getNumPredicates(); i++)
00188 {
00189 pred = (Predicate*) clause->getPredicate(i);
00190 numGndings = tfGndingsStore->getNumTrueLiteralGroundings(pred,
00191 flippedGndPred);
00192 si = numGndings * gndingsWithPredGnded[i];
00193 numTrueGndsPerPos.append(si);
00194 sumTrueGnds += si;
00195 }
00196
00197 float p = 0;
00198 for (int i = 0; i < clause->getNumPredicates(); i++)
00199 {
00200 if (sumTrueGnds > 0)
00201 {
00202 p += numTrueGndsPerPos[i]/sumTrueGnds;
00203 corrProbLimit.append(p);
00204 }
00205 else
00206 corrProbLimit.append(0.0);
00207 }
00208 }
00209
00210
00211 Array<int>* ClauseSampler::chooseSample(const Clause* const & clause,
00212 const Array<VarsGroundedType*>* const & vgtArr,
00213 const Domain* const & domain,
00214 const int& predPos,
00215 const Predicate* const & flippedGndPred)
00216 {
00217 const Predicate* pred = clause->getPredicate(predPos);
00218 TrueFalseGroundingsStore* tfGndingsStore
00219 = domain->getTrueFalseGroundingsStore();
00220 Predicate* rtSensePred
00221 = tfGndingsStore->getRandomTrueLiteralGrounding(pred, flippedGndPred);
00222
00223 if (!rtSensePred)
00224 {
00225 cout << "ERROR: in ClauseSampler::chooseSample, no grounding found for ";
00226 pred->printAsInt(cout); cout << endl;
00227 exit(-1);
00228 }
00229
00230
00231 Array<int>* samp = new Array<int>;
00232 samp->growToSize(vgtArr->size(), -1);
00233
00234
00235 for (int i = 0; i < pred->getNumTerms(); i++)
00236 {
00237 int constId = rtSensePred->getTerm(i)->getId(); assert(constId >= 0);
00238 int varId = pred->getTerm(i)->getId();
00239 if (varId < 0) (*samp)[-varId] = constId;
00240 else assert(constId == varId);
00241 }
00242
00243 for (int i = 1; i < vgtArr->size(); i++)
00244 {
00245
00246 if ((*vgtArr)[i] == NULL || (*samp)[i] >= 0) continue;
00247 const Array<int>* constants
00248 = domain->getConstantsByType((*vgtArr)[i]->typeId);
00249 assert(constants->size() > 0);
00250 int idx = random_.randomOneOf(constants->size());
00251 (*samp)[i] = (*constants)[idx];
00252
00253 }
00254
00255 delete rtSensePred;
00256 return samp;
00257 }
00258
00259
00260 void ClauseSampler::groundClause(const Array<VarsGroundedType*>& vgtArr,
00261 const Array<int>& samp)
00262 {
00263 assert(samp.size() == vgtArr.size());
00264 for (int i = 1; i < vgtArr.size(); i++)
00265 {
00266 if (vgtArr[i] == NULL) continue;
00267 Array<Term*>& vars = vgtArr[i]->vars;
00268 for (int j = 0 ; j < vars.size(); j++) vars[j]->setId(samp[i]);
00269 }
00270 }
00271
00272
00273 int ClauseSampler::testSampleMembership(Clause* const & clause,
00274 Array<VarsGroundedType*>* const& vgtArr,
00275 const Database* const & db,
00276 const Array<int>& samp,
00277 const double& sumTrueGnds)
00278 {
00279 groundClause(*vgtArr, samp);
00280 int numj = 0;
00281 for (int i = 0; i < clause->getNumPredicates(); i++)
00282 {
00283 TruthValue tv = db->getValue(clause->getPredicate(i));
00284 bool sense = clause->getPredicate(i)->getSense();
00285 if (Database::sameTruthValueAndSense(tv,sense)) numj++;
00286 }
00287 clause->restoreVars(vgtArr);
00288 return numj;
00289 }
00290