DL:基于sklearn的加利福尼亚房价数据集实现GD算法
生活随笔
收集整理的這篇文章主要介紹了
DL:基于sklearn的加利福尼亚房价数据集实现GD算法
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
DL:基于sklearn的加利福尼亞房價數據集實現GD算法
?
?
目錄
輸出結果
代碼設計
?
?
輸出結果
? ? ?該數據包含9個變量的20640個觀測值,該數據集包含平均房屋價值作為目標變量和以下輸入變量(特征):平均收入、房屋平均年齡、平均房間、平均臥室、人口、平均占用、緯度和經度。
更新……
?
代碼設計
#DL:基于sklearn的加利福尼亞房價數據集實現GD算法 import tensorflow as tf import numpy as np from sklearn.datasets import fetch_california_housing from sklearn.preprocessing import StandardScaler scaler = StandardScaler() #將特征進行標準歸一化 #獲取房價數據 housing = fetch_california_housing() m,n = housing.data.shape print (housing.keys()) #輸出房價的key print (housing.feature_names) #輸出房價的特征: print (housing.target) print (housing.DESCR) housing_data_plus_bias = np.c_[np.ones((m,1)), housing.data] scaled_data = scaler. fit_transform(housing.data) data = np.c_[np.ones((m,1)),scaled_data] #設置參數 n_epoch = 1000 learning_rate = 0.01 #設置placeholder即灌入數據 X = tf.constant(data,dtype = tf.float32,name = "X") y = tf.constant(housing.target.reshape(-111),dtype=tf.float32,name='y') #theta理解為權重,random_uniform途中創建包含隨機值的節點即初始權重是隨機賦值的,理解為numpy的random函數 theta = tf.Variable(tf.random_uniform([n+1, 1], -1, 1),name='theta') y_pred = tf.matmul(X,theta,name='prediction') error = y_pred - y mse = tf.reduce_mean(tf.square(error),name='mse') #采用的成本函數是mse即Mean Squared Error均方誤差#計算梯度公式,關鍵一步 # #T1、手動求導 # gradient = 2/m * tf.matmul(tf.transpose(X),error) # training_op = tf.assign(theta,theta - learning_rate * gradient) #assign將新值賦值給一個變量的節點,即權重更新公式的迭代過程#T2、自動求導 optimizer = tf.train.GradientDescentOptimizer(learning_rate = learning_rate)#參數初始化,啟動session,將graph放入session進行每一步的更新 init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) for epoch in range(n_epoch): if epoch % 100 == 0: print ("Epoch",epoch, "MSE =", mse.eval()) # sess.run(training_op)print('best theta:',theta.eval())?
GitHub相關文章
DL:基于sklearn的加利福尼亞房價,數據集較多時采用mini-batch方式訓練會更快
TF保存模型:基于TF進行模型的保存與恢復加載,調用Save()函數即可
?
?
?
總結
以上是生活随笔為你收集整理的DL:基于sklearn的加利福尼亚房价数据集实现GD算法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TF之LoR:基于tensorflow利
- 下一篇: 成功解决AttributeError: