Java机器学习库ML之四模型训练和预测示例
生活随笔
收集整理的這篇文章主要介紹了
Java机器学习库ML之四模型训练和预测示例
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
基于ML庫機器學習的步驟:
1)樣本數據導入;
2)樣本數據特征抽取和特征值處理(結合模型需要歸一化或離散化);這里本文沒有做處理,特征選擇和特征值處理本身就很大;
3)樣本集劃分訓練集和驗證集;
4)根據訓練集訓練模型;
5)用驗證集評價模型;
6)導入測試集,并用模型預測輸出預測結果;
package com.vip;import java.io.File;import be.abeel.util.Pair; import net.sf.javaml.classification.Classifier; import net.sf.javaml.classification.KNearestNeighbors; import net.sf.javaml.core.Dataset; import net.sf.javaml.core.DefaultDataset; import net.sf.javaml.core.DenseInstance; import net.sf.javaml.core.Instance; import net.sf.javaml.featureselection.scoring.GainRatio; import net.sf.javaml.sampling.Sampling; import net.sf.javaml.tools.data.FileHandler;public class VIPClassifer {public static void main(String[] args)throws Exception {if (args.length != 2) {System.err.println("Usage: 輸入訓練集和測試集路徑");System.exit(2);}/* Load a data set 前面13列是訓練特征,最后1列標記*/Dataset ori_data = FileHandler.loadDataset(new File(args[0]), 13, "\\s+");//特征評分,可獨立//GainRatio ga = new GainRatio(); //ga.build(ori_data); /* Apply the algorithm to the data set */ //for (int i = 0; i < ga.noAttributes(); i++) // System.out.println(ga.score(i)); //抽樣訓練集和驗證集Sampling s = Sampling.SubSampling;Pair<Dataset, Dataset> sam_data = s.sample(ori_data, (int) (ori_data.size() * 0.8));/*Dataset train_data = new DefaultDataset();//80%訓練Dataset test_data = new DefaultDataset();//20%驗證int sample=0;for(Instance inst:ori_data){double[] values = new double[] { inst.value(5),inst.value(6),inst.value(7),inst.value(8), inst.value(9),inst.value(16),inst.value(17)};Instance train_inst = new DenseInstance(values, inst.classValue()); if(sample<4){sample++;train_data.add(train_inst);}else {sample=0;test_data.add(train_inst); } }*///Contruct a KNN classifier that uses 5 neighbors to make a decision.Classifier knn = new KNearestNeighbors(5);knn.buildClassifier(sam_data.x());//驗證集int correct = 0, wrong = 0;/* Classify all instances and check with the correct class values */for (Instance inst : sam_data.y()) {Object predictedClassValue = knn.classify(inst);Object realClassValue = inst.classValue();if (predictedClassValue.equals(realClassValue))correct++;elsewrong++;}System.out.println("Correct predictions " + correct);System.out.println("Wrong predictions " + wrong);//模型預測/* Load a data set 前面13列是訓練特征,最后2列是uid和spuid聯合標識*/Dataset pre_data = FileHandler.loadDataset(new File(args[1]),"\\s+");System.out.println(pre_data.instance(0));Dataset out_data = new DefaultDataset();for(Instance inst:pre_data){double[] values = new double[13]; for(int i=0;i<13;i++) values[i]=inst.value(i);Instance pre_inst = new DenseInstance(values); //無標記,13列特征參與訓練Object pre_classvalue = knn.classify(pre_inst);//預測結果//pre_inst.setClassValue(pre_classvalue);//標注預測結果double[] u_spu_id=new double[]{inst.value(13),inst.value(14)};Instance out_inst = new DenseInstance(u_spu_id,pre_classvalue); //帶標記out_data.add(out_inst);}//輸出u_Id+spu_id+action_typeFileHandler.exportDataset(out_data, new File("/data1/DataFountain/output.txt"));} } //java -XX:-UseGCOverheadLimit -Xmx10240m -jar vip.jar train_features_new.txt test_features_new.txt在上面這個代碼框架內,可以用不同模型,如SVM、RF(隨機森林)等,也可以對特征值進行處理后選擇特征來訓練。
總結
以上是生活随笔為你收集整理的Java机器学习库ML之四模型训练和预测示例的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Java数据结构Map遍历和排序
- 下一篇: Java机器学习库ML之五样本不均衡