Supervised Fine-tuning Trainer (SFT) 入門
Supervised Fine-tuning (SFT) は、Reinforcement Learning from Human Feedback (RLHF) における重要なステップです。TRLでは、簡単に使えるAPIを提供しており、数行のコードであなたのデータセットでSFTモデルを作成し、学習することができます。
クイックスタート
🤗 Hubでホストされているデータセットがある場合、TRLのSFTTrainerを使って簡単にSFTモデルをファインチューニングできます。例えば、データセットがimdbで、予測したいテキストがデータセットのtextフィールドにあり、facebook/opt-350mモデルをファインチューニングしたいとします。
from datasets import load_dataset
from trl import SFTTrainer
# imdbデータセットをロード
dataset = load_dataset("imdb", split="train")
# SFTTrainerを初期化
trainer = SFTTrainer(
"facebook/opt-350m", # ファインチューニングするモデル
train_dataset=dataset, # 学習データセット
dataset_text_field="text", # テキストフィールド名
max_seq_length=512, # 最大シーケンス長
)
# 学習を開始
trainer.train()
高度な使用法
生成されたプロンプトのみでの学習
DataCollatorForCompletionOnlyLMを使用して、生成されたプロンプトのみでモデルを学習できます。これはpacking=Falseの場合にのみ機能することに注意してください。
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
# CodeAlpacaデータセットをロード
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
# モデルとトークナイザーをロード
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# プロンプトをフォーマットする関数を定義
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['instruction'])):
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
output_texts.append(text)
return output_texts
# レスポンステンプレートを定義
response_template = " ### Answer:"
# DataCollatorForCompletionOnlyLMを初期化
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
# SFTTrainerを初期化
trainer = SFTTrainer(
model,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
data_collator=collator,
)
# 学習を開始
trainer.train()
チャット形式用の特殊トークンの追加
言語モデルに特殊トークンを追加することは、チャットモデルの学習において重要です。これらのトークンは、ユーザー、アシスタント、システムなどの会話の異なる役割の間に追加され、モデルが会話の構造と流れを認識するのに役立ちます。
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
# モデルとトークナイザーをロード
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# チャット形式用の特殊トークンを追加
model, tokenizer = setup_chat_format(model, tokenizer)
データセットのパッキング
SFTTrainerは、複数の短い例を同じ入力シーケンスにパックすることで、学習効率を高めるサンプルパッキングをサポートしています。
from datasets import load_dataset
from trl import SFTTrainer
# imdbデータセットをロード
dataset = load_dataset("imdb", split="train")
# SFTTrainerを初期化(パッキングを有効化)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
packing=True
)
# 学習を開始
trainer.train()
プロンプトのカスタマイズ
データセットに複数のフィールドがあり、それらを組み合わせたい場合は、それを処理するフォーマット関数をトレーナーに渡すことができます。
def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
packing=True,
formatting_func=formatting_func
)
trainer.train()
事前学習済みモデルの制御
from_pretrained()メソッドのkwargsを直接SFTTrainerに渡すことができます。
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
model_init_kwargs={
"torch_dtype": torch.bfloat16,
},
)
trainer.train()
アダプターの学習
🤗 PEFTライブラリとの緊密な統合もサポートしているため、ユーザーは簡単にアダプターを学習し、モデル全体を学習する代わりにHubで共有できます。
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
# imdbデータセットをロード
dataset = load_dataset("imdb", split="train")
# LoraConfigを定義
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# SFTTrainerを初期化(peft_configを指定)
trainer = SFTTrainer(
"EleutherAI/gpt-neo-125m",
train_dataset=dataset,
dataset_text_field="text",
peft_config=peft_config
)
# 学習を開始
trainer.train()
NEFTuneを使用したモデルのパフォーマンス向上
NEFTuneは、学習中にノイズを埋め込みベクトルに追加することで、チャットモデルのパフォーマンスを向上させるテクニックです。
from datasets import load_dataset
from trl import SFTTrainer
# imdbデータセットをロード
dataset = load_dataset("imdb", split="train")
# SFTTrainerを初期化(neftune_noise_alphaを指定)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
neftune_noise_alpha=5,
)
# 学習を開始
trainer.train()
以上が、コードブロックを多用し、実行可能なセルとコメントを付与した、Supervised Fine-tuning Trainer (SFT)の初心者向けの解説です。SFTを使えば、簡単かつ効率的に言語モデルをファインチューニングできます。ぜひ活用してみてください!
コメント