import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from skimage import io, transform
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from torch.utils.data import Dataset, DataLoader
import cv2
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
transform = transforms.Compose(
[
transforms.ToPILImage(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# 图像在[-1,1]范围内归一化,image =(图像-平均值)/ std
# 通用的统计值
transforms.Resize((255,255)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
#(0.5,0.5,0.5) 是 R G B 三个通道上的均值, 后面(0.5, 0.5, 0.5)是三个通道的标准差
])
def picCount(root_dir):
count = 0
for file in os.listdir(root_dir): # file 表示的是文件名
count = count + 1
return count
# 数据集类
class TestDataset(Dataset):
def __init__(self, root_dir, transform=transform):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return picCount(self.root_dir)
def __getitem__(self, idx):
path_list = os.listdir(self.root_dir)
#path_list.remove('.DS_Store') # macos中的文件管理文件,默认隐藏,这里可以忽略
#print(path_list)
img_name = os.path.join(self.root_dir,
path_list[idx])
image = io.imread(img_name)
# plt.imshow(image)
# plt.show()
if self.transform:
image = self.transform(image)
sample = {'image': image, 'name': path_list[idx]}
return sample
train_data = TestDataset(root_dir='./data/someTry/in',transform=transform)
train_loader = DataLoader(dataset=train_data, batch_size=5, shuffle=False ,num_workers=0)
# 输出图像的函数
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
#imshow(train_data.__getitem__(2)['image'])
dataiter = iter(train_loader)
print(dataiter.next())
image = dataiter.next()['image']
name = dataiter.next()['name']
# 显示图片
imshow(torchvision.utils.make_grid(image))
# 打印图片标签
print(' '.join('%5s' % name[j] for j in range(4)))
io读

cv2读

因为通道会被改变。
在使用io.imread时,进行transforms时,有时需要转换成PIL格式,即需要
ToPILImage()
版权声明:本文为qq_45912513原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。