00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include "node.h"
00068
00069
00070
00071
00072
00073
00074
00075 void Node::addFactors(Array<Factor *> * const & allFactors,
00076 LinkIdToTwoWayMessageMap* const & lidToTWMsg)
00077 {
00078 Factor *factor;
00079 Link *link;
00080 double cnt;
00081 ClauseCounter * counter = (ClauseCounter *)superPred_->getClauseCounter();
00082 int numFactors = counter->getNumClauses();
00083 const Array<double> * cnts;
00084
00085 for (int findex = 0; findex < numFactors; findex++)
00086 {
00087 int fid = counter->getClauseId(findex);
00088 cnts = counter->getClauseCounts(findex);
00089 factor = (*allFactors)[fid];
00090
00091 for (int predIndex = 0; predIndex < cnts->size(); predIndex++)
00092 {
00093 cnt = (*cnts)[predIndex];
00094 if ((*cnts)[predIndex] == 0)
00095 continue;
00096
00097
00098 int reverseNodeIndex = factor->getNumLinks();
00099
00100 int reverseFactorIndex = getNumLinks();
00101
00102 link = new Link(this, factor, reverseNodeIndex, reverseFactorIndex,
00103 predIndex, cnt);
00104
00105
00106 LinkId *lid;
00107 TwoWayMessage *tmsg;
00108 double *nodeToFactorMsgs, *factorToNodeMsgs;
00109 LinkIdToTwoWayMessageMap::iterator lidToTMsgItr;
00110 int parentSuperPredId = getParentSuperPredId();
00111 int parentSuperClauseId = factor->getParentSuperClauseId();
00112
00113 lid = new LinkId(predId_, parentSuperPredId, parentSuperClauseId,
00114 predIndex);
00115 lidToTMsgItr = lidToTWMsg->find(lid);
00116 delete lid;
00117
00118 if (lidToTMsgItr != lidToTWMsg->end())
00119 {
00120 tmsg = lidToTMsgItr->second;
00121 nodeToFactorMsgs = tmsg->getNodeToFactorMessage();
00122 factorToNodeMsgs = tmsg->getFactorToNodeMessage();
00123 }
00124 else
00125 {
00126 nodeToFactorMsgs = NULL;
00127 factorToNodeMsgs = NULL;
00128 }
00129 this->addLink(link, nodeToFactorMsgs);
00130 factor->addLink(link,factorToNodeMsgs);
00131 }
00132 }
00133 }
00134
00135
00136 void Node::sendMessage()
00137 {
00138 Link *link;
00139 Factor *factor;
00140 double cnt;
00141 double *msgs;
00142 double *outMsgs = new double[2];
00143
00144 for (int lindex = 0; lindex < links_->size(); lindex++)
00145 {
00146 link = (*links_)[lindex];
00147 factor = link->getFactor();
00148 cnt = link->getCount();
00149
00150 msgs = (*msgsArr_)[lindex];
00151 for (int i = 0; i < 2; i++)
00152 {
00153 outMsgs[i] = msgProds_[i] - msgs[i];
00154 }
00155
00156
00157 factor->receiveMessage(outMsgs, link);
00158 }
00159 delete [] outMsgs;
00160 }
00161
00162
00163 void Node::sendAuxMessage()
00164 {
00165 Link *link;
00166 Factor *factor;
00167 double cnt;
00168
00169 double *outMsgs = new double[2];
00170
00171 for (int lindex = 0; lindex < auxLinks_->size(); lindex++)
00172 {
00173 link = (*auxLinks_)[lindex];
00174 factor = link->getFactor();
00175 cnt = link->getCount();
00176 for (int i = 0; i < 2; i++)
00177 {
00178 outMsgs[i] = msgProds_[i];
00179 }
00180
00181
00182 factor->receiveMessage(outMsgs, link);
00183 }
00184 delete [] outMsgs;
00185 }
00186
00187
00188
00189 void Node::moveToNextStep()
00190 {
00191 double * msgs;
00192 double * nextMsgs;
00193 double cnt;
00194
00195
00196 for (int i = 0; i < 2; i++)
00197 {
00198 msgProds_[i] = 0;
00199 }
00200
00201
00202
00203 for (int lindex = 0; lindex < links_->size(); lindex++)
00204 {
00205 msgs = (*msgsArr_)[lindex];
00206 nextMsgs = (*nextMsgsArr_)[lindex];
00207
00208
00209
00210
00211
00212
00213 cnt = (*links_)[lindex]->getCount();
00214 for (int i = 0; i < 2; i++)
00215 {
00216 msgs[i] = nextMsgs[i];
00217 nextMsgs[i] = 0;
00218
00219 msgProds_[i] = msgProds_[i] + cnt*msgs[i];
00220 }
00221 }
00222 }
00223