TensorFlow训练模型时,自定义数据处理函数,方便快速加载数据训练模型

一:自定义数据处理函数

#txt:txt文件里面存放图像名称
def read_images(txt, batch_size):
    imagepaths, labels = list(), list() #分别存放图像路径和label
    images_names = []

    with open('/data/' + txt, 'r') as r:
        images_names.extend(r.readlines())
    for name in images_names:
        name = name.replace('\n', '')
        imagepaths.append(os.path.join(cover_path, name))
        labels.append(0)
        imagepaths.append(os.path.join(stego_path, name))
        labels.append(1)

    # Convert to Tensor
    imagepaths = tf.convert_to_tensor(imagepaths, dtype=tf.string)
    labels = tf.convert_to_tensor(labels, dtype=tf.int32)
    # Build a TF Queue, shuffle data
    image, label = tf.train.slice_input_producer([imagepaths, labels],
                                                 shuffle=True)

    # Read images from disk
    image = tf.read_file(image)
    image = tf.image.decode_jpeg(image, channels=CHANNELS)

    # Resize images to a common size
    image = tf.image.resize_images(image, [IMG_HEIGHT, IMG_WIDTH])

    # Normalize
    image = image * 1.0 / 127.5 - 1.0

    # Create batches
    X, Y = tf.train.batch([image, label], batch_size=batch_size,
                          capacity=batch_size * 8,
                          num_threads=4)

    return X, Y

二:训练模型


# -*- coding: utf-8 -*-

import scipy.misc
import os

import tensorflow as tf
import numpy as np
# import pandas as pd
import cv2
from random import shuffle
from data_utils import *
from utils import read_images
from os import walk
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

tf.logging.set_verbosity(tf.logging.ERROR)
batch_size = 16


def train(save_model_path):
    train_x, train_y = read_images('train.txt', batch_size)
    val_x, val_y = read_images('val.txt', batch_size)
    test_x, test_y = read_images('test.txt', batch_size)

    # model setup
    with tf.variable_scope('xuNet'):
        logits_train = discriminator(train_x, is_training=True)

    with tf.variable_scope('xuNet', reuse=tf.AUTO_REUSE):
        logits_val = discriminator(val_x, is_training=False)

    with tf.variable_scope('xuNet', reuse=tf.AUTO_REUSE):
        logits_test = discriminator(test_x, is_training=False)

    # calculate loss
    loss_op = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_train, labels=train_y))
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.5)
    train_op = optimizer.minimize(loss_op)

    train_correct_pred = tf.equal(tf.argmax(logits_train, 1), tf.cast(train_y, tf.int64))
    train_acc = tf.reduce_mean(tf.cast(train_correct_pred, tf.float32))

    val_correct_pred = tf.equal(tf.argmax(logits_val, 1), tf.cast(val_y, tf.int64))
    val_acc = tf.reduce_mean(tf.cast(val_correct_pred, tf.float32))

    test_correct_pred = tf.equal(tf.argmax(logits_test, 1), tf.cast(test_y, tf.int64))
    test_acc = tf.reduce_mean(tf.cast(test_correct_pred, tf.float32))

    init = tf.global_variables_initializer()
    saver = tf.train.Saver(max_to_keep=50)  # maximum 100 latest models are saved
    with tf.Session() as sess:
        # saver.restore(sess, '/data/zhouzl/de_attack/model/xu_spatial_02/xu_spatial_model_860.ckpt')
        sess.run(init)
        # Start the data queue
        tf.train.start_queue_runners()

        pre_acc = 0

        for epoch in range(200):
            train_acc_per_epoch = np.array([])
            for step in range(round(18000*2 / batch_size)):
                _, train_accuracy, train_loss = sess.run([train_op, train_acc, loss_op])
                train_acc_per_epoch = np.insert(train_acc_per_epoch, 0, train_accuracy)
                if step % 200 == 0:
                    print("epoch " + str(epoch) + ", step " + str(step) + ", train mean Accuracy= " + str(
                        np.mean(train_acc_per_epoch)
                        + ", train loss " + str(train_loss)))

            if epoch % 5 == 0:
                val_acc_per_epoch = np.array([])
                for step in range(round(9000 * 2 / batch_size)):
                    val_accuracy = sess.run(val_acc)
                    val_acc_per_epoch = np.insert(val_acc_per_epoch, 0, val_accuracy)
                print('epoch {}, val mean accuracy {}'.format(epoch, np.mean(val_acc_per_epoch)))

            if epoch % 10 == 0:
                test_acc_per_epoch = np.array([])
                for step in range(round(3000 * 2 / batch_size)):
                    test_accuracy = sess.run(test_acc)
                    test_acc_per_epoch = np.insert(test_acc_per_epoch, 0, test_accuracy)
                test_mean_acc = np.mean(test_acc_per_epoch)
                print('epoch {}, test mean accuracy {}'.format(epoch, test_mean_acc))
                if test_mean_acc > pre_acc and epoch % 10 == 0:
                    saver.save(sess, save_model_path + '/xu_model_' + str(test_mean_acc) + '.ckpt')
            pre_acc = test_mean_acc


if __name__ == '__main__':

    save_model_path = '/model'
    if not os.path.exists(save_model_path):
        os.mkdir(save_model_path)
    train(save_model_path)

reference


版权声明:本文为zzldm原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。