class Reader(Dataset):
def __init__(self, data_path: str, is_val: bool = False):
super().__init__()
self.data_path = data_path
with open(os.path.join(self.data_path, "label_dict.txt"), "r", encoding="utf-8") as f:
self.info = ast.literal_eval(f.read())
self.img_paths = [os.path.join(self.data_path, img_name) for img_name in self.info]
self.img_paths = self.img_paths[-250:] if is_val else self.img_paths[:-250]
def __getitem__(self, index):
file_path = self.img_paths[index]
file_name = os.path.basename(file_path)
with open(file_path, 'rb') as f:
img = Image.open(f).convert('RGB')
img = np.array(img, dtype="float32") / 255
img = img.reshape((IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W))
label = [CHAR_TO_IDX[char] for char in self.info[file_name]]
label = np.array(label, dtype="int64") # 确保 label 是 int64 类型
label_length = len(label)
input_length = np.array([IMAGE_SHAPE_W], dtype="int64") # 确保 input_length 是 int64 类型
return img, label, label_length, input_length
def __len__(self):
return len(self.img_paths)
...
class CTCLoss(paddle.nn.Layer):
def __init__(self):
super().__init__()
def forward(self, ipt, label, label_lengths, input_lengths):
# 转换 ipt 的维度顺序,并确保是 float32 类型
ipt = paddle.transpose(ipt, perm=[1, 0, 2])
ipt = paddle.cast(ipt, 'float32')
# 确保 label, label_lengths, 和 input_lengths 是 int64 类型
label = paddle.to_tensor(label, dtype='int64')
label_lengths = paddle.to_tensor(label_lengths, dtype='int64')
input_lengths = paddle.to_tensor(input_lengths, dtype='int64')
# 计算损失,确保 blank 索引正确
loss = paddle.nn.functional.ctc_loss(ipt, label, input_lengths, label_lengths, blank=0)
return loss
来回报错:
ValueError: (InvalidArgument) The type of data we are trying to retrieve (int32) does not match the type of data (int64) currently contained in the container.
[Hint: Expected dtype() == phi::CppTypeToDataType<T>::Type(), but received dtype():9 != phi::CppTypeToDataType<T>::Type():7.] (at ..\paddle\phi\core\dense_tensor.cc:171)
修改完要输入的类型为int32后运行又报错
ValueError: (InvalidArgument) The type of data we are trying to retrieve (int64) does not match the type of data (int32) currently contained in the container.
[Hint: Expected dtype() == phi::CppTypeToDataType<T>::Type(), but received dtype():7 != phi::CppTypeToDataType<T>::Type():9.] (at ..\paddle\phi\core\dense_tensor.cc:185)
我又改回相关参数为int64后又报错
形成了一个闭环,求大佬解决,这是一个通过飞桨训练验证码识别模型的代码