深度学习之Bert中文分类学习
BERT实验
预训练结果分析
tfhub_handle_preprocess = "https://hub.tensorflow.google.cn/tensorflow/bert_zh_preprocess/3"
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)
text_test = ['我真是个天才啊!!']
text_preprocessed = bert_preprocess_model(text_test)
print(f'Keys : {list(text_preprocessed.keys())}')
print(f'Shape : {text_preprocessed["input_word_ids"].shape}')
print(f'Word Ids : {text_preprocessed["input_word_ids"][0, :12]}')
print(f'Input Mask : {text_preprocessed["input_mask"][0, :12]}')
print(f'Type Ids : {text_preprocessed["input_type_ids"][0, :12]}')
打印结果
Keys : ['input_mask', 'input_type_ids', 'input_word_ids']
Shape : (1, 128)
Word Ids : [ 101 2769 4696 3221 702 1921 2798 1557 8013 8013 102 0]
Input Mask : [1 1 1 1 1 1 1 1 1 1 1 0]
Type Ids : [0 0 0 0 0 0 0 0 0 0 0 0]
Shape中的128猜想应该是 最大长度。
Keys对应了三个属性,但其实bert应该是有7个特征属性。为什么另外四个属性在这里没有,目前不是很清楚,但我觉得Hub里面的Advance topics估计是在讲这个事情。但由于现在要做的任务是文本分类,以下的四个特征是不需要的。
- input_ids: 输入的token对应的id
- input_mask: 输入的mask,1代表是正常输入,0代表的是padding的输入
- segment_ids: 输入的0:代表句子A或者padding句子,1代表句子B
- masked_lm_positions:我们mask的token的位置
- masked_lm_ids:我们mask的token的对应id
- masked_lm_weights:我们mask的token的权重,1代表是真实mask的,0代表的是padding的mask
- next_sentence_labels:句子A和B是否是上下句
接下来看下输出怎么理解。为了保证自己的猜想正确,我又丢了一句真是个天才啊!!
进去看结果。把两个结果合并起来是这样的。
Word Ids : [ 101 2769 4696 3221 702 1921 2798 1557 8013 8013 102 0]
Word Ids : [ 101 4696 3221 702 1921 2798 1557 8013 8013 102 0 0]
Input Mask : [1 1 1 1 1 1 1 1 1 1 1 0]
Input Mask : [1 1 1 1 1 1 1 1 1 1 0 0]
Type Ids : [0 0 0 0 0 0 0 0 0 0 0 0]
Type Ids : [0 0 0 0 0 0 0 0 0 0 0 0]
有意思吧。bert目前没有分词组的概念,它是一个一个词分开的。根据资料查看bert的预训练分成两步,这个显然是第一步。
我真是个天才啊啊!
被分割成了[CLS] 我 真 是 个 天 才 啊 啊 ![SEP]
,所以这个跟Word Ids是能对得上的。
TypeIds应该跟segment_ids是对等的,也是对得上的,都是0。
Input Mask也是对得上的。它的算法如下
mask是1表示是"真正"的Token,0则是Padding出来的。
在后面的Attention时会通过tricky的技巧让模型不能attend to这些padding出来的Token上。input_mask = [1] * len(input_ids)
结果输出分析
tfhub_handle_encoder = "https://hub.tensorflow.google.cn/tensorflow/bert_zh_L-12_H-768_A-12/4"
bert_model = hub.KerasLayer(tfhub_handle_encoder)
bert_results = bert_model(text_preprocessed)
print(f'Keys : {list(bert_results.keys())}')
print(f'Pooled Outputs Shape:{bert_results["pooled_output"].shape}')
print(f'Pooled Outputs Values:{bert_results["pooled_output"][0, :12]}')
print(f'Sequence Outputs Shape:{bert_results["sequence_output"].shape}')
print(f'Sequence Outputs Values:{bert_results["sequence_output"][0, :12]}')
- sequence_output:维度【batch_size, seq_length, hidden_size】,这是训练后每个token的词向量。
- pooled_output:维度是【batch_size, hidden_size】,每个sequence第一个位置CLS的向量输出,用于分类任务。
BERT实战中文多分类
参考官方的BERT分类,官方用的是情感分析的场景,其实跟中文多分类场景类似。
要做的任务就是换BERT的pre-process和encoder,再把最后一层输出换成多分类即可。
基本可以照搬。
环境准备
pip install -q tensorflow-text
pip install -q tf-models-official
import tensorflow as tf
import os
import shutil
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization # to create AdamW optimizer
import matplotlib.pyplot as plt
tfhub_handle_preprocess = "https://hub.tensorflow.google.cn/tensorflow/bert_zh_preprocess/3"
tfhub_handle_encoder = "https://hub.tensorflow.google.cn/tensorflow/bert_zh_L-12_H-768_A-12/4"
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)
bert_model = hub.KerasLayer(tfhub_handle_encoder)
准备数据集
因为是在Colab做的实验,上传貌似只能上传文件的形式。所以在本地压缩成一个zip包,上传到Colab后做解压。
import zipfile
from pathlib import Path
zFile = zipfile.ZipFile("path.zip","r")
for fileM in zFile.namelist():
zFile.extract(fileM, "path")
zFile.close();
文件格式的要求跟官方示例一致即可。
文件目录
分类1
分类1的标题1
分类1的标题2
...
分类2
分类3
...
拆分数据集
AUTOTUNE = tf.data.AUTOTUNE
batch_size = 32
seed = 42
raw_train_ds = tf.keras.preprocessing.text_dataset_from_directory(
'train',
batch_size=batch_size,
validation_split=0.2,
subset='training',
seed=seed)
class_names = raw_train_ds.class_names
print(class_names)
train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = tf.keras.preprocessing.text_dataset_from_directory(
'train',
batch_size=batch_size,
validation_split=0.2,
subset='validation',
seed=seed)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
看一下数据集
for text_batch, label_batch in train_ds.take(1):
print(text_batch)
for i in range(3):
print(f'Review: {text_batch.numpy()[i]}')
label = label_batch.numpy()[i]
print(f'Label : {label} ({class_names[label]})')
定义模型
def build_classifier_model():
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
encoder_inputs = preprocessing_layer(text_input)
encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
outputs = encoder(encoder_inputs)
net = outputs['pooled_output']
net = tf.keras.layers.Dropout(0.5)(net)
net = tf.keras.layers.Dense(6, activation='softmax', name='classifier')(net)
return tf.keras.Model(text_input, net)
classifier_model = build_classifier_model()
epochs = 10
steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
num_train_steps = steps_per_epoch * epochs
print(num_train_steps)
num_warmup_steps = int(0.1*num_train_steps)
init_lr = 3e-5
optimizer = optimization.create_optimizer(init_lr=init_lr, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps,
optimizer_type='adamw')
loss = tf.keras.losses.SparseCategoricalCrossentropy()
metrics = tf.metrics.SparseCategoricalAccuracy()
classifier_model.compile(optimizer=optimizer,
loss=loss,
metrics=metrics)
训练模型
history = classifier_model.fit(x=train_ds,
validation_data=val_ds,
epochs=epochs)
预测数据
发现虽然我只用了300条样本,train样本240条,val样本60条,但是准确率也能到90%。BERT果然叼。
import numpy as np
def print_my_examples(inputs, results):
result_for_printing = \
[f'input: {inputs[i]:<30} : class: {class_names[np.argmax(results[i])] }'
for i in range(len(inputs))]
print(*result_for_printing, sep='\n')
print()
examples = ['德国米技(MIJI)蒸汽喷淋式煮茶壶 全自动保温泡茶壶HK-K018',
'公牛插座收纳盒装办公接线板家用创意电源',
'生活元素烧水壶办公室家用多功能小型烧水壶保温一体煮茶器煮茶壶',
'信安智囊老人防摔防跌倒术后保护重复使用智能气囊马甲服神器送礼',
'七度空间卫生巾优雅系列日夜超值组合',
' 【贝拉家】贝拉爸爸自己家工厂生产的科技布狗窝']
example_result = classifier_model(tf.constant(examples))
print_my_examples(examples,example_result)
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。