Supervised Fine-tuning Trainer (SFT) 入門

自然言語処理

Supervised Fine-tuning Trainer (SFT) 入門

Supervised Fine-tuning (SFT) は、Reinforcement Learning from Human Feedback (RLHF) における重要なステップです。TRLでは、簡単に使えるAPIを提供しており、数行のコードであなたのデータセットでSFTモデルを作成し、学習することができます。

クイックスタート

🤗 Hubでホストされているデータセットがある場合、TRLのSFTTrainerを使って簡単にSFTモデルをファインチューニングできます。例えば、データセットがimdbで、予測したいテキストがデータセットのtextフィールドにあり、facebook/opt-350mモデルをファインチューニングしたいとします。

from datasets import load_dataset
from trl import SFTTrainer

# imdbデータセットをロード
dataset = load_dataset("imdb", split="train")

# SFTTrainerを初期化
trainer = SFTTrainer(
    "facebook/opt-350m",  # ファインチューニングするモデル
    train_dataset=dataset,  # 学習データセット
    dataset_text_field="text",  # テキストフィールド名
    max_seq_length=512,  # 最大シーケンス長
)

# 学習を開始
trainer.train()

高度な使用法

生成されたプロンプトのみでの学習

DataCollatorForCompletionOnlyLMを使用して、生成されたプロンプトのみでモデルを学習できます。これはpacking=Falseの場合にのみ機能することに注意してください。

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

# CodeAlpacaデータセットをロード
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

# モデルとトークナイザーをロード
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

# プロンプトをフォーマットする関数を定義
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

# レスポンステンプレートを定義
response_template = " ### Answer:"

# DataCollatorForCompletionOnlyLMを初期化
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

# SFTTrainerを初期化
trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)

# 学習を開始
trainer.train()

チャット形式用の特殊トークンの追加

言語モデルに特殊トークンを追加することは、チャットモデルの学習において重要です。これらのトークンは、ユーザー、アシスタント、システムなどの会話の異なる役割の間に追加され、モデルが会話の構造と流れを認識するのに役立ちます。

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format

# モデルとトークナイザーをロード
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") 

# チャット形式用の特殊トークンを追加
model, tokenizer = setup_chat_format(model, tokenizer)

データセットのパッキング

SFTTrainerは、複数の短い例を同じ入力シーケンスにパックすることで、学習効率を高めるサンプルパッキングをサポートしています。

from datasets import load_dataset
from trl import SFTTrainer

# imdbデータセットをロード
dataset = load_dataset("imdb", split="train")

# SFTTrainerを初期化(パッキングを有効化)
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    packing=True
)

# 学習を開始
trainer.train()

プロンプトのカスタマイズ

データセットに複数のフィールドがあり、それらを組み合わせたい場合は、それを処理するフォーマット関数をトレーナーに渡すことができます。

def formatting_func(example):
    text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
    return text

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    packing=True,
    formatting_func=formatting_func
)

trainer.train()

事前学習済みモデルの制御

from_pretrained()メソッドのkwargsを直接SFTTrainerに渡すことができます。

trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    model_init_kwargs={
        "torch_dtype": torch.bfloat16,
    },
)

trainer.train()

アダプターの学習

🤗 PEFTライブラリとの緊密な統合もサポートしているため、ユーザーは簡単にアダプターを学習し、モデル全体を学習する代わりにHubで共有できます。

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig

# imdbデータセットをロード
dataset = load_dataset("imdb", split="train")

# LoraConfigを定義
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# SFTTrainerを初期化(peft_configを指定)
trainer = SFTTrainer(
    "EleutherAI/gpt-neo-125m",  
    train_dataset=dataset,
    dataset_text_field="text",
    peft_config=peft_config
)

# 学習を開始
trainer.train()

NEFTuneを使用したモデルのパフォーマンス向上

NEFTuneは、学習中にノイズを埋め込みベクトルに追加することで、チャットモデルのパフォーマンスを向上させるテクニックです。

from datasets import load_dataset
from trl import SFTTrainer

# imdbデータセットをロード
dataset = load_dataset("imdb", split="train")

# SFTTrainerを初期化(neftune_noise_alphaを指定)
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset, 
    dataset_text_field="text",
    max_seq_length=512,
    neftune_noise_alpha=5,
)

# 学習を開始
trainer.train()

以上が、コードブロックを多用し、実行可能なセルとコメントを付与した、Supervised Fine-tuning Trainer (SFT)の初心者向けの解説です。SFTを使えば、簡単かつ効率的に言語モデルをファインチューニングできます。ぜひ活用してみてください!

ノートブック

Google Colaboratory

参考サイト

Supervised Fine-tuning Trainer
We’re on a journey to advance and democratize artificial intelligence through open source and open science.

コメント

タイトルとURLをコピーしました