import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True)
input_size = 784
#数字从0-9
no_classes = 10
batch_size = 100
total_batches = 200
x_input = tf.placeholder(tf.float32, shape=[None, input_size])
y_input = tf.placeholder(tf.float32, shape=[None, no_classes])
# 权重
weights = tf.Variable(tf.random_normal([input_size, no_classes]))
#偏置量
bias = tf.Variable(tf.random_normal([no_classes]))
#对输入数据进行加权并且偏重处理
logits = tf.matmul(x_input, weights) + bias
#tf.nn是交叉熵和softmax
softmax_cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_input, logits=logits)
#求交叉熵的平均值来计算损失
loss_operation = tf.reduce_mean(softmax_cross_entropy)
#优化器获取损失并将其优化,tf.train 中有很多优化器,这里使用的是vanilla梯度下降
optimiser = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(loss_operation)
#启动会话
session = tf.Session()
# 使用全局变量启动初始化器初始化变量
session.run(tf.global_variables_initializer())
#准备好接收数据并进行训练
for batch_no in range(total_batches):
mnist_batch = mnist_data.train.next_batch(batch_size)
train_images, train_labels = mnist_batch[0], mnist_batch[1]
#optimiser用来接收输出值,loss_operation是传递损失值,有的化说明在被训练,
_, loss_value = session.run([optimiser, loss_operation], feed_dict={x_input: train_images,
y_input: train_labels})
print(loss_value)
#计算准确率来评估模型的工作情况
predictions = tf.argmax(logits, 1)
correct_predictions = tf.equal(predictions, tf.argmax(y_input, 1))
#求正确预测的平均值
accuracy_operation = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
#用测试集来评估数据的准确性
test_images, test_labels = mnist_data.test.images, mnist_data.test.labels
accuracy_value = session.run(accuracy_operation, feed_dict={x_input: test_images,
y_input: test_labels})
print('Accuracy : ', accuracy_value)
session.close()
在运行程序的过程中会出现程序连接中断的故障,是因为数据集mnist第一次从网上下载的
from tensorflow.examples.tutorials.mnist import input_data
多运行几次代码就好了,如果实在不行就自己下载数据集就行了
下载完之后我的文件下就出现了
运行结果如下:
14.193853
11.976693
11.152556
9.816551
10.062865
9.1732855
8.452591
7.39629
7.261834
6.4758296
5.095989
5.3675117
4.346563
5.8800707
4.050165
4.988338
4.207674
4.223306
3.542412
4.6152387
4.2903047
3.1502352
2.654767
3.2581136
2.6391566
2.582851
4.013533
2.93053
2.9647357
3.1097214
3.2510238
3.087494
2.1127334
2.7375786
2.6396368
2.2358656
3.0328584
2.8446057
1.9014186
2.5864
2.4557014
2.0604048
2.678323
2.0324626
2.1053376
1.7908027
2.071164
2.7318442
2.3347359
2.0183458
2.1548445
1.5502448
1.5956453
2.2312446
2.4022517
1.9966538
1.6618974
2.2258823
2.3394287
1.9388988
2.1634982
2.6243417
1.7405152
1.6083919
1.3865634
1.7334721
2.0415554
1.6849861
1.3436967
1.2848032
1.5731153
1.8387886
1.5595846
1.3502707
1.7222123
1.6371703
1.281432
2.1001666
1.5937349
1.7854013
1.7249304
1.595706
1.2963427
1.2165885
1.7543554
1.2786142
1.737834
2.1750572
1.5531785
1.7206784
1.7020774
1.8592744
1.0420597
1.2558794
2.2464373
1.091059
1.428548
1.1194586
1.6839582
1.7219964
1.4865652
1.3515755
1.5324861
1.6667664
1.3281902
1.1460389
1.8825967
1.5745894
1.2487916
1.9258285
1.1637989
1.002647
0.91422826
1.023291
1.5730728
1.214298
1.8587601
0.8384979
1.1753948
1.808498
1.1910837
1.3949873
1.4556371
1.4928048
1.5076202
1.6343887
1.1755352
1.093307
1.3246533
1.6141297
1.4151738
1.1250764
0.9521868
1.2051206
1.0615922
1.3972435
1.3090788
1.0103723
1.4599731
1.1549947
1.0185461
1.4114245
1.31415
1.6407756
0.7462804
0.6913569
1.4364597
1.4817989
1.3700738
1.1351871
1.1323878
1.2835964
0.5260574
0.9555662
0.93056464
0.9683941
0.64839345
0.8722134
1.0275012
1.6065514
1.0029927
1.4628772
1.1451436
1.1110401
0.9737751
1.0064018
1.4781717
1.3430713
1.3098766
1.2825389
1.094347
1.2620716
1.2226298
1.4171237
0.95748407
1.2118229
0.94896215
1.2674838
1.098785
0.81294525
1.061944
1.1593318
1.0882632
0.8745768
1.0962759
1.0665091
1.0902841
1.887527
1.2829604
0.81598336
1.8372782
1.1302813
1.3531325
1.481201
0.86651915
1.0126982
1.1199208
1.252277
0.8720088
0.62692046
Accuracy : 0.8031
准确率是0.8301
如果再次运行会不会结果不一样呢?

可以看到准确率变高了
再次运行
参考
《计算机视觉之深度学习-使用TensorFlow和Keras训练高级神经网络》-英拉贾林加帕.尚穆加马尼