实战Kaggle比赛(1):树叶分类

比赛地址:https://www.kaggle.com/c/leaf-classification/rules
完整代码:https://github.com/SPECTRELWF/kaggle_competition
个人主页:liuweifeng.top:8090

比赛题目:对树叶的类别进行分类,树叶总共99个类别,树叶的图片如下:
在这里插入图片描述
在这里插入图片描述
我也不知道怎么分类,反正总共有99中类别的树叶。下载到的数据集解压后如下:
在这里插入图片描述
image里面存了所有的树叶图像,train.csv是训练文件的标号以及类别,后面有一堆的特征,我没用到,因为比赛已经结束了,我只是纯纯的拿了练习下CNN。test.csv文件是测试数据的标号,sample_submission.csv文件是提交样例,长这样:
在这里插入图片描述
第一列是id,后面的99列是对应的每个类别的概率,分类结果加上softmax就行。

思路:

直接使用的基于ImageNet预训练的resnet101,微调一下。

预处理

将训练集的id和label写到一个txt文件中,测试集的id写入另一个txt文件:

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 上午10:27

import os
import pandas as pd
classes = ['Acer_Capillipes', 'Acer_Circinatum', 'Acer_Mono', 'Acer_Opalus', 'Acer_Palmatum', 'Acer_Pictum', 'Acer_Platanoids', 'Acer_Rubrum', 'Acer_Rufinerve', 'Acer_Saccharinum', 'Alnus_Cordata', 'Alnus_Maximowiczii', 'Alnus_Rubra', 'Alnus_Sieboldiana', 'Alnus_Viridis', 'Arundinaria_Simonii', 'Betula_Austrosinensis', 'Betula_Pendula', 'Callicarpa_Bodinieri', 'Castanea_Sativa', 'Celtis_Koraiensis', 'Cercis_Siliquastrum', 'Cornus_Chinensis', 'Cornus_Controversa', 'Cornus_Macrophylla', 'Cotinus_Coggygria', 'Crataegus_Monogyna', 'Cytisus_Battandieri', 'Eucalyptus_Glaucescens', 'Eucalyptus_Neglecta', 'Eucalyptus_Urnigera', 'Fagus_Sylvatica', 'Ginkgo_Biloba', 'Ilex_Aquifolium', 'Ilex_Cornuta', 'Liquidambar_Styraciflua', 'Liriodendron_Tulipifera', 'Lithocarpus_Cleistocarpus', 'Lithocarpus_Edulis', 'Magnolia_Heptapeta', 'Magnolia_Salicifolia', 'Morus_Nigra', 'Olea_Europaea', 'Phildelphus', 'Populus_Adenopoda', 'Populus_Grandidentata', 'Populus_Nigra', 'Prunus_Avium', 'Prunus_X_Shmittii', 'Pterocarya_Stenoptera', 'Quercus_Afares', 'Quercus_Agrifolia', 'Quercus_Alnifolia', 'Quercus_Brantii', 'Quercus_Canariensis', 'Quercus_Castaneifolia', 'Quercus_Cerris', 'Quercus_Chrysolepis', 'Quercus_Coccifera', 'Quercus_Coccinea', 'Quercus_Crassifolia', 'Quercus_Crassipes', 'Quercus_Dolicholepis', 'Quercus_Ellipsoidalis', 'Quercus_Greggii', 'Quercus_Hartwissiana', 'Quercus_Ilex', 'Quercus_Imbricaria', 'Quercus_Infectoria_sub', 'Quercus_Kewensis', 'Quercus_Nigra', 'Quercus_Palustris', 'Quercus_Phellos', 'Quercus_Phillyraeoides', 'Quercus_Pontica', 'Quercus_Pubescens', 'Quercus_Pyrenaica', 'Quercus_Rhysophylla', 'Quercus_Rubra', 'Quercus_Semecarpifolia', 'Quercus_Shumardii', 'Quercus_Suber', 'Quercus_Texana', 'Quercus_Trojana', 'Quercus_Variabilis', 'Quercus_Vulcanica', 'Quercus_x_Hispanica', 'Quercus_x_Turneri', 'Rhododendron_x_Russellianum', 'Salix_Fragilis', 'Salix_Intergra', 'Sorbus_Aria', 'Tilia_Oliveri', 'Tilia_Platyphyllos', 'Tilia_Tomentosa', 'Ulmus_Bergmanniana', 'Viburnum_Tinus', 'Viburnum_x_Rhytidophylloides', 'Zelkova_Serrata']

train_txt = open('train.txt','w')
train_csv = pd.read_csv(r'leaf-classification/train.csv')
ids = train_csv['id']
species = train_csv['species']

for i in range(len(ids)):
    train_txt.write(str(ids[i]))
    train_txt.write(' ')
    train_txt.write(str(classes.index(str(species[i]))))
    train_txt.write('\n')
train_txt.close()

test_txt = open('test.txt','w')
test_csv = pd.read_csv(r'leaf-classification/test.csv')
ids = test_csv['id']
for i in range(len(ids)):
    test_txt.write(str(ids[i]))
    test_txt.write('\n')
test_txt.close()


模型resnet101

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 上午10:24
import torch
import torchvision.models
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models as models

class resnet101(nn.Module):
    def __init__(self, num_classes=1000):
        super(resnet101, self).__init__()
        self.num_classes = num_classes
        self.feature_extract = torchvision.models.resnet101(pretrained=True)
        self.net = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.feature_extract(x)
        x = self.net(x)
        return x


# x = torch.randn((1,3,224,224))
# net = resnet101(num_classes=99)
# print(net)
# print(net(x).shape)

dataloader

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 上午10:24
import numpy as np
import torch.utils.data as data
import torch
import torchvision.transforms as transforms
from PIL import Image
data_root = r'leaf-classification/images/'

class leaf_Dataset(data.Dataset):
    def __init__(self,is_train=True,transform=None):
        self.is_train = is_train
        self.transform = transform
        self.images = []
        self.labels = []
        if is_train:
            file = open('train.txt','r')
            lines = file.readlines()
            for line in lines:
                res = line[:-1]
                image = res.split(' ')[0]
                label = int(res.split(' ')[1])
                self.images.append(image)
                self.labels.append(label)
            print(self.images)
            print(self.labels)
    def __len__(self):
        return len(self.images)
    def __getitem__(self, index):
        image_name = self.images[index] + '.jpg'
        image_path = data_root + image_name
        img = Image.open(image_path).convert('RGB')
        # print(img)
        img = self.transform(img)
        label = self.labels[index]
        label = torch.from_numpy(np.array(label))
        return img, label

transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 上午10:25


"""
使用imagenet预训练的rennet101来在树叶数据集上面进行微调
"""
import torch
import torchvision.transforms as transforms
from dataset import leaf_Dataset
import torch.utils.data as data
import torch.optim as optim
import torch.nn as nn
from resnet import resnet101
#使用Adam优化器来训练网络,不冻结参数

# 设置hyperparameter

epoch = 200
lr = 1e-3
b1 = 0.9
b2 = 0.999
device = torch.device('cuda:0')
train_loss = []
# 初始化网络模型
net = resnet101(num_classes=99)
net.to(device)

# load data
transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
])
data = leaf_Dataset(is_train=True,transform=transforms)
dataloader = torch.utils.data.DataLoader(data,
                                         batch_size=64,
                                         shuffle=True)

loss_func = nn.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(),lr=lr,betas=(b1,b2))


for epoch in range(1,epoch + 1):
    for i, (x,y) in enumerate(dataloader):
        x = x.to(device)
        y = y.to(device)
        pred = net(x)
        loss = loss_func(pred,y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_loss.append(loss.item())
        print("epoch: %d   batch_idx:%d   loss:%.3f" %(epoch,i,loss.item()))
    torch.save(net.state_dict(),'model/epoch:%d'%epoch + '.pth')
from utils import plot_curve
plot_curve(train_loss)

loss

在这里插入图片描述

将预测结果写入要提交的文件

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/12/8 下午5:42
import torch
import torchvision.transforms as transforms
import numpy as np
import os
from PIL import Image
from resnet import resnet101
import torch.nn.functional as F

image_path = r'leaf-classification/images'
f = open('test.txt','r')
tmp = f.readlines()
test_file = []
for i in tmp:
    i = i[:-1]
    test_file.append(i+'.jpg')
print(test_file)

device = torch.device('cuda:0')
net = resnet101(num_classes=99)
print('load weight........')
net.load_state_dict(torch.load('model/epoch:200.pth'))
net.to(device)
net.eval()
transformss = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])
res = []
with torch.no_grad():
    for image in test_file:
        img = Image.open(os.path.join(image_path,image)).convert('RGB')
        img = transformss(img)
        img = torch.unsqueeze(img,dim=0)
        img = img.to(device)
        # print(img.shape)
        pred = net(img)
        pred = F.softmax(pred).flatten()
        pred = pred.cpu().numpy()
        print(pred)
        res.append(pred)
        np.savetxt("result.csv",res,delimiter = ',')


版权声明:本文为lwf1881原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。