一:自定义数据处理函数
#txt:txt文件里面存放图像名称
def read_images(txt, batch_size):
imagepaths, labels = list(), list() #分别存放图像路径和label
images_names = []
with open('/data/' + txt, 'r') as r:
images_names.extend(r.readlines())
for name in images_names:
name = name.replace('\n', '')
imagepaths.append(os.path.join(cover_path, name))
labels.append(0)
imagepaths.append(os.path.join(stego_path, name))
labels.append(1)
# Convert to Tensor
imagepaths = tf.convert_to_tensor(imagepaths, dtype=tf.string)
labels = tf.convert_to_tensor(labels, dtype=tf.int32)
# Build a TF Queue, shuffle data
image, label = tf.train.slice_input_producer([imagepaths, labels],
shuffle=True)
# Read images from disk
image = tf.read_file(image)
image = tf.image.decode_jpeg(image, channels=CHANNELS)
# Resize images to a common size
image = tf.image.resize_images(image, [IMG_HEIGHT, IMG_WIDTH])
# Normalize
image = image * 1.0 / 127.5 - 1.0
# Create batches
X, Y = tf.train.batch([image, label], batch_size=batch_size,
capacity=batch_size * 8,
num_threads=4)
return X, Y
二:训练模型
# -*- coding: utf-8 -*-
import scipy.misc
import os
import tensorflow as tf
import numpy as np
# import pandas as pd
import cv2
from random import shuffle
from data_utils import *
from utils import read_images
from os import walk
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.logging.set_verbosity(tf.logging.ERROR)
batch_size = 16
def train(save_model_path):
train_x, train_y = read_images('train.txt', batch_size)
val_x, val_y = read_images('val.txt', batch_size)
test_x, test_y = read_images('test.txt', batch_size)
# model setup
with tf.variable_scope('xuNet'):
logits_train = discriminator(train_x, is_training=True)
with tf.variable_scope('xuNet', reuse=tf.AUTO_REUSE):
logits_val = discriminator(val_x, is_training=False)
with tf.variable_scope('xuNet', reuse=tf.AUTO_REUSE):
logits_test = discriminator(test_x, is_training=False)
# calculate loss
loss_op = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_train, labels=train_y))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001, beta1=0.5)
train_op = optimizer.minimize(loss_op)
train_correct_pred = tf.equal(tf.argmax(logits_train, 1), tf.cast(train_y, tf.int64))
train_acc = tf.reduce_mean(tf.cast(train_correct_pred, tf.float32))
val_correct_pred = tf.equal(tf.argmax(logits_val, 1), tf.cast(val_y, tf.int64))
val_acc = tf.reduce_mean(tf.cast(val_correct_pred, tf.float32))
test_correct_pred = tf.equal(tf.argmax(logits_test, 1), tf.cast(test_y, tf.int64))
test_acc = tf.reduce_mean(tf.cast(test_correct_pred, tf.float32))
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=50) # maximum 100 latest models are saved
with tf.Session() as sess:
# saver.restore(sess, '/data/zhouzl/de_attack/model/xu_spatial_02/xu_spatial_model_860.ckpt')
sess.run(init)
# Start the data queue
tf.train.start_queue_runners()
pre_acc = 0
for epoch in range(200):
train_acc_per_epoch = np.array([])
for step in range(round(18000*2 / batch_size)):
_, train_accuracy, train_loss = sess.run([train_op, train_acc, loss_op])
train_acc_per_epoch = np.insert(train_acc_per_epoch, 0, train_accuracy)
if step % 200 == 0:
print("epoch " + str(epoch) + ", step " + str(step) + ", train mean Accuracy= " + str(
np.mean(train_acc_per_epoch)
+ ", train loss " + str(train_loss)))
if epoch % 5 == 0:
val_acc_per_epoch = np.array([])
for step in range(round(9000 * 2 / batch_size)):
val_accuracy = sess.run(val_acc)
val_acc_per_epoch = np.insert(val_acc_per_epoch, 0, val_accuracy)
print('epoch {}, val mean accuracy {}'.format(epoch, np.mean(val_acc_per_epoch)))
if epoch % 10 == 0:
test_acc_per_epoch = np.array([])
for step in range(round(3000 * 2 / batch_size)):
test_accuracy = sess.run(test_acc)
test_acc_per_epoch = np.insert(test_acc_per_epoch, 0, test_accuracy)
test_mean_acc = np.mean(test_acc_per_epoch)
print('epoch {}, test mean accuracy {}'.format(epoch, test_mean_acc))
if test_mean_acc > pre_acc and epoch % 10 == 0:
saver.save(sess, save_model_path + '/xu_model_' + str(test_mean_acc) + '.ckpt')
pre_acc = test_mean_acc
if __name__ == '__main__':
save_model_path = '/model'
if not os.path.exists(save_model_path):
os.mkdir(save_model_path)
train(save_model_path)
版权声明:本文为zzldm原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。