机器学习入门04-线性回归原理与java实现多元线性回归
線性回歸原理:
線性回歸公式:y = b + w*x,w表示權重b表示偏置。
在實際實現中可以將公式寫作:y = w[0] * x[0] + w[1] * x[1],x[0]=1,這樣就可以很方便的進行參數求解,同樣稍作修改將公式寫成:y = w[0] * x[0] + w[1] * x[1] + ... + w[n]*x[n],就變成了多元回歸。
采用梯度下降和多次迭代不斷優化參數,梯度下降計算參數的梯度,計算流程分為以下幾步:
1、根據當前參數和訓練計算數據預測值
????????preY = sum(w[n] + x[n])
2、計算梯度
????????wright_gradient[n] = sum(2 * (preY - y) * x[n] / N),N為訓練數據總行數
3、更新參數:
????????wright[n] = wright[n] - a * wright_gradient[n],a為學習率,學習率取值范圍[0,1],根據訓練數據和訓練情況來定。
4、迭代
????????每迭代一次就多整個訓練數據計算一次梯度和更新一次參數,通過迭代使函數不斷逼近最小誤差。
線性回歸的實現(java實現,一元回歸和多元回歸通用):
????????1、讀取數據,以csv格式存儲,前面幾列為x,最后一列為y。
public List<double[]> readTrainFile(String filepath) {File trainFile = new File(filepath);List<double[]> resultList = new ArrayList<double[]>();if (trainFile.exists()) {try {BufferedReader reader = new BufferedReader(new FileReader(trainFile));String line;while ((line = reader.readLine()) != null) {String[] strs = line.split(",");double[] lines = new double[strs.length];for (int i = 0; i < strs.length; i++) {lines[i] = Double.parseDouble(strs[i]);}resultList.add(lines);}reader.close();} catch (Exception e) {e.printStackTrace();}}return resultList;}? ? ? ? 2、訓練,需要設置學習率和迭代次數,返回參數數組。
public double[] train(String filepath, double learningRate, int iterationNum) {List<double[]> trainData = readTrainFile(filepath);double[] weights = new double[trainData.get(0).length];for(int i = 0; i < weights.length; i++) {weights[i] = 0;}weights = updateWeights(trainData, weights, learningRate, iterationNum);return weights;}? ? ? ? 3、計算權重參數,對數據集每迭代一次,使用梯度下降計算梯度,通過學習率*梯度更新權重。
public double[] updateWeights(List<double[]> trainData, double[] weights, double learningRate, int iterationNum) {for (int i = 0; i < iterationNum; i++) {double[] weights_gradient = new double[weights.length];for (int j = 0; j < trainData.size(); j++) {double[] line = trainData.get(j);double[] x = new double[line.length];x[0] = 1;double y = line[line.length - 1];for(int n = 1; n < x.length; n++) {x[n] = line[n - 1];}//根據當前參數和數據預測preYdouble preY = 0.0;for(int n = 0; n < weights.length; n++) {preY += x[n] * weights[n];}for(int n = 0; n < weights.length; n++) {weights_gradient[n]+=2 * (preY - y) * x[n] / (double)trainData.size();}}//更新參數for(int j = 0; j < weights.length; j++) {weights[j] = weights[j] - learningRate * weights_gradient[j];}//每迭代1次,輸出lossif (i % 100 == 0) {double loss = computeError(trainData, weights);System.out.println(loss);}}return weights; }? ? ? ? 4、計算error
public double computeError(List<double[]> trainData, double[] weights) {double error= 0.0;for(int i = 0; i < trainData.size(); i++) {double[] line = trainData.get(i);double preY = 0.0;double[] x = new double[line.length];x[0] = 1;double y = line[line.length - 1];for(int n = 1; n < x.length; n++) {x[n] = line[n - 1];}for(int j = 0; j < line.length; j++) {preY += weights[j] * x[j];}error += (y - preY) * (y - preY);}return error / (double)trainData.size();}? ? ? ? 5、測試程序,輸出計算后參數
public static void main(String[] args) {MultiLineRegression lineRegression = new MultiLineRegression();double[] weights = lineRegression.train("E:/index/traindata.csv", 0.001, 1000);for(double w : weights) {System.out.print(w + ",");}}總結
以上是生活随笔為你收集整理的机器学习入门04-线性回归原理与java实现多元线性回归的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 机器学习入门02-朴素贝叶斯原理和jav
- 下一篇: tensorflow实现宝可梦数据集迁移