在线网站建设询问报价,郑州通告最新,成品网站 免费,做彩票网站代理Ranklib是一套优秀的Learning to Rank领域的开源实现#xff0c;其中有实现了MART#xff0c;RankNet#xff0c;RankBoost#xff0c;LambdaMart#xff0c;Random Forest等模型。其中由微软发布的LambdaMART是IR业内常用的Learning to Rank模型#xff0c;本文主要介绍…Ranklib是一套优秀的Learning to Rank领域的开源实现其中有实现了MARTRankNetRankBoostLambdaMartRandom Forest等模型。其中由微软发布的LambdaMART是IR业内常用的Learning to Rank模型本文主要介绍Ranklib中的LambdaMART模型的具体实现用以帮助理解paper中阐述的方法。本文是基于version2.3版本的Ranklib来介绍的。 LambdaMart的基本原理详见之前的博客http://www.cnblogs.com/bentuwuying/p/6690836.html。要知道LambdaMart是基于MART的而MART又是由若干棵regression tree组合而成的。所以我们先来看看Ranklib中是如何实现regression tree的以及在给定training data with labels的情况下regression tree是如何拟合的。 1. regression tree regression tree拟合给定training data的步骤总结概括如下 RegressionTreenodes #限制一棵树的最大叶子节点数minLeafSupport #控制分裂的次数如果某个节点所包含的训练数据小于2*minLeafSupport 则该节点不再分裂root #根节点leaves #叶子节点list构造函数RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport)对各个类变量进行初始化fit #对training data进行拟合regression tree新建一个队列queue用于按队列顺序(即按层遍历的顺序)进行分裂初始化一个regression tree的根节点rootroot.split #根节点分裂hist.findBestSplit #调用Split对象包含的FeatureHistogram对象的分裂方法在该节点的已经统计好的特征统计直方图的基础上寻找最佳分裂点进行分裂再计算左右子节点的特征统计直方图并对左右子节点进行初始化判断deviance为0则分裂不成功根据samplingRate决定usedFeatures分裂时需要使用的features的索引调用内部的findBestSplit方法在一个节点上在usedFeatures中根据该节点的特征统计直方图来进行分裂时feature和threshold的选择S sumLeft * sumLeft / countLeft sumRight * sumRight / countRight对每个可选的划分点feature和threshold组合求最大的S值对应于均方误差最小是最优的划分点判断划分是否成功若S-1则分裂不成功对该节点上的每个训练数据根据最优分裂点进行左右子节点的分配初始化分裂后左右子节点各自的特征统计直方图construct #一般用作父节点分裂后产生的左子节点的特征统计直方图的构造函数当使用父节点来构造时thresholds数组不变但是sum和count数组需要重新构造construct #一般用作父节点分裂后产生的右子节点的特征统计直方图的构造函数计算本节点和左右子节点的均方误差sp.set #调用FeatureHistogram对象所在的Split对象的方法一般在该节点进行分裂完成后设定分裂时的featureIDthresholddeviance只有非叶子节点才会进行分裂调用这个方法所以只有非叶子节点的featureID不为-1叶子节点由于没有调用这个方法故featureID-1初始化左子节点根据分裂到左子节点的训练数据索引数组左子节点的特征统计直方图左子节点的均方误差左子节点的训练数据label之和并设置到当前节点的左子节点变量上初始化右子节点根据分裂到右子节点的训练数据索引数组右子节点的特征统计直方图右子节点的均方误差右子节点的训练数据label之和并设置到当前节点的右子节点变量上insert #将左右的子节点插入队列用于下面遍历按均方误差从大到小的顺序进行插入队列循环按队列顺序(即按层遍历的顺序)进行分裂再将每次能够成功分裂的产生的两个子节点插入队列中根据根节点root的leaves类方法迭代遍历设置regression tree的leaves类变量 下面是regression tree拟合过程中涉及到的几个类文件代码关键部分都有添加了详细的注释。 1. FeatureHistogram 1 package ciir.umass.edu.learning.tree;2 import java.util.ArrayList;3 import java.util.Arrays;4 import java.util.List;5 import java.util.Random;6 import ciir.umass.edu.learning.DataPoint;7 import ciir.umass.edu.utilities.MyThreadPool;8 import ciir.umass.edu.utilities.WorkerThread;9 /**10 * author vdang11 */12 //特征直方图类对RankList对象进行特征的直方图统计选择每次split时最优的feature和划分点13 public class FeatureHistogram {14 // 存放分裂时的featureIdxthresholdIdx以及评判是否最佳分裂的评分值sumLeft*sumLeft/countLeft sumRight*sumRight/countRight15 class Config {16 int featureIdx -1;17 int thresholdIdx -1;18 double S -1;19 }20 21 //Parameter22 public static float samplingRate 1; //采样率用于对分裂时使用的feature个数进行采样不使用所有的feature23 24 //Variables25 public int[] features null; //feature数组每个元素是一个feature id(fid)26 public float[][] thresholds null; //二维数组第一维是feature下标是相应的features的下标不是feature id第二维是阈值个数为所有训练数据在此feature上的value的去重个数从小到大排序的不重复值用于对此节点的训练数据在此feature上分裂时可选的feature value阈值27 public double[][] sum null; //二维数组第一维是feature下标是相应的features的下标不是feature id第二维是label之和是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的label之和sum二维数组大小与thresholds数组相同28 public double sumResponse 0; //所有的训练数据的label之和29 public double sqSumResponse 0; //所有的训练数据的label的平方和30 public int[][] count null; //二维数组第一维是feature下标是相应的features的下标不是feature id第二维是个数是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的个数count二维数组大小与thresholds数组相同31 public int[][] sampleToThresholdMap null; //二维数组第一维是feature下标是相应的features的下标不是feature id第二维是索引是对应训练数据samples[i][j]在特定feature上每个训练数据的value对应于其在thresholds数组中相应行的列索引位置32 33 //whether to re-use its parents sum and count instead of cleaning up the parent and re-allocate for the children.34 //sum and count of any intermediate tree node (except for root) can be re-used. 35 private boolean reuseParent false;36 37 public FeatureHistogram()38 {39 40 }41 42 //FeatureHistogram构造函数(1-1)一般用作整棵树/根节点的feature histogram计算该节点的特征统计直方图43 //samples: 训练数据44 //labels: 训练数据的label45 //sampleSortedIdx: 将样本根据特征排序方便做树的分列时快速找出最优分列点sorted list of samples by each feature, need initializing only once初始化可见LambdaMART.java中的init()46 //features: 训练数据的特征集合47 //thresholds: 创建存放候选阈值分列点的表a table of candidate thresolds for each feature, we will select the best tree split from these candidates later on48 初始化可见LambdaMART.java中的init()此二维数组的每一行的最后一列的值是后加的为Float.MAX_VALUE49 public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds)50 {51 this.features features;52 this.thresholds thresholds;53 54 sumResponse 0;55 sqSumResponse 0;56 57 sum new double[features.length][];58 count new int[features.length][];59 sampleToThresholdMap new int[features.length][];60 61 //确定是否使用多线程计算62 MyThreadPool p MyThreadPool.getInstance();63 if(p.size() 1)64 construct(samples, labels, sampleSortedIdx, thresholds, 0, features.length-1);65 else66 p.execute(new Worker(this, samples, labels, sampleSortedIdx, thresholds), features.length); 67 }68 //FeatureHistogram构造函数(1-2)被(1-1)调用69 protected void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds, int start, int end)70 {71 for(int istart;iend;i) //对于每个feature72 {73 int fid features[i]; // 获取feature id74 //get the list of samples associated with this node (sorted in ascending order with respect to the current feature)75 int[] idx sampleSortedIdx[i]; //根据此feature下的value从小到大排序后的训练数据的索引数组76 77 double sumLeft 0; //累计此值用于给sumLabel使用78 float[] threshold thresholds[i];79 double[] sumLabel new double[threshold.length]; //对应前面sum二维数组的一行80 int[] c new int[threshold.length]; //对应前面count二维数组的一行81 int[] stMap new int[samples.length]; //对应前面sampleToThresholdMap二维数组的一行82 83 int last -1;84 for(int t0;tthreshold.length;t) //对于每个可选的split阈值85 {86 int jlast1;87 //find the first sample that exceeds the current threshold88 for(;jidx.length;j)89 {90 int k idx[j]; //获取此DataPoint在samples数组中的索引91 if(samples[k].getFeatureValue(fid) threshold[t])92 break;93 sumLeft labels[k];94 if(i 0)95 {96 sumResponse labels[k];97 sqSumResponse labels[k] * labels[k];98 }99 stMap[k] t;
100 }
101 last j-1;
102 sumLabel[t] sumLeft;
103 c[t] last1;
104 }
105 sampleToThresholdMap[i] stMap;
106 sum[i] sumLabel;
107 count[i] c;
108 }
109 }
110
111 //update(1-1), update the histogram with these training labels (the feature histogram will be used to find the best tree split)
112 protected void update(double[] labels)
113 {
114 sumResponse 0;
115 sqSumResponse 0;
116
117
118 //确定是否使用多线程计算
119 MyThreadPool p MyThreadPool.getInstance();
120 if(p.size() 1)
121 update(labels, 0, features.length-1);
122 else
123 p.execute(new Worker(this, labels), features.length);
124 }
125
126 //update(1-2)被(1-1)调用
127 protected void update(double[] labels, int start, int end)
128 {
129 for(int fstart;fend;f)
130 Arrays.fill(sum[f], 0);
131 for(int k0;klabels.length;k)
132 {
133 for(int fstart;fend;f)
134 {
135 int t sampleToThresholdMap[f][k];
136 sum[f][t] labels[k];
137 if(f 0)
138 {
139 sumResponse labels[k];
140 sqSumResponse labels[k]*labels[k];
141 }
142 //count doesnt change, so no need to re-compute
143 }
144 }
145 for(int fstart;fend;f)
146 {
147 for(int t1;tthresholds[f].length;t)
148 sum[f][t] sum[f][t-1];
149 }
150 }
151
152 //FeatureHistogram构造函数(2-1)一般用作父节点分裂后产生的左子节点的特征统计直方图的构造函数
153 //当使用父节点来构造时thresholds数组不变但是sum和count数组需要重新构造
154 //soi: 使用的训练数据的索引位置
155 public void construct(FeatureHistogram parent, int[] soi, double[] labels)
156 {
157 this.features parent.features;
158 this.thresholds parent.thresholds;
159 sumResponse 0;
160 sqSumResponse 0;
161 sum new double[features.length][];
162 count new int[features.length][];
163 sampleToThresholdMap parent.sampleToThresholdMap;
164
165
166 //确定是否使用多线程计算
167 MyThreadPool p MyThreadPool.getInstance();
168 if(p.size() 1)
169 construct(parent, soi, labels, 0, features.length-1);
170 else
171 p.execute(new Worker(this, parent, soi, labels), features.length);
172 }
173
174 //FeatureHistogram构造函数(2-2)被(2-1)调用
175 protected void construct(FeatureHistogram parent, int[] soi, double[] labels, int start, int end)
176 {
177 //init
178 for(int istart;iend;i)
179 {
180 float[] threshold thresholds[i];
181 sum[i] new double[threshold.length];
182 count[i] new int[threshold.length];
183 Arrays.fill(sum[i], 0);
184 Arrays.fill(count[i], 0);
185 }
186
187 //update
188 for(int i0;isoi.length;i)
189 {
190 int k soi[i];
191 for(int fstart;fend;f)
192 {
193 int t sampleToThresholdMap[f][k];
194 sum[f][t] labels[k];
195 count[f][t] ;
196 if(f 0)
197 {
198 sumResponse labels[k];
199 sqSumResponse labels[k]*labels[k];
200 }
201 }
202 }
203
204 for(int fstart;fend;f)
205 {
206 for(int t1;tthresholds[f].length;t)
207 {
208 sum[f][t] sum[f][t-1];
209 count[f][t] count[f][t-1];
210 }
211 }
212 }
213
214 //FeatureHistogram构造函数(3-1)一般用作父节点分裂后产生的右子节点的特征统计直方图的构造函数
215 public void construct(FeatureHistogram parent, FeatureHistogram leftSibling, boolean reuseParent)
216 {
217 this.reuseParent reuseParent;
218 this.features parent.features;
219 this.thresholds parent.thresholds;
220 sumResponse parent.sumResponse - leftSibling.sumResponse;
221 sqSumResponse parent.sqSumResponse - leftSibling.sqSumResponse;
222
223 if(reuseParent)
224 {
225 sum parent.sum;
226 count parent.count;
227 }
228 else
229 {
230 sum new double[features.length][];
231 count new int[features.length][];
232 }
233 sampleToThresholdMap parent.sampleToThresholdMap;
234
235 //确定是否使用多线程计算
236 MyThreadPool p MyThreadPool.getInstance();
237 if(p.size() 1)
238 construct(parent, leftSibling, 0, features.length-1);
239 else
240 p.execute(new Worker(this, parent, leftSibling), features.length);
241 }
242
243 //FeatureHistogram构造函数(3-2)被(3-1)调用
244 protected void construct(FeatureHistogram parent, FeatureHistogram leftSibling, int start, int end)
245 {
246 for(int fstart;fend;f)
247 {
248 float[] threshold thresholds[f];
249 if(!reuseParent)
250 {
251 sum[f] new double[threshold.length];
252 count[f] new int[threshold.length];
253 }
254 for(int t0;tthreshold.length;t)
255 {
256 sum[f][t] parent.sum[f][t] - leftSibling.sum[f][t];
257 count[f][t] parent.count[f][t] - leftSibling.count[f][t];
258 }
259 }
260 }
261
262 //findBestSplit函数(1-2)被(1-1)调用。在一个节点上在usedFeatures中根据该节点的特征统计直方图来进行分裂时feature和threshold的选择
263 protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end)
264 {
265 Config cfg new Config();
266 int totalCount count[start][count[start].length-1];
267 for(int fstart;fend;f)
268 {
269 int i usedFeatures[f];
270 float[] threshold thresholds[i];
271
272 for(int t0;tthreshold.length;t)
273 {
274 int countLeft count[i][t];
275 int countRight totalCount - countLeft;
276 if(countLeft minLeafSupport || countRight minLeafSupport)
277 continue;
278
279 double sumLeft sum[i][t];
280 double sumRight sumResponse - sumLeft;
281
282 double S sumLeft * sumLeft / countLeft sumRight * sumRight / countRight;
283 //求最大的S值对应于均方误差最小是最优的划分点
284 if(cfg.S S)
285 {
286 cfg.S S;
287 cfg.featureIdx i;
288 cfg.thresholdIdx t;
289 }
290 }
291 }
292 return cfg;
293 }
294
295 //findBestSplit函数(1-1)在该节点的已经统计好的特征统计直方图的基础上寻找最佳分裂点进行分裂再计算左右子节点的特征统计直方图并对左右子节点进行初始化
296 public boolean findBestSplit(Split sp, double[] labels, int minLeafSupport)
297 {
298 if(sp.getDeviance() 0.0 sp.getDeviance() 0.0)//equals 0
299 return false;//no need to split
300
301 int[] usedFeatures null;//index of the features to be used for tree splitting
302 if(samplingRate 1)//need to do sub sampling (feature sampling)
303 {
304 int size (int)(samplingRate * features.length);
305 usedFeatures new int[size];
306 //put all features into a pool
307 ListInteger fpool new ArrayListInteger();
308 for(int i0;ifeatures.length;i)
309 fpool.add(i);
310 //do sampling, without replacement
311 Random r new Random();
312 for(int i0;isize;i)
313 {
314 int sel r.nextInt(fpool.size());
315 usedFeatures[i] fpool.get(sel);
316 fpool.remove(sel);
317 }
318 }
319 else//no sub-sampling, all features will be used
320 {
321 usedFeatures new int[features.length];
322 for(int i0;ifeatures.length;i)
323 usedFeatures[i] i;
324 }
325
326 //find the best split
327 Config best new Config();
328 //确定是否使用多线程
329 MyThreadPool p MyThreadPool.getInstance();
330 if(p.size() 1)
331 best findBestSplit(usedFeatures, minLeafSupport, 0, usedFeatures.length-1);
332 else
333 {
334 WorkerThread[] workers p.execute(new Worker(this, usedFeatures, minLeafSupport), usedFeatures.length);
335 for(int i0;iworkers.length;i)
336 {
337 Worker wk (Worker)workers[i];
338 if(best.S wk.cfg.S)
339 best wk.cfg;
340 }
341 }
342
343 if(best.S -1)//unsplitable, for some reason...
344 return false;
345
346 //if(minS sp.getDeviance())
347 //return null;
348
349 double[] sumLabel sum[best.featureIdx];
350 int[] sampleCount count[best.featureIdx];
351
352 double s sumLabel[sumLabel.length-1];
353 int c sampleCount[sumLabel.length-1];
354
355 double sumLeft sumLabel[best.thresholdIdx];
356 int countLeft sampleCount[best.thresholdIdx];
357
358 double sumRight s - sumLeft;
359 int countRight c - countLeft;
360
361 int[] left new int[countLeft];
362 int[] right new int[countRight];
363 int l 0;
364 int r 0;
365 int k 0;
366 int[] idx sp.getSamples();
367 //对该节点上的每个训练数据根据最优分裂点进行左右子节点的分配
368 for(int j0;jidx.length;j)
369 {
370 k idx[j];
371 if(sampleToThresholdMap[best.featureIdx][k] best.thresholdIdx)//go to the left
372 left[l] k;
373 else//go to the right
374 right[r] k;
375 }
376
377 //初始化分裂后左右子节点各自的特征统计直方图
378 FeatureHistogram lh new FeatureHistogram();
379 lh.construct(sp.hist, left, labels); //初始化左子节点的特征统计直方图
380 FeatureHistogram rh new FeatureHistogram();
381 rh.construct(sp.hist, lh, !sp.isRoot()); //初始化右子节点的特征统计直方图
382 double var sqSumResponse - sumResponse * sumResponse / idx.length; //计算本节点的均方误差
383 double varLeft lh.sqSumResponse - lh.sumResponse * lh.sumResponse / left.length; //计算左子节点的均方误差
384 double varRight rh.sqSumResponse - rh.sumResponse * rh.sumResponse / right.length; //计算右子节点的均方误差
385
386 sp.set(features[best.featureIdx], thresholds[best.featureIdx][best.thresholdIdx], var);
387 sp.setLeft(new Split(left, lh, varLeft, sumLeft));
388 sp.setRight(new Split(right, rh, varRight, sumRight));
389
390 sp.clearSamples(); //清理本节点所属的sortedSampleIDs,samples,hist等数据
391
392 return true;
393 }
394 class Worker extends WorkerThread {
395 FeatureHistogram fh null;
396 int type -1;
397
398 //find best split (type 0)
399 int[] usedFeatures null;
400 int minLeafSup -1;
401 Config cfg null;
402
403 //update (type 1)
404 double[] labels null;
405
406 //construct (type 2)
407 FeatureHistogram parent null;
408 int[] soi null;
409
410 //construct (type 3)
411 FeatureHistogram leftSibling null;
412
413 //construct (type 4)
414 DataPoint[] samples;
415 int[][] sampleSortedIdx;
416 float[][] thresholds;
417
418 public Worker()
419 {
420 }
421 public Worker(FeatureHistogram fh, int[] usedFeatures, int minLeafSup)
422 {
423 type 0;
424 this.fh fh;
425 this.usedFeatures usedFeatures;
426 this.minLeafSup minLeafSup;
427 }
428 public Worker(FeatureHistogram fh, double[] labels)
429 {
430 type 1;
431 this.fh fh;
432 this.labels labels;
433 }
434 public Worker(FeatureHistogram fh, FeatureHistogram parent, int[] soi, double[] labels)
435 {
436 type 2;
437 this.fh fh;
438 this.parent parent;
439 this.soi soi;
440 this.labels labels;
441 }
442 public Worker(FeatureHistogram fh, FeatureHistogram parent, FeatureHistogram leftSibling)
443 {
444 type 3;
445 this.fh fh;
446 this.parent parent;
447 this.leftSibling leftSibling;
448 }
449 public Worker(FeatureHistogram fh, DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, float[][] thresholds)
450 {
451 type 4;
452 this.fh fh;
453 this.samples samples;
454 this.labels labels;
455 this.sampleSortedIdx sampleSortedIdx;
456 this.thresholds thresholds;
457 }
458 public void run()
459 {
460 if(type 0)
461 cfg fh.findBestSplit(usedFeatures, minLeafSup, start, end);
462 else if(type 1)
463 fh.update(labels, start, end);
464 else if(type 2)
465 fh.construct(parent, soi, labels, start, end);
466 else if(type 3)
467 fh.construct(parent, leftSibling, start, end);
468 else if(type 4)
469 fh.construct(samples, labels, sampleSortedIdx, thresholds, start, end);
470 }
471 public WorkerThread clone()
472 {
473 Worker wk new Worker();
474 wk.fh fh;
475 wk.type type;
476
477 //find best split (type 0)
478 wk.usedFeatures usedFeatures;
479 wk.minLeafSup minLeafSup;
480 //wk.cfg cfg;
481
482 //update (type 1)
483 wk.labels labels;
484
485 //construct (type 2)
486 wk.parent parent;
487 wk.soi soi;
488
489 //construct (type 3)
490 wk.leftSibling leftSibling;
491
492 //construct (type 1)
493 wk.samples samples;
494 wk.sampleSortedIdx sampleSortedIdx;
495 wk.thresholds thresholds;
496
497 return wk;
498 }
499 }
500 } 2. Split 1 package ciir.umass.edu.learning.tree;2 import java.util.ArrayList;3 import java.util.List;4 import ciir.umass.edu.learning.DataPoint;5 /**6 * 7 * author vdang8 *9 */10 //Tree node节点类用于11 // 1训练时候的分裂判断利用FeatureHistogram类12 // 2存储该节点的分裂规则featureIDthreshold以及该节点的输出avgLabeldeviance等13 public class Split {14 //Key attributes of a split (tree node)15 //存储该节点的分裂规则featureIDthreshold以及该节点的输出avgLabeldeviance等16 private int featureID -1;17 private float threshold 0F;18 private double avgLabel 0.0F;19 20 //Intermediate variables (ONLY used during learning)21 //*DO NOT* attempt to access them once the training is done22 private boolean isRoot false;23 private double sumLabel 0.0;24 private double sqSumLabel 0.0;25 private Split left null;26 private Split right null;27 private double deviance 0F;//mean squared error S28 private int[][] sortedSampleIDs null;29 public int[] samples null;//训练时候该节点上的训练数据集的索引30 public FeatureHistogram hist null;//训练时候该节点上的训练数据集的特征统计直方图31 32 public Split()33 {34 35 }36 public Split(int featureID, float threshold, double deviance)37 {38 this.featureID featureID;39 this.threshold threshold;40 this.deviance deviance;41 }42 public Split(int[][] sortedSampleIDs, double deviance, double sumLabel, double sqSumLabel)43 {44 this.sortedSampleIDs sortedSampleIDs;45 this.deviance deviance;46 this.sumLabel sumLabel;47 this.sqSumLabel sqSumLabel;48 avgLabel sumLabel/sortedSampleIDs[0].length;49 }50 public Split(int[] samples, FeatureHistogram hist, double deviance, double sumLabel)51 {52 this.samples samples;53 this.hist hist;54 this.deviance deviance;55 this.sumLabel sumLabel;56 avgLabel sumLabel/samples.length;57 }58 59 //一般在该节点进行分裂完成后设定分裂时的featureIDthresholddeviance。60 //只有非叶子节点才会进行分裂调用这个方法所以只有非叶子节点的featureID不为-1叶子节点由于没有调用这个方法故featureID-161 public void set(int featureID, float threshold, double deviance)62 {63 this.featureID featureID;64 this.threshold threshold;65 this.deviance deviance;66 }67 public void setLeft(Split s)68 {69 left s;70 }71 public void setRight(Split s)72 {73 right s;74 }75 public void setOutput(float output)76 {77 avgLabel output;78 }79 80 public Split getLeft()81 {82 return left;83 }84 public Split getRight()85 {86 return right;87 }88 public double getDeviance()89 {90 return deviance;91 }92 public double getOutput()93 {94 return avgLabel;95 }96 97 //得到此节点(一般是根节点)下的所有叶子节点的list98 //采用了递归的方法碰到叶子节点featureID-1则加入到list中否则递归地调用leaves(list)99 public ListSplit leaves()
100 {
101 ListSplit list new ArrayListSplit();
102 leaves(list);
103 return list;
104 }
105 private void leaves(ListSplit leaves)
106 {
107 if(featureID -1)
108 leaves.add(this);
109 else
110 {
111 left.leaves(leaves);
112 right.leaves(leaves);
113 }
114 }
115
116 //得到一个DataPoint在此节点(一般是根节点)下的最终落入每层都按照分裂规则进入下一层的叶子节点的输出值avgLabel值
117 public double eval(DataPoint dp)
118 {
119 Split n this;
120 while(n.featureID ! -1)
121 {
122 if(dp.getFeatureValue(n.featureID) n.threshold)
123 n n.left;
124 else
125 n n.right;
126 }
127 return n.avgLabel;
128 }
129
130 public String toString()
131 {
132 return toString();
133 }
134 public String toString(String indent)
135 {
136 String strOutput indent split \n;
137 strOutput getString(indent \t);
138 strOutput indent /split \n;
139 return strOutput;
140 }
141 public String getString(String indent)
142 {
143 String strOutput ;
144 if(featureID -1)
145 {
146 strOutput indent output avgLabel /output \n;
147 }
148 else
149 {
150 strOutput indent feature featureID /feature \n;
151 strOutput indent threshold threshold /threshold \n;
152 strOutput indent split pos\left\ \n;
153 strOutput left.getString(indent \t);
154 strOutput indent /split \n;
155 strOutput indent split pos\right\ \n;
156 strOutput right.getString(indent \t);
157 strOutput indent /split \n;
158 }
159 return strOutput;
160 }
161 //Internal functions(ONLY used during learning)
162 //*DO NOT* attempt to call them once the training is done
163 //*重要*训练时候在该节点上进行分裂调用了该节点的特征统计直方图对象的方法findBestSplit
164 public boolean split(double[] trainingLabels, int minLeafSupport)
165 {
166 return hist.findBestSplit(this, trainingLabels, minLeafSupport);
167 }
168 public int[] getSamples()
169 {
170 if(sortedSampleIDs ! null)
171 return sortedSampleIDs[0];
172 return samples;
173 }
174 public int[][] getSampleSortedIndex()
175 {
176 return sortedSampleIDs;
177 }
178 public double getSumLabel()
179 {
180 return sumLabel;
181 }
182 public double getSqSumLabel()
183 {
184 return sqSumLabel;
185 }
186 public void clearSamples()
187 {
188 sortedSampleIDs null;
189 samples null;
190 hist null;
191 }
192 public void setRoot(boolean isRoot)
193 {
194 this.isRoot isRoot;
195 }
196 public boolean isRoot()
197 {
198 return isRoot;
199 }
200 } 3. RegressionTree 1 package ciir.umass.edu.learning.tree;2 import java.util.ArrayList;3 import java.util.List;4 import ciir.umass.edu.learning.DataPoint;5 /**6 * author vdang7 */8 //回归树类9 public class RegressionTree {10 11 //Parameters12 protected int nodes 10;//-1 for unlimited number of nodes (the size of the tree will then be controlled *ONLY* by minLeafSupport)13 protected int minLeafSupport 1; //控制分裂的次数如果某个节点所包含的训练数据小于2*minLeafSupport 则该节点不再分裂14 15 //Member variables and functions 16 protected Split root null; //根节点17 protected ListSplit leaves null; //叶子节点list18 19 protected DataPoint[] trainingSamples null;20 protected double[] trainingLabels null;21 protected int[] features null;22 protected float[][] thresholds null; //二维数组第一维是feature下标是相应的features的下标不是feature id第二维是阈值个数为所有训练数据在此feature上的value的去重个数从小到大排序的不重复值用于对此节点的训练数据在此feature上分裂时可选的feature value阈值23 protected int[] index null;24 protected FeatureHistogram hist null;25 26 public RegressionTree(Split root)27 {28 this.root root;29 leaves root.leaves();30 }31 public RegressionTree(int nLeaves, DataPoint[] trainingSamples, double[] labels, FeatureHistogram hist, int minLeafSupport)32 {33 this.nodes nLeaves;34 this.trainingSamples trainingSamples;35 this.trainingLabels labels;36 this.hist hist;37 this.minLeafSupport minLeafSupport;38 index new int[trainingSamples.length];39 for(int i0;itrainingSamples.length;i)40 index[i] i;41 }42 43 /**44 * Fit the tree from the specified training data45 */46 public void fit()47 {48 ListSplit queue new ArrayListSplit(); //用于按队列顺序(即按层遍历的顺序)进行分裂49 root new Split(index, hist, Float.MAX_VALUE, 0); //回归树的根节点50 root.setRoot(true);51 root.split(trainingLabels, minLeafSupport); //根节点分裂1次下面多了2个子节点52 insert(queue, root.getLeft()); //将左子节点插入队列用于下面遍历53 insert(queue, root.getRight()); //将右子节点插入队列用于下面遍历54 //循环按队列顺序(即按层遍历的顺序)进行分裂再将每次能够成功分裂的产生的两个子节点插入队列中55 int taken 0;56 while( (nodes -1 || taken queue.size() nodes) queue.size() 0)57 {58 Split leaf queue.get(0);59 queue.remove(0);60 61 if(leaf.getSamples().length 2 * minLeafSupport)62 {63 taken;64 continue;65 }66 67 if(!leaf.split(trainingLabels, minLeafSupport))//unsplitable (i.e. variance(s)0; or after-split variance is higher than before) 对每个遍历到的节点进行1次分裂下面多了2个子节点68 taken;69 else70 {71 insert(queue, leaf.getLeft()); //将左子节点插入队列用于下面遍历72 insert(queue, leaf.getRight()); //将右子节点插入队列用于下面遍历73 } 74 }75 leaves root.leaves();76 }77 78 /**79 * Get the tree output for the input sample80 * param dp81 * return82 */83 public double eval(DataPoint dp)84 {85 return root.eval(dp);86 }87 /**88 * Retrieve all leave nodes in the tree89 * return90 */91 public ListSplit leaves()92 {93 return leaves;94 }95 /**96 * Clear samples associated with each leaves (when they are no longer necessary) in order to save memory97 */98 public void clearSamples()99 {
100 trainingSamples null;
101 trainingLabels null;
102 features null;
103 thresholds null;
104 index null;
105 hist null;
106 for(int i0;ileaves.size();i)
107 leaves.get(i).clearSamples();
108 }
109
110 /**
111 * Generate the string representation of the tree
112 */
113 public String toString()
114 {
115 if(root ! null)
116 return root.toString();
117 return ;
118 }
119 public String toString(String indent)
120 {
121 if(root ! null)
122 return root.toString(indent);
123 return ;
124 }
125
126 public double variance()
127 {
128 double var 0;
129 for(int i0;ileaves.size();i)
130 var leaves.get(i).getDeviance();
131 return var;
132 }
133 protected void insert(ListSplit ls, Split s)
134 {
135 int i0;
136 while(i ls.size())
137 {
138 if(ls.get(i).getDeviance() s.getDeviance()) //按均方误差从大到小的顺序进行插入队列
139 i;
140 else
141 break;
142 }
143 ls.add(i, s);
144 }
145 } 2. LambdaMart LambdaMart模型训练过程总结概括如下 1 LambdaMart2 init3 初始化训练数据martSamplesmodelScorespseudoResponsesweights4 将样本根据特征排序方便做树的分裂时快速找出最优分裂点sortedIdx5 初始化二维数组thresholds第一维是feature下标是相应的features的下标不是feature id第二维是阈值个数为所有训练数据在此feature上的value的去重个数从小到大排序的不重复值用于对此节点的训练数据在此feature上分裂时可选的feature value阈值6 hist.construct #根据训练数据以及thresholds二维数组初始化一个FeatureHistogram对象用于构造整体数据的特征统计直方图用于在根节点上进行分裂7 初始化8 sum #二维数组第一维是feature下标是相应的features的下标不是feature id第二维是label之和是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的label之和sum二维数组大小与thresholds数组相同9 count #二维数组第一维是feature下标是相应的features的下标不是feature id第二维是个数是所有训练数据中在此feature上的value小于等于相应位置的threshold值(thresholds[i][j])的DataPoint的个数count二维数组大小与thresholds数组相同
10 sampleToThresholdMap #二维数组第一维是feature下标是相应的features的下标不是feature id第二维是索引是对应训练数据samples[i][j]在特定feature上每个训练数据的value对应于其在thresholds数组中相应行的列索引位置
11 sumResponse #所有的训练数据的label之和
12 sqSumResponse #所有的训练数据的label的平方和
13 learn
14 初始化一个Ensemble对象ensemble
15 开始Gradient Boosting过程即依次构造若干棵regression tree
16 computePseudoResponses #计算本轮迭代中每个instance需要拟合的pseudo responses值即梯度值lambda
17 根据LambdaMart的梯度计算公式进行计算
18 hist.update #根据本轮迭代中计算得到的pseudo responses值即梯度值lambda更新特征统计直方图因为只改变了training data中每个instance的label而其他值如features并未改变
19 初始化一棵regression tree根据训练数据和特征统计直方图
20 rt.fit #用regression tree对训练数据本轮迭代中的pseudo responses值即梯度值lambda进行拟合
21 将本轮迭代拟合产生的regression tree加入到ensembel对象中
22 updateTreeOutput #更新本轮迭代中拟合数据的regression tree的各个叶子节点的输出
23 计算本轮迭代后新regression tree已经加入到集成模型中training data中各个instance的预测分modelScores
24 computeModelScoreOnTraining #计算本轮迭代后最新模型对于training data总体的排序评价分例如NDCG
25 计算本轮迭代后新regression tree已经加入到集成模型中validation data中各个instance的预测分modelScoresOnValidation
26 computeModelScoreOnValidation #计算本轮迭代后最新模型对于validation data总体的排序评价分例如NDCG
27 更新在validation data上的历次各个模型的最优排序评价分bestScoreOnValidationData以及最优模型编号bestModelOnValidation
28 如果在连续若干轮迭代中模型在validation data上的排序评价分都没有提高则终止迭代
29 回滚到在验证集上的最优模型
30 计算最优模型在training data和validation data上的排序评价分 下面是LambdaMart训练过程的代码关键部分都有添加了详细的注释。 1. LambdaMART 1 package ciir.umass.edu.learning.tree;2 import ciir.umass.edu.learning.DataPoint;3 import ciir.umass.edu.learning.RankList;4 import ciir.umass.edu.learning.Ranker;5 import ciir.umass.edu.metric.MetricScorer;6 import ciir.umass.edu.utilities.MergeSorter;7 import ciir.umass.edu.utilities.MyThreadPool;8 import ciir.umass.edu.utilities.RankLibError;9 import ciir.umass.edu.utilities.SimpleMath;10 import java.io.BufferedReader;11 import java.io.StringReader;12 import java.util.ArrayList;13 import java.util.Arrays;14 import java.util.List;15 /**16 * author vdang17 *18 * This class implements LambdaMART.19 * Q. Wu, C.J.C. Burges, K. Svore and J. Gao. Adapting Boosting for Information Retrieval Measures. 20 * Journal of Information Retrieval, 2007.21 */22 public class LambdaMART extends Ranker {23 //Parameters24 public static int nTrees 1000;//the number of trees25 public static float learningRate 0.1F;//or shrinkage26 public static int nThreshold 256;27 public static int nRoundToStopEarly 100;//If no performance gain on the *VALIDATION* data is observed in #rounds, stop the training process right away. 28 public static int nTreeLeaves 10;29 public static int minLeafSupport 1;30 31 //for debugging32 public static int gcCycle 100;33 34 //Local variables35 protected float[][] thresholds null;36 protected Ensemble ensemble null;37 protected double[] modelScores null;//on training data38 39 protected double[][] modelScoresOnValidation null;40 protected int bestModelOnValidation Integer.MAX_VALUE-2;41 42 //Training instances prepared for MART43 protected DataPoint[] martSamples null;//Need initializing only once44 protected int[][] sortedIdx null;//sorted list of samples in martSamples by each feature -- Need initializing only once 45 protected FeatureHistogram hist null;46 protected double[] pseudoResponses null;//different for each iteration47 protected double[] weights null;//different for each iteration48 49 public LambdaMART()50 { 51 }52 public LambdaMART(ListRankList samples, int[] features, MetricScorer scorer)53 {54 super(samples, features, scorer);55 }56 57 public void init()58 {59 PRINT(Initializing... ); 60 //initialize samples for MART61 int dpCount 0;62 for(int i0;isamples.size();i)63 {64 RankList rl samples.get(i);65 dpCount rl.size();66 }67 int current 0;68 martSamples new DataPoint[dpCount];69 modelScores new double[dpCount];70 pseudoResponses new double[dpCount];71 weights new double[dpCount];72 for(int i0;isamples.size();i)73 {74 RankList rl samples.get(i);75 for(int j0;jrl.size();j)76 {77 martSamples[currentj] rl.get(j);78 modelScores[currentj] 0.0F;79 pseudoResponses[currentj] 0.0F;80 weights[currentj] 0;81 }82 current rl.size();83 } 84 85 //sort (MART) samples by each feature so that we can quickly retrieve a sorted list of samples by any feature later on.86 // 将样本根据特征排序方便做树的分裂时快速找出最优分裂点87 sortedIdx new int[features.length][];88 MyThreadPool p MyThreadPool.getInstance();89 if(p.size() 1)//single-thread90 sortSamplesByFeature(0, features.length-1);91 else//multi-thread92 {93 int[] partition p.partition(features.length);94 for(int i0;ipartition.length-1;i)95 p.execute(new SortWorker(this, partition[i], partition[i1]-1));96 p.await();97 }98 99 //Create a table of candidate thresholds (for each feature). Later on, we will select the best tree split from these candidates // 创建存放候选阈值分裂点的表
100 thresholds new float[features.length][];
101 for(int f0;ffeatures.length;f)
102 {
103 //For this feature, keep track of the list of unique values and the max/min
104 ListFloat values new ArrayListFloat();
105 float fmax Float.NEGATIVE_INFINITY;
106 float fmin Float.MAX_VALUE;
107 for(int i0;imartSamples.length;i)
108 {
109 int k sortedIdx[f][i];//get samples sorted with respect to this feature
110 float fv martSamples[k].getFeatureValue(features[f]);
111 values.add(fv);
112 if(fmax fv)
113 fmax fv;
114 if(fmin fv)
115 fmin fv;
116 //skip all samples with the same feature value
117 int ji1;
118 while(j martSamples.length)
119 {
120 if(martSamples[sortedIdx[f][j]].getFeatureValue(features[f]) fv)
121 break;
122 j;
123 }
124 i j-1;//[i, j] gives the range of samples with the same feature value
125 }
126
127 if(values.size() nThreshold || nThreshold -1)
128 {
129 thresholds[f] new float[values.size()1];
130 for(int i0;ivalues.size();i)
131 thresholds[f][i] values.get(i);
132 thresholds[f][values.size()] Float.MAX_VALUE;
133 }
134 else
135 {
136 float step (Math.abs(fmax - fmin))/nThreshold;
137 thresholds[f] new float[nThreshold1];
138 thresholds[f][0] fmin;
139 for(int j1;jnThreshold;j)
140 thresholds[f][j] thresholds[f][j-1] step;
141 thresholds[f][nThreshold] Float.MAX_VALUE;
142 }
143 }
144
145 if(validationSamples ! null)
146 {
147 modelScoresOnValidation new double[validationSamples.size()][];
148 for(int i0;ivalidationSamples.size();i)
149 {
150 modelScoresOnValidation[i] new double[validationSamples.get(i).size()];
151 Arrays.fill(modelScoresOnValidation[i], 0);
152 }
153 }
154
155 //compute the feature histogram (this is used to speed up the procedure of finding the best tree split later on)
156 // 计算特征直方图加速寻找分裂点
157 hist new FeatureHistogram();
158 hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds);
159 //we no longer need the sorted indexes of samples
160 sortedIdx null;
161
162 System.gc();
163 PRINTLN([Done]);
164 }
165 public void learn()
166 {
167 ensemble new Ensemble();
168
169 PRINTLN(---------------------------------);
170 PRINTLN(Training starts...);
171 PRINTLN(---------------------------------);
172 PRINTLN(new int[]{7, 9, 9}, new String[]{#iter, scorer.name()-T, scorer.name()-V});
173 PRINTLN(---------------------------------);
174
175 //Start the gradient boosting process
176 for(int m0; mnTrees; m)
177 {
178 PRINT(new int[]{7}, new String[]{(m1)});
179
180 //Compute lambdas (which act as the pseudo responses)
181 //Create training instances for MART:
182 // - Each document is a training sample
183 // - The lambda for this document serves as its training label
184 // 计算lambdas (pseudo responses)
185 computePseudoResponses();
186
187 //update the histogram with these training labels (the feature histogram will be used to find the best tree split)
188 // 根据新的label更新特征直方图
189 hist.update(pseudoResponses);
190
191 //Fit a regression tree
192 // 回归决策树
193 RegressionTree rt new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport);
194 rt.fit();
195
196 //Add this tree to the ensemble (our model)
197 // 将新生成的树加入模型
198 ensemble.add(rt, learningRate);
199 //update the outputs of the tree (with gamma computed using the Newton-Raphson method)
200 // 更新树的输出
201 updateTreeOutput(rt);
202
203 //Update the models outputs on all training samples
204 // 更新所有训练样本的模型输出
205 ListSplit leaves rt.leaves();
206 for(int i0;ileaves.size();i)
207 {
208 Split s leaves.get(i);
209 int[] idx s.getSamples();
210 for(int j0;jidx.length;j)
211 modelScores[idx[j]] learningRate * s.getOutput();
212 }
213 //clear references to data that is no longer used
214 rt.clearSamples();
215
216 //beg the garbage collector to work...
217 if(m % gcCycle 0)
218 System.gc();//this call is expensive. We shouldnt do it too often.
219 //Evaluate the current model
220 // 评价模型
221 scoreOnTrainingData computeModelScoreOnTraining();
222 //**** NOTE ****
223 //The above function to evaluate the current model on the training data is equivalent to a single call:
224 //
225 // scoreOnTrainingData scorer.score(rank(samples);
226 //
227 //However, this function is more efficient since it uses the cached outputs of the model (as opposed to re-evaluating the model
228 //on the entire training set).
229
230 PRINT(new int[]{9}, new String[]{SimpleMath.round(scoreOnTrainingData, 4) });
231
232 //Evaluate the current model on the validation data (if available)
233 if(validationSamples ! null)
234 {
235 //Update the models scores on all validation samples
236 for(int i0;imodelScoresOnValidation.length;i)
237 for(int j0;jmodelScoresOnValidation[i].length;j)
238 modelScoresOnValidation[i][j] learningRate * rt.eval(validationSamples.get(i).get(j));
239
240 //again, equivalent to scoreOnValidationscorer.score(rank(validationSamples)), but more efficient since we use the cached models outputs
241 double score computeModelScoreOnValidation();
242
243 PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) });
244 if(score bestScoreOnValidationData)
245 {
246 bestScoreOnValidationData score;
247 bestModelOnValidation ensemble.treeCount()-1;
248 }
249 }
250
251 PRINTLN();
252
253 //Should we stop early?
254 // 检验是否提前结束
255 if(m - bestModelOnValidation nRoundToStopEarly)
256 break;
257 }
258
259 //Rollback to the best model observed on the validation data
260 // 回滚到在验证集上的最优模型
261 while(ensemble.treeCount() bestModelOnValidation1)
262 ensemble.remove(ensemble.treeCount()-1);
263
264 //Finishing up
265 scoreOnTrainingData scorer.score(rank(samples));
266 PRINTLN(---------------------------------);
267 PRINTLN(Finished sucessfully.);
268 PRINTLN(scorer.name() on training data: SimpleMath.round(scoreOnTrainingData, 4));
269 if(validationSamples ! null)
270 {
271 bestScoreOnValidationData scorer.score(rank(validationSamples));
272 PRINTLN(scorer.name() on validation data: SimpleMath.round(bestScoreOnValidationData, 4));
273 }
274 PRINTLN(---------------------------------);
275 }
276 public double eval(DataPoint dp)
277 {
278 return ensemble.eval(dp);
279 }
280 public Ranker createNew()
281 {
282 return new LambdaMART();
283 }
284 public String toString()
285 {
286 return ensemble.toString();
287 }
288 public String model()
289 {
290 String output ## name() \n;
291 output ## No. of trees nTrees \n;
292 output ## No. of leaves nTreeLeaves \n;
293 output ## No. of threshold candidates nThreshold \n;
294 output ## Learning rate learningRate \n;
295 output ## Stop early nRoundToStopEarly \n;
296 output \n;
297 output toString();
298 return output;
299 }
300 Override
301 public void loadFromString(String fullText)
302 {
303 try {
304 String content ;
305 //String model ;
306 StringBuffer model new StringBuffer ();
307 BufferedReader in new BufferedReader(new StringReader(fullText));
308 while((content in.readLine()) ! null)
309 {
310 content content.trim();
311 if(content.length() 0)
312 continue;
313 if(content.indexOf(##)0)
314 continue;
315 //actual model component
316 //model content;
317 model.append (content);
318 }
319 in.close();
320 //load the ensemble
321 ensemble new Ensemble(model.toString());
322 features ensemble.getFeatures();
323 }
324 catch(Exception ex)
325 {
326 throw RankLibError.create(Error in LambdaMART::load(): , ex);
327 }
328 }
329 public void printParameters()
330 {
331 PRINTLN(No. of trees: nTrees);
332 PRINTLN(No. of leaves: nTreeLeaves);
333 PRINTLN(No. of threshold candidates: nThreshold);
334 PRINTLN(Min leaf support: minLeafSupport);
335 PRINTLN(Learning rate: learningRate);
336 PRINTLN(Stop early: nRoundToStopEarly rounds without performance gain on validation data);
337 }
338 public String name()
339 {
340 return LambdaMART;
341 }
342 public Ensemble getEnsemble()
343 {
344 return ensemble;
345 }
346
347 protected void computePseudoResponses()
348 {
349 Arrays.fill(pseudoResponses, 0F);
350 Arrays.fill(weights, 0);
351 MyThreadPool p MyThreadPool.getInstance();
352 if(p.size() 1)//single-thread
353 computePseudoResponses(0, samples.size()-1, 0);
354 else //multi-threading
355 {
356 ListLambdaComputationWorker workers new ArrayListLambdaMART.LambdaComputationWorker();
357 //divide the entire dataset into chunks of equal size for each worker thread
358 int[] partition p.partition(samples.size());
359 int current 0;
360 for(int i0;ipartition.length-1;i)
361 {
362 //execute the worker
363 LambdaComputationWorker wk new LambdaComputationWorker(this, partition[i], partition[i1]-1, current);
364 workers.add(wk);//keep it so we can get back results from it later on
365 p.execute(wk);
366
367 if(i partition.length-2)
368 for(int jpartition[i]; jpartition[i1]-1;j)
369 current samples.get(j).size();
370 }
371
372 //wait for all workers to complete before we move on to the next stage
373 p.await();
374 }
375 }
376 protected void computePseudoResponses(int start, int end, int current)
377 {
378 int cutoff scorer.getK();
379 //compute the lambda for each document (a.k.a pseudo response)
380 for(int istart;iend;i)
381 {
382 RankList orig samples.get(i);
383 int[] idx MergeSorter.sort(modelScores, current, currentorig.size()-1, false);
384 RankList rl new RankList(orig, idx, current);
385 double[][] changes scorer.swapChange(rl);
386 //NOTE: j, k are indices in the sorted (by modelScore) list, not the original
387 // need to map back with idx[j] and idx[k]
388 for(int j0;jrl.size();j)
389 {
390 DataPoint p1 rl.get(j);
391 int mj idx[j];
392 for(int k0;krl.size();k)
393 {
394 if(j cutoff k cutoff)//swaping these pair wont result in any change in target measures since theyre below the cut-off point
395 break;
396 DataPoint p2 rl.get(k);
397 int mk idx[k];
398 if(p1.getLabel() p2.getLabel())
399 {
400 double deltaNDCG Math.abs(changes[j][k]);
401 if(deltaNDCG 0)
402 {
403 double rho 1.0 / (1 Math.exp(modelScores[mj] - modelScores[mk]));
404 double lambda rho * deltaNDCG;
405 pseudoResponses[mj] lambda;
406 pseudoResponses[mk] - lambda;
407 double delta rho * (1.0 - rho) * deltaNDCG;
408 weights[mj] delta;
409 weights[mk] delta;
410 }
411 }
412 }
413 }
414 current orig.size();
415 }
416 }
417 protected void updateTreeOutput(RegressionTree rt)
418 {
419 ListSplit leaves rt.leaves();
420 for(int i0;ileaves.size();i)
421 {
422 float s1 0F;
423 float s2 0F;
424 Split s leaves.get(i);
425 int[] idx s.getSamples();
426 for(int j0;jidx.length;j)
427 {
428 int k idx[j];
429 s1 pseudoResponses[k];
430 s2 weights[k];
431 }
432 if(s2 0)
433 s.setOutput(0);
434 else
435 s.setOutput(s1/s2);
436 }
437 }
438 protected int[] sortSamplesByFeature(DataPoint[] samples, int fid)
439 {
440 double[] score new double[samples.length];
441 for(int i0;isamples.length;i)
442 score[i] samples[i].getFeatureValue(fid);
443 int[] idx MergeSorter.sort(score, true);
444 return idx;
445 }
446 /**
447 * This function is equivalent to the inherited function rank(...), but it uses the cached models outputs instead of computing them from scratch.
448 * param rankListIndex
449 * param current
450 * return
451 */
452 protected RankList rank(int rankListIndex, int current)
453 {
454 RankList orig samples.get(rankListIndex);
455 double[] scores new double[orig.size()];
456 for(int i0;iscores.length;i)
457 scores[i] modelScores[currenti];
458 int[] idx MergeSorter.sort(scores, false);
459 return new RankList(orig, idx);
460 }
461 protected float computeModelScoreOnTraining()
462 {
463 /*float s 0;
464 int current 0;
465 MyThreadPool p MyThreadPool.getInstance();
466 if(p.size() 1)//single-thread
467 s computeModelScoreOnTraining(0, samples.size()-1, current);
468 else
469 {
470 ListWorker workers new ArrayListWorker();
471 //divide the entire dataset into chunks of equal size for each worker thread
472 int[] partition p.partition(samples.size());
473 for(int i0;ipartition.length-1;i)
474 {
475 //execute the worker
476 Worker wk new Worker(this, partition[i], partition[i1]-1, current);
477 workers.add(wk);//keep it so we can get back results from it later on
478 p.execute(wk);
479
480 if(i partition.length-2)
481 for(int jpartition[i]; jpartition[i1]-1;j)
482 current samples.get(j).size();
483 }
484 //wait for all workers to complete before we move on to the next stage
485 p.await();
486 for(int i0;iworkers.size();i)
487 s workers.get(i).score;
488 }*/
489 float s computeModelScoreOnTraining(0, samples.size()-1, 0);
490 s s / samples.size();
491 return s;
492 }
493 protected float computeModelScoreOnTraining(int start, int end, int current)
494 {
495 float s 0;
496 int c current;
497 for(int istart;iend;i)
498 {
499 s scorer.score(rank(i, c));
500 c samples.get(i).size();
501 }
502 return s;
503 }
504 protected float computeModelScoreOnValidation()
505 {
506 /*float score 0;
507 MyThreadPool p MyThreadPool.getInstance();
508 if(p.size() 1)//single-thread
509 score computeModelScoreOnValidation(0, validationSamples.size()-1);
510 else
511 {
512 ListWorker workers new ArrayListWorker();
513 //divide the entire dataset into chunks of equal size for each worker thread
514 int[] partition p.partition(validationSamples.size());
515 for(int i0;ipartition.length-1;i)
516 {
517 //execute the worker
518 Worker wk new Worker(this, partition[i], partition[i1]-1);
519 workers.add(wk);//keep it so we can get back results from it later on
520 p.execute(wk);
521 }
522 //wait for all workers to complete before we move on to the next stage
523 p.await();
524 for(int i0;iworkers.size();i)
525 score workers.get(i).score;
526 }*/
527 float score computeModelScoreOnValidation(0, validationSamples.size()-1);
528 return score/validationSamples.size();
529 }
530 protected float computeModelScoreOnValidation(int start, int end)
531 {
532 float score 0;
533 for(int istart;iend;i)
534 {
535 int[] idx MergeSorter.sort(modelScoresOnValidation[i], false);
536 score scorer.score(new RankList(validationSamples.get(i), idx));
537 }
538 return score;
539 }
540
541 protected void sortSamplesByFeature(int fStart, int fEnd)
542 {
543 for(int ifStart;ifEnd; i)
544 sortedIdx[i] sortSamplesByFeature(martSamples, features[i]);
545 }
546 //For multi-threading processing
547 class SortWorker implements Runnable {
548 LambdaMART ranker null;
549 int start -1;
550 int end -1;
551 SortWorker(LambdaMART ranker, int start, int end)
552 {
553 this.ranker ranker;
554 this.start start;
555 this.end end;
556 }
557 public void run()
558 {
559 ranker.sortSamplesByFeature(start, end);
560 }
561 }
562 class LambdaComputationWorker implements Runnable {
563 LambdaMART ranker null;
564 int rlStart -1;
565 int rlEnd -1;
566 int martStart -1;
567 LambdaComputationWorker(LambdaMART ranker, int rlStart, int rlEnd, int martStart)
568 {
569 this.ranker ranker;
570 this.rlStart rlStart;
571 this.rlEnd rlEnd;
572 this.martStart martStart;
573 }
574 public void run()
575 {
576 ranker.computePseudoResponses(rlStart, rlEnd, martStart);
577 }
578 }
579 class Worker implements Runnable {
580 LambdaMART ranker null;
581 int rlStart -1;
582 int rlEnd -1;
583 int martStart -1;
584 int type -1;
585
586 //compute score on validation
587 float score 0;
588
589 Worker(LambdaMART ranker, int rlStart, int rlEnd)
590 {
591 type 3;
592 this.ranker ranker;
593 this.rlStart rlStart;
594 this.rlEnd rlEnd;
595 }
596 Worker(LambdaMART ranker, int rlStart, int rlEnd, int martStart)
597 {
598 type 4;
599 this.ranker ranker;
600 this.rlStart rlStart;
601 this.rlEnd rlEnd;
602 this.martStart martStart;
603 }
604 public void run()
605 {
606 if(type 4)
607 score ranker.computeModelScoreOnTraining(rlStart, rlEnd, martStart);
608 else if(type 3)
609 score ranker.computeModelScoreOnValidation(rlStart, rlEnd);
610 }
611 }
612 } 版权声明 本文由笨兔勿应所有发布于http://www.cnblogs.com/bentuwuying。如果转载请注明出处在未经作者同意下将本文用于商业用途将追究其法律责任。 转载于:https://www.cnblogs.com/bentuwuying/p/6701027.html