训练一个识别苹果和香蕉的深度学习网络,需要多大规模的样本?

我在百度图片下载了一些香蕉和苹果的图片

图片.png

香蕉图片 195 张
图片.png

苹果图片 263 张
图片.png

一共 458 张图片

下面是我的训练代码

import torch
import torchvision
import torchvision.transforms as transforms

# 加载数据集并进行数据增强
transform_train = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ]
)


image_dir = 'resources/images/train'
model_file = 'models/fruit_classifier.pt'

trainset = torchvision.datasets.ImageFolder(
    root=image_dir,
    transform=transform_train
)
train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=20,
    shuffle=True,
    num_workers=0
)

# 加载预训练模型并修改最后一层
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")
model = torchvision.models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)
model = model.to(device)

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        
        
        print(outputs)
        
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 10 == 9:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0

# 保存模型
torch.save(model.state_dict(), model_file)
print('Finished Training')

下面是我的推理代码

import os
import shutil
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image


image_dir = 'resources/images/inference/mixed'
model_file = 'models/fruit_classifier.pt'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device('mps')
model = models.resnet50(pretrained=True)
model.fc = torch.nn.Linear(2048, 2)
model = model.to(device)
model.load_state_dict(torch.load(model_file))

data_transforms = transforms.Compose([
    # transforms.ToPILImage(mode='RGB'),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


test_images = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]

for test_image in test_images:
    image = Image.open(test_image).convert('RGB')
    image_tensor = data_transforms(image).float()
    image_tensor = image_tensor.unsqueeze_(0)
    # input = Variable(image_tensor)
    input = image_tensor
    input = input.to(device)
    output = model(input)
    
    print(output)
    
    _, preds = torch.max(output, 1)
    if preds == 0:
        shutil.copy(test_image, 'resources/images/inference/classified/bananas')
    else:
        shutil.copy(test_image, 'resources/images/inference/classified/apples')

推理结果非常的糟糕

所有的图片,都被分类为香蕉,一个苹果都没有
图片.png

为什么?因为训练的样本图片太少了?

阅读 1.9k
1 个回答

把你的图像先输入预训练 VGG16 拿到特征,然后再用特征训练三层的 MLP,我估计几百张就可以了。

撰写回答
你尚未登录,登录后可以
  • 和开发者交流问题的细节
  • 关注并接收问题和回答的更新提醒
  • 参与内容的编辑和改进,让解决方法与时俱进
推荐问题