JAXとWeights & Biasesを用いたGemma Instruct 2BモデルのFinetuning入門

AI・機械学習

はじめに

このノートブックでは、JAXをバックエンドに使用して、Kaggleの"AI Mathematical Olympiad"コンペティションに向けてGemma Instruct 2Bモデルをfinetuningする方法について解説します。また、Weights & Biases (W&B) を用いて実験結果のトラッキングや可視化を行います。

データの準備

データの準備は機械学習プロジェクトにおいて非常に重要なステップです。この節では、Kaggleの"AI Mathematical Olympiad"コンペティションのtrainデータを取得し、モデルの訓練に適した形式に整形する手順を詳しく説明します。

必要なライブラリのインストールとインポート

まずは、以下のコマンドを実行して必要なライブラリをインストールします。

!pip install -U kagglehub kaggle
!pip install -U keras_nlp wandb

次に、インポートします。

import wandb
from wandb.keras import WandbMetricsLogger

from google.colab import userdata
wandb_api_key = userdata.get('WANDB_API_KEY')
!wandb login $wandb_api_key

from google.colab import userdata
import os

os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
KAGGLE_USERNAME = userdata.get('KAGGLE_USERNAME')

import numpy as np
import pandas as pd

ここでは、Weights & Biases (wandb)、Keras NLP、Kaggle APIに関連するライブラリをインストールし、インポートしています。また、Weights & BiasesとKaggle APIの認証情報を環境変数から取得しています。

Kaggleからデータをダウンロード

次に、Kaggle APIを使用して"AI Mathematical Olympiad"コンペティションのデータをダウンロードし、解凍します。

!kaggle competitions list | grep mathematical
!kaggle competitions download -c ai-mathematical-olympiad-prize --force
!unzip ai-mathematical-olympiad-prize -d ai-mathematical-olympiad-prize

これにより、ai-mathematical-olympiad-prizeディレクトリ内にtrainデータが展開されます。

trainデータの読み込み

pandas を使用して、trainデータを読み込みます。

df1 = pd.read_csv("/content/ai-mathematical-olympiad-prize/train.csv")

これにより、trainデータがdf1というDataFrameに格納されます。

データの整形

trainデータから問題文と解答を抽出し、promptのテンプレートに合わせてデータを整形します。

Math_data = []

for index, row in df1.iterrows():
    question, answer = row['problem'], row['answer']
    template = (f"""
    Context: You are an intelligent math tutor tasked with solving mathematical problems and explaining the solution steps in a clear and concise manner.

Problem: ${question}

Instructions:
- Carefully analyze the given problem and identify the key information, known values, and the unknown quantity to be found.
- Break down the problem into smaller steps if necessary, and apply the appropriate mathematical concepts, formulas, and operations to solve the problem.
- Show your step-by-step working, explicitly stating the reasoning behind each step.
- If relevant, provide additional explanations, examples, or visualizations to aid understanding.
- Finally, state the final answer to the problem clearly and concisely.

Solution: ${answer}

    """)
    Math_data.append(template)

ここでは、DataFrameの各行をイテレートし、問題文(question)と解答(answer)を取得しています。そして、これらの情報をプロンプトのテンプレートに埋め込み、Math_dataリストに追加しています。このテンプレートは、モデルが問題を解決し、解答の手順を明確かつ簡潔に説明するための指示を含んでいます。

また、predict_numeric_answers関数を定義しています。この関数は、与えられたDataFrameとモデルを使用して、数値解答を予測し、予測結果と実際の解答が一致するかどうかを示す新しい列を追加します。

def predict_numeric_answers(df, model, column_name):
    predictions = []

    for index, row in df.iterrows():
        question = row['problem']
        prompt = f"Context: You are an intelligent system designed to solve mathematical problems and provide only the numeric answer without any additional explanations or steps.\n\nProblem: {question}\n\nInstructions:\n- Analyze the given mathematical problem carefully.\n- Identify the unknown quantity or value to be determined.\n- Apply the appropriate mathematical concepts, formulas, and operations to solve the problem.\n- Output only the final numeric answer to the problem, without any additional text or explanations.\n\nAnswer:"
        answer = model.generate(prompt, max_length=256)
        numeric_answer = ''.join(filter(str.isdigit, answer.split(':')[-1]))
        try:
            predictions.append(int(numeric_answer))
        except ValueError:
            predictions.append(-1)
        print("--------------------------")
        print(answer)

    df[column_name] = predictions
    df[f'{column_name}_match'] = df.answer == df[column_name]
    return df

この関数は、モデルのファインチューニング前後の性能を評価するために使用されます。

以上が、データの準備に関する詳細な説明です。これらの手順により、trainデータを取得し、モデルの訓練に適した形式に整形することができます。

モデルの準備

次に、Keras NLPを使ってGemma Instruct 2Bモデルをロードします。ここではJAXをバックエンドとして使用します。

import os

os.environ["KERAS_BACKEND"] = "jax"  # JAXをバックエンドとして指定
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"  # JAXバックエンドでのメモリ断片化を回避

import keras
import keras_nlp

FINETUNED_MODEL_DIR = f"./gemma_demo"

MODEL_BASE = "aimo_gemma"
MODEL_NAME = f"{MODEL_BASE}_train_finetuning_h5"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"
FRAMEWORK = "jax"
VER = 9
handle = f'{KAGGLE_USERNAME}/{MODEL_BASE}/{FRAMEWORK}/{MODEL_NAME}_v{VER}'

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
gemma_lm.summary()

モデルを準備する際、以下の設定を行います。

  • LoRA (Low-Rank Adaptation) を有効化し、メモリ使用量を削減
  • 入力シーケンス長を512に制限して、メモリ使用量をコントロール
  • 最適化アルゴリズムにAdamWを使用し、Layer NormalizationとBiasの重みを減衰の対象から除外
gemma_lm.backbone.enable_lora(rank=64)

gemma_lm.preprocessor.sequence_length = 512  # 入力シーケンス長を512に制限
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])  # LayerNormとBiasを減衰対象から除外

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

ファインチューニング前の予測

まずはファインチューニング前のモデルでtrainデータの予測を行い、正解率を確認します。

df1 = predict_numeric_answers(df1, gemma_lm, 'pre_finetune_prediction')

df1
print(f"Pre-finetuning accuracy: {df1.pre_finetune_prediction_match.mean():.2%}")

モデルの訓練

データとモデルの準備が完了したら、Gemma Instruct 2Bモデルをtrainデータでfinetuningします。ここでは、Weights & Biasesを用いて実験結果のトラッキングと可視化を行います。

wandb.init(config={"bs": 12})  # Weights & Biasesの新しいRunを初期化

gemma_lm.fit(Math_data, epochs=60, batch_size=1, callbacks=[WandbMetricsLogger(log_freq="batch")])  # モデルの訓練とW&Bへのログ記録

!mkdir $FINETUNED_MODEL_DIR
gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)  # finetuningされた重みを保存
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)  # トークナイザーのアセットを保存

Weights & Biasesを使用することで、以下のようなメリットが得られます。

  • 実験結果のリアルタイムなトラッキングと可視化
  • ハイパーパラメータ、メトリクス、モデルの重みなどの自動ロギング
  • チームメンバー間でのコラボレーションとナレッジ共有の促進

ファインチューニング後の予測

finetuning済みのモデルを使って、trainデータの問題に対する解答を生成し、正解率を算出します。

df1 = predict_numeric_answers(df1, gemma_lm, 'post_finetune_prediction')

print(f"Post-finetuning accuracy: {df1.post_finetune_prediction_match.mean():.2%}")
df1

モデルのアップロード

訓練したモデルをKaggle Hubにアップロードします。

import kagglehub

kagglehub.model_upload(handle, FINETUNED_WEIGHTS_PATH, license_name='Apache 2.0', version_notes=f'v{VER}')

まとめ

このノートブックでは、JAXをバックエンドに使用し、Weights & Biasesによる実験管理を行いながら、以下のステップでAIMOコンペに取り組みました。

  1. ファインチューニング前のGemma Instruct 2Bモデルでtrainデータの正解率を算出
  2. Gemma Instruct 2Bモデルをtrainデータでfinetuning
  3. finetuning済みモデルでtrainデータの正解率を算出

ファインチューニングにより、モデルの予測精度が大幅に向上したことが確認できました。
JAXをバックエンドとして活用することで、高速かつ効率的なモデル訓練が可能になります。
Weights & Biasesを用いることで、実験結果のトラッキングや可視化、チームコラボレーションが容易になります。
今後は更なるハイパーパラメータの調整やアンサンブル手法の導入などにより、予測精度の向上が期待できるでしょう。

ノートブック

Google Colaboratory

参考サイト

KerasとJAXを使ってGemmaモデルをTPU分散学習する方法
Explore and run machine learning code with Kaggle Notebooks | Using data from Gemma
kagglehub を使った大規模言語モデル gemma のファインチューニングとモデル共有
Explore and run machine learning code with Kaggle Notebooks | Using data from multiple data sources

コメント

タイトルとURLをコピーしました