import tensorflow as tf
from tensorflow.keras import layers, datasets, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
#读取数据
base_dir = '../data/cats_and_dogs_filtered'
#指定数据集路径
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
#指定训练集目录
train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')
#指定验证集目录
validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
#构建神经网络
model=tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16,(3,3),activation="relu",input_shape=(150,150,3)),#卷积层设置
tf.keras.layers.MaxPool2D(2,2),#最大池化层
tf.keras.layers.Conv2D(32,(3,3),activation="relu"),#卷积层设置
tf.keras.layers.MaxPool2D(2,2),#最大池化层
tf.keras.layers.Conv2D(64,(3,3),activation="relu"),#卷积层设置
tf.keras.layers.MaxPool2D(2,2),#最大池化层
tf.keras.layers.Flatten(),#将三维数组拉平为一维数组输入全连接层
tf.keras.layers.Dense(128,activation="relu"),#全连接层
tf.keras.layers.Dense(1,activation="sigmoid"),#输出层
])
#训练配置器
model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
loss='binary_crossentropy',
metrics = ['acc']
)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#标准化到[0,1]
train_datagen = ImageDataGenerator( rescale = 1.0/255. )
test_datagen = ImageDataGenerator( rescale = 1.0/255. )
#批量生成20个大小为大小为 150x150 的图像及其标签用于训练
train_generator = train_datagen.flow_from_directory(train_dir,
batch_size=10,
class_mode='binary',
target_size=(150, 150))
#批量生成20个大小为大小为 150x150 的图像及其标签用于验证
validation_generator = test_datagen.flow_from_directory(validation_dir,
batch_size=10,
class_mode = 'binary',
target_size = (150, 150))
#训练配置器
model.compile(
optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
loss='binary_crossentropy',
metrics = ['acc']
)
history = model.fit(train_generator,
validation_data=validation_generator,
steps_per_epoch=200,
epochs=15,
validation_steps=200,
verbose=2)
希望每一轮都可以正常训练
可能是数据加载器出错了导致训练无效吧。