pytorch 搭建cnn resnet50网络进行图片分类 代码详解

数据样式:

 

直接上代码:

import pathlib
import tensorflow as tf
import matplotlib.pyplot as plt
import os, PIL, pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keras
import torch,torchvision
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.models import resnet50
warnings.filterwarnings("ignore")  # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

data_dir_train = "./Chinese Medicine dataes dealed/train"
data_dir_train  = pathlib.Path(data_dir_train )
image_count_train = len(list(data_dir_train.glob('*/*')))
print("训练数据图片总数为:", image_count_train)

dat

版权声明:本文为qq_38735017原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。