上篇文章我们介绍了Llama 2的量化和部署,本篇文章将介绍使用PEFT库和QLoRa方法对Llama 27b预训练模型进行微调。我们将使用自定义数据集来构建情感分析模型。只有可以对数据进行微调我们才可以将这种大模型进行符合我们数据集的定制化。
一些前置的知识
如果熟悉Google Colab、Weights & Biases (W&B)、HF库,可以跳过这一节。
虽然Google Colab(托管的Jupyter笔记本环境)不是真正的先决条件,但我们建议使用它来访问GPU并进行快速实验。如果是付费的用户,则可以使用高级GPU访问,比如A100这样的GPU。
W&B帐户的作用是记录进度和训练指标,这个如果不需要也可以用tensorboard替代,但是我们是演示Google Colab环境所以直接用它。
然后就是需要一个HF帐户。然后转到settings,创建至少具有读权限的API令牌。因为在训练脚本时将使用它下载预训练的Llama 2模型和数据集。
最后就是请求访问Llama 2模型。等待Meta AI和HF的邮件。这可能要1-2天。
准备数据集
指令微调是一种常用技术,用于为特定的下游用例微调基本LLM。训练示例如下:
Below is an instruction that describes a sentiment analysis task...
### Instruction:
Analyze the following comment and classify the tone as...
### Input:
I love reading your articles...
### Response:
friendly & constructive
我们建议使用json,因为这样比较灵活。比如为每个示例创建一个JSON对象,其中只有一个文本字段。像这样:
{ "text": "Below is an instruction ... ### Instruction: Analyze the... ### Input: I love... ### Response: friendly" },
{ "text": "Below is an instruction ... ### Instruction: ..." }
有很多很多方法可以提取原始数据、处理和创建训练数据集作为json文件。下面是一个简单的脚本:
with open('train.jsonl', 'a') as outfile:
for example in raw_data:
text = '<process_example>'
# now append entry to the jsonl file.
outfile.write('{"text": "' + text + '"}')
outfile.write('\n')
如HF的Datasets库也是一个选择,但是我个人觉得他不好用。
在我们开始训练之前,我们要将文件作为数据集存储库推送到HF。可以直接使用huggingface-cli上传数据集。
训练
Parameter-Efficient Fine-Tuning(PEFT)可以用于在不触及LLM的所有参数的情况下对LLM进行有效的微调。PEFT支持QLoRa方法,通过4位量化对LLM参数的一小部分进行微调。
Transformer Reinforcement Learning (TRL)是一个使用强化学习来训练语言模型的库。TRL也提供的监督微调(SFT)训练器API可以让我们快速的微调模型。
!pip install -q huggingface_hub
!pip install -q -U trl transformers accelerate peft
!pip install -q -U datasets bitsandbytes einops wandb
# Uncomment to install new features that support latest models like Llama 2
# !pip install git+https://github.com/huggingface/peft.git
# !pip install git+https://github.com/huggingface/transformers.git
# When prompted, paste the HF access token you created earlier.
from huggingface_hub import notebook_login
notebook_login()
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
dataset_name = "<your_hf_dataset>"
dataset = load_dataset(dataset_name, split="train")
base_model_name = "meta-llama/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
device_map = {"": 0}
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=bnb_config,
device_map=device_map,
trust_remote_code=True,
use_auth_token=True
)
base_model.config.use_cache = False
# More info: https://github.com/huggingface/transformers/pull/24906
base_model.config.pretraining_tp = 1
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
output_dir = "./results"
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=10,
max_steps=500
)
max_seq_length = 512
trainer = SFTTrainer(
model=base_model,
train_dataset=dataset,
peft_config=peft_config,
dataset_text_field="text",
max_seq_length=max_seq_length,
tokenizer=tokenizer,
args=training_args,
)
trainer.train()
import os
output_dir = os.path.join(output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)
上面的脚本就是一个微调的简单代码,这里可以添加命令行参数解析器模块,如HfArgumentParser,这样就不必硬编码这些值
测试
下面时一个简单的加载模型并进行完整性测试的快速方法。
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map=device_map, torch_dtype=torch.bfloat16)
text = "..."
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"], max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
这样就能够查看我们的结果了。
本文作者:UD
原文地址:
https://avoid.overfit.cn/post/e2b178db4f9344c2a659925689c1f049
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。