TensorFlow学习笔记(十一)读取自己的数据进行训练
1. 線性關系
數據csv文件讀取
x,y
1,2
4,5
6,11
3,6
4,7
5,12
7,13
10,21
11,23
24,50
45,89
50,101
55,111
60,123
70,139
80,164
85,171
90,192
95,190
100,199
200,401
1000,2000
代碼:
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 15:43:41 2017
@author: ESRI
"""
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 14:59:10 2017
@author: ESRI
"""
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
#讀取數據
dataset = pd.read_csv('E:\\testData\\network.csv')
#查看描述信息
print(dataset.describe())
#查看前5行
print(dataset.head())
#查看數據形狀
print(dataset.shape)
#分別得到
X_data = dataset['x'].as_matrix(columns=None).reshape(-1,1)
#print(X_data)
Y_data = dataset['y'].as_matrix(columns=None).reshape(-1,1)
#添加一層網絡
def add_layer(inputs, in_size, out_size, activation_function=None):
??? # add one more layer and return the output of this layer
??? Weights = tf.Variable(tf.random_normal([in_size, out_size]))
??? biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
??? Wx_plus_b = tf.matmul(inputs, Weights) + biases
??? if activation_function is None:
??????? outputs = Wx_plus_b
??? else:
??????? outputs = activation_function(Wx_plus_b)
??? return outputs
#歸一化
def normalize(train):
??? mean, std = train.mean(), train.std()
??? train = (train - mean) / std
??? return train
xs = tf.placeholder(tf.float32)
ys = tf.placeholder(tf.float32)
#歸一化處理數據
X = normalize(X_data)
Y = normalize(Y_data)
#3層網絡
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)
#計算loss
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
???????????????????? reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
# important step
#init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
#結果可視化
# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(X, Y)
plt.ion()
plt.show()
for i in range(8000):
??? # training
??? sess.run(train_step, feed_dict={xs: X, ys: Y})
??? if i % 50 == 0:
??????? print(sess.run(loss, feed_dict={xs: X, ys: Y}))
??????? try:
??????????? ax.lines.remove(lines[0])
??????? except Exception:
??????????? pass
??????? prediction_value = sess.run(prediction, feed_dict={xs: X})
??????? # plot the prediction
??????? lines = ax.plot(X, prediction_value, 'r-', lw=5)
??????? plt.pause(0.1)
?結果: ???
??????????????????? x??????????? y
count??? 22.000000??? 22.000000
mean???? 91.136364?? 183.181818
std???? 208.740051?? 417.486314
min?????? 1.000000???? 2.000000
25%?????? 6.250000??? 12.250000
50%????? 47.500000??? 95.000000
75%????? 83.750000?? 169.250000
max??? 1000.000000? 2000.000000
????? x??? y
0???? 1??? 2
1???? 4??? 5
2???? 6?? 11
3???? 3??? 6
4???? 4??? 7
5???? 5?? 12
6???? 7?? 13
7??? 10?? 21
8??? 11?? 23
9??? 24?? 50
10?? 45?? 89
11?? 50? 101
12?? 55? 111
13?? 60? 123
14?? 70? 139
15?? 80? 164
16?? 85? 171
17?? 90? 192
18?? 95? 190
19? 100? 199
(22, 2)
[[-0.4419732 ]
?[-0.42726305]
?[-0.41745628]
?...,
?[ 0.04346181]
?[ 0.53380021]
?[ 4.45650743]]

11.6106
0.00154685
0.00107705
0.000622838
0.000468779
0.000346998
0.000274857
0.00016539
9.39608e-05
6.02521e-05
4.41742e-05
3.47886e-05
3.02667e-05
2.81042e-05
2.73301e-05
2.69677e-05
2.67462e-05
2.66131e-05
2.6452e-05
2.63586e-05
2.63102e-05
2.61975e-05
2.61691e-05
2.61784e-05
2.61712e-05
2.61596e-05
2.61267e-05
2.61323e-05
2.61504e-05
2.61072e-05
2.61337e-05
2.61305e-05
2.60892e-05
2.60815e-05
2.6096e-05
2.60919e-05
2.60685e-05
2.60606e-05
2.60774e-05
2.61023e-05
2.60717e-05
2.60601e-05
2.60832e-05
2.60474e-05
2.60752e-05
2.60568e-05
2.60328e-05
2.60716e-05
2.60527e-05
2.60288e-05
2.60224e-05
2.60488e-05
2.60549e-05
2.60573e-05
2.60576e-05
2.60556e-05
2.60509e-05
2.60434e-05
2.60333e-05
2.60186e-05
2.60025e-05
2.60154e-05
2.60487e-05
2.60329e-05
2.59924e-05
2.60066e-05
2.60364e-05
2.60053e-05
2.60045e-05
2.60256e-05
2.5987e-05
2.60303e-05
2.59782e-05
2.603e-05
2.59753e-05
2.60242e-05
2.59781e-05
2.60142e-05
2.59865e-05
2.59966e-05
2.6021e-05
2.59726e-05
2.59987e-05
2.6012e-05
2.59699e-05
2.59885e-05
2.60072e-05
2.59776e-05
2.59591e-05
2.59867e-05
2.59993e-05
2.59841e-05
2.59637e-05
2.59506e-05
2.59757e-05
2.59872e-05
2.59941e-05
2.5992e-05
2.59636e-05
2.59547e-05
2.59475e-05
2.59412e-05
2.59377e-05
2.59612e-05
2.59653e-05
2.59678e-05
2.59692e-05
2.59695e-05
2.59691e-05
2.59679e-05
2.59662e-05
2.59643e-05
2.59615e-05
2.59585e-05
2.59546e-05
2.59497e-05
2.5937e-05
2.59175e-05
2.59209e-05
2.59248e-05
2.59291e-05
2.5935e-05
2.59458e-05
2.59611e-05
2.59533e-05
2.59444e-05
2.59316e-05
2.59077e-05
2.59154e-05
2.59242e-05
2.59549e-05
2.59445e-05
2.59325e-05
2.58975e-05
2.59079e-05
2.59221e-05
2.59423e-05
2.59288e-05
2.58909e-05
2.5903e-05
2.59232e-05
2.59314e-05
2.59115e-05
2.58939e-05
2.59115e-05
2.59273e-05
2.59065e-05
2.589e-05
2.59133e-05
2.59185e-05
2.58759e-05
2.58929e-05
2.59227e-05
2.59028e-05
2.58816e-05
2.59242e-05
2.59048e-05
2.58738e-05
2.59005e-05
2.59029e-05??? ?
??
總結
以上是生活随笔為你收集整理的TensorFlow学习笔记(十一)读取自己的数据进行训练的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TensorFlow学习笔记(十)tf搭
- 下一篇: scikit-learn学习笔记(四)R