tensorflow.Keras 使用Resnet50 实现猫狗识别

最近几天做点小东西,因为懒所以不想用tensorflow或者slim再在底层写layer,就直接使用了tensorflow里面自带的模型,处理下数据,直接用了,后面想想还是比较有意思的,就把这个东西分享一下。

 

首先发个效果

                   

 

接着直接上代码:

首先是工具文件:

import os,sys
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions


def GetLabel(path):
    file = open(path)
    filenames = []
    labels = []
    for line in file:
        filename, label = line.split(' ')
        filenames.append(filename)
        labels.append(int(label))
    return filenames, labels


def getDecodes(labels):
    wage = max(labels)
    count = len(labels)
    lbl = np.zeros((count, wage+1))
    for i in range(len(labels)):
        lbl[i, labels[i]] = 1
    return lbl


def getImgRect(dirname, filenames):
    length = len(filenames)
    Images = np.zeros((length, 224, 224, 3))

    for i in range(len(filenames)):
        fileName = os.path.join(dirname, filenames[i])
        Image = ImageEncode(fileName)
        Images[i, :, :, :] = Image[0, :, :, :]

    return Images


def transClasses(filename):
    kinds = {}
    f = open(filename)
    for line in f:
        label, kind = line.split(' ')
        label = int(label)
        kind = kind.strip()
        kinds[label] = kind
    return kinds

def ImageEncode(img_path):
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    return x

if __name__ == '__main__':
    x = ImageEncode('D:\\guochuang\\Images\\kaggle猫狗大战数据集\\train\\train\\cat.0.jpg')
    print(x.shape)

接下来是我们通过猫狗大战的数据集生成训练数据的代码:

import os
import utils
import numpy as np
import cv2

path = 'D:\\guochuang\\Images\\kaggle猫狗大战数据集\\train\\train'


def getpictures(mount):
    image = np.zeros((mount*2, 224, 224, 3))
    label = np.zeros((2*mount, 2))
    for i in range(mount):
        cat_file = 'cat.' + str(i) + '.jpg'
        full_path = os.path.join(path, cat_file)
        Img = utils.ImageEncode(full_path)
        image[i, :, :, :] = Img[0, :, :, :]
        label[i, 0] = 1
    for i in range(mount, 2*mount):
        cat_file = 'dog.' + str(i) + '.jpg'
        full_path = os.path.join(path, cat_file)
        Img = utils.ImageEncode(full_path)
        image[i, :, :, :] = Img[0, :, :, :]
        label[i, 1] = 1

    return image, label


if __name__ == '__main__':
    image, label = getpictures(2000)
    cv2.imshow('mat', np.array(image[3, :, :, :], dtype=np.uint8))
    cv2.waitKey(0)
    print(label[3, :])

接下来是重头戏,可能也是大家想要的,训练和识别的代码:

from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
import cv2
from tensorflow.keras.models import load_model
import utils
import DogsCat as DC

font = cv2.FONT_HERSHEY_SIMPLEX

model = ResNet50(
    weights=None,
    classes=2
)

if __name__ == '__main__':

    # Images, labels = DC.getpictures(3000)
    #
    # model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
    # model.fit(
    #     x=Images,
    #     y=labels,
    #     epochs=10,
    #     batch_size=5
    # )
    #
    #
    # model.save('my_model.h5')

    imgpath = "cat.jpg"

    model = load_model('dogcat.h5')

    code = utils.ImageEncode(imgpath)
    ret = model.predict(code)
    res1 = np.argmax(ret[0, :])

    img = cv2.imread(imgpath)
    if res1:
        cv2.putText(img, 'dog', (50, 100), font, 2, (255, 255, 255), 7)
        cv2.imshow('mat', img)
    else:
        cv2.putText(img, 'cat', (50, 100), font, 2, (255, 255, 255), 7)
        cv2.imshow('mat', img)

    cv2.waitKey(0)

难度基本是小学级别的。

 

这个是网上某大哥提供的数据集:

https://download.csdn.net/download/qq_38210185/10227930?utm_source=bbsseo

然后代码和训练模型已经上传到了github上了,下附地址

代码:https://github.com/1695973632/DogAndcat

模型:链接:https://pan.baidu.com/s/1QF4Vu1rSRscFEH5GpWGbPg  密码:x3xm


版权声明:本文为qq_31622541原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。