从CIFAR10 读取图像并显示

# -*- coding: utf-8 -*-

import pickle
import cv2
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import os
from resnet import ResNet18
import numpy as np
import matplotlib.pyplot as plt


print("Waiting Test!")
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# 准备数据集并预处理
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)
# Cifar-10的标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 模型定义-ResNet
net = ResNet18().to(device)
pthfile = r'E:\classficalnet\residual-attention-network-master\residual-attention-network-master\imagenet_model\model\net_010.pth'
net.load_state_dict(torch.load(pthfile))



if __name__ == "__main__":

    for data in testloader:

        images, labels = data
        img = images[0]  # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
        img = img.numpy()  # FloatTensor转为ndarray
        img = np.transpose(img, (1, 2, 0))  # 把channel那一维放到最后
        img = (img*255).astype(np.uint8)

将读取的图片输入到网络中并进行显示

   images = img
        images = (images / 255).astype(np.float32) # 转换数据类型
        images = np.transpose(images, (2, 0, 1))  # 把channel那一维放到最后

        images = torch.from_numpy(images).unsqueeze(0) #将单个数据图片转化为batchsize的尺寸


        images = images.to(device)
        outputs = net(images)
        # 取得分最高的那个类 (outputs.data的索引号)
        _, predicted = torch.max(outputs.data, 1)
        print(predicted) # 输出结果

版权声明:本文为weixin_43474255原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。