基于TensorFlow实现线性回归的模型训练预测一(一个样本一个特征深度学习,单层网络)

例子为:100个样本,每个样本1个特征

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

if __name__ == '__main__':
    with tf.Graph().as_default():
        # 一、执行图的构建
        # a. 定义占位符
        input_x = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='x')  # [None, 1],None表示不知道有几个样本,1表示一个样本里面有一个特征
        input_y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='y')

        # b. 定义模型参数
        w = tf.get_variable(name='w', shape=[1, 1], dtype=tf.float32,
                            initializer=tf.random_normal_initializer(mean=0.0, stddev=1.0))
        b = tf.get_variable(name='b', shape=[1], dtype=tf.float32,
                            initializer=tf.zeros_initializer())

        # c. 模型预测的构建(获取预测值)
        y_ = tf.matmul(input_x, w) + b

        # d. 损失函数构建(平方和损失函数)
        loss = tf.reduce_mean(tf.square(input_y - y_))

        # e. 定义优化器(优化器的意思:求解让损失函数最小的模型参数<变量>的方式)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
        # f. 定义一个训练操作对象
        train_op = optimizer.minimize(loss=loss)

        # 二、执行图的训练运行
        with tf.Session() as sess:
            # a. 变量的初始化操作
            sess.run(tf.global_variables_initializer())

            # b. 训练数据的产生/获取(基于numpy随机产生<可以先考虑一个固定的数据集>)
            N = 100
            x = np.linspace(0, 6, N) + np.random.normal(0, 2.0, N)
            y = 14 * x + 7 + np.random.normal(0, 5.0, N)
            x.shape = -1, 1
            y.shape = -1, 1
            print((np.shape(x), np.shape(y)))

            # c. 模型训练
            for step in range(100):
                # 1. 触发模型训练操作
                _, loss_ = sess.run([train_op, loss], feed_dict={
                    input_x: x,
                    input_y: y
                })
                print("第{}次训练后模型的损失函数为:{}".format(step, loss_))

            # d. 构造一个测试数据或者使用训练数据,得到该数据对应的预测值,做一个可视化的操作
            predict = sess.run(y_, feed_dict={input_x: x})
            plt.plot(x, y, 'ro')
            plt.plot(x, predict, 'g-')
            plt.show()


版权声明:本文为shitoucoming原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。