pytorch 人脸识别

import torch
import os
import numpy as np
import torch.nn as nn
import  matplotlib.pyplot as plt
import time
import torchvision
from torchvision import  transforms,models,datasets
import torch.optim as optim

#训练集在train文件夹下,每种类别的人脸都位于同一个子目录下。验证集数据类似
data_dir="F:/muct人脸数据库_项目"
train_dir=data_dir+"/train"
valid_dir=data_dir+"/valid"


data_transform=transforms.Compose(
        [
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ]
    )

#加载数据
image_datasets={
    x:torchvision.datasets.ImageFolder(os.path.join(data_dir,x),
    transform=data_transform)for x in ["train","valid"]
}

dataLoaders={x:torch.utils.data.DataLoader(image_datasets[x],
             batch_size=4,shuffle=True)for x in ["train","valid"]
}

#使用GPU
device=torch.device("cuda")

#加载模型
model=torchvision.models.resnet152(True)
#冻住模型中的参数
for param in model.parameters():
    param.requires_grad=False

#修改最后的全连接层,使其适应咱们的项目-276分类
#in_features是全连接层中的输入的维数
num_fts=model.fc.in_features
model.fc=nn.Linear(num_fts,276)


#将模型加载到GPU
model=model.to(device)

#设置优化器
optimizer=optim.Adam(model.fc.parameters(),lr=1e-2)

#损失函数
criterion=nn.CrossEntropyLoss()

for epoch in range(5):
    print("Epoch:", epoch)
    print("---" * 5)
    for phase in ["train","valid"]:
        rightnumber = 0
        rightacc = 0
        if phase =="train":
            model.train()
        else:
            model.eval()

         # 把数据都取个遍
        for inputs, labels in dataLoaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 清零
            optimizer.zero_grad()
            # 只有训练的时候计算和更新梯度
            with torch.set_grad_enabled(phase == 'train' or phase == 'valid'):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                # 计算损失
                rightnumber+= torch.sum(preds == labels.data)

            rightacc = rightnumber.double() / len(dataLoaders[phase].dataset)
        print(rightacc.item())

版权声明:本文为qq_43071209原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。