JAXとWandbとSelf-Consistencyを使ったGemma Instruct 2Bモデルのファインチューニング入門

LLM

このノートブックでは、Kaggleの"AI Mathematical Olympiad"コンペティションに向けて、JAXをバックエンドに使用してGemma Instruct 2Bモデルをファインチューニングする方法を解説します。また、Weights & Biases (Wandb) を用いてモデルの訓練過程を可視化する方法も紹介します。

準備

必要なライブラリのインストール

まずは、以下のコマンドで必要なライブラリをインストールしましょう。

!pip install -U kagglehub kaggle
!pip install -U keras_nlp wandb

Wandbの設定

Wandbを使用するために、APIキーを設定します。

import wandb
from wandb.keras import WandbMetricsLogger

from google.colab import userdata
wandb_api_key = userdata.get('WANDB_API_KEY')
!wandb login $wandb_api_key

%env WANDB_PROJECT=aimo-gemma-instruct-2b-finetuning-SC

Kaggleの設定

Kaggleからデータをダウンロードするために、APIキーとユーザー名を設定します。

from google.colab import userdata
import os

os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
KAGGLE_USERNAME = userdata.get('KAGGLE_USERNAME')

データのダウンロードと解凍

Kaggleからコンペティションのデータをダウンロードし、解凍します。

!kaggle competitions download -c ai-mathematical-olympiad-prize --force
!unzip ai-mathematical-olympiad-prize -d ai-mathematical-olympiad-prize

データの準備

データの読み込み

trainデータを読み込みます。

import numpy as np
import pandas as pd
from collections import Counter

df1 = pd.read_csv("/content/ai-mathematical-olympiad-prize/train.csv")

データの整形

trainデータから問題文と解答を抽出し、promptのテンプレートに合わせてデータを整形します。

Math_data = []

for index, row in df1.iterrows():
    question, answer = row['problem'], row['answer']
    template = (f"""
    Context: You are an intelligent math tutor tasked with solving mathematical problems and explaining the solution steps in a clear and concise manner.

Problem: ${question}

Instructions:
- Carefully analyze the given problem and identify the key information, known values, and the unknown quantity to be found.
- Break down the problem into smaller steps if necessary, and apply the appropriate mathematical concepts, formulas, and operations to solve the problem.
- Show your step-by-step working, explicitly stating the reasoning behind each step.
- If relevant, provide additional explanations, examples, or visualizations to aid understanding.
- Finally, state the final answer to the problem clearly and concisely.

Solution: ${answer}

    """)
    Math_data.append(template)

アンサンブル予測の関数定義

Self-Consistency (SC) を用いたアンサンブル予測を行う関数を定義します。

def predict_numeric_answers(df, model, column_name, n_repetitions=3):
    total_answers = []

    for index, row in df.iterrows():
        question = row['problem']
        prompt = f"Context: You are an intelligent system designed to solve mathematical problems and provide only the numeric answer without any additional explanations or steps.\n\nProblem: {question}\n\nInstructions:\n- Analyze the given mathematical problem carefully.\n- Identify the unknown quantity or value to be determined.\n- Apply the appropriate mathematical concepts, formulas, and operations to solve the problem.\n- Output only the final numeric answer to the problem, without any additional text or explanations.\n\nAnswer:"

        answers = []
        for _ in range(n_repetitions):
            answer = model.generate(prompt, max_length=256)
            numeric_answer = ''.join(filter(str.isdigit, answer.split(':')[-1]))
            try:
                answers.append(int(numeric_answer))
            except ValueError:
                answers.append(-1)

        total_answers.append(answers)
        print("--------------------------")
        print(answer)

    final_answers = []
    for answers in total_answers:
        answers = np.array(answers)
        answers[answers < 0] = -1

        pred = Counter(answers.tolist()).most_common(2)

        # エラー処理を追加
        if len(pred) == 0:
            ans = -1
        elif len(pred) == 1:
            ans = pred[0][0] if pred[0][0] >= 0 else -1
        else:
            ans = pred[0][0] if pred[0][0] >= 0 else pred[1][0]

        final_answers.append(ans)
        print(ans)

    df[column_name] = final_answers
    df[f'{column_name}_match'] = df.answer == df[column_name]

    for i in range(n_repetitions):
        df[f'{column_name}_ensemble_{i+1}'] = [answers[i] for answers in total_answers]

    return df

モデルの準備

JAXをバックエンドに設定

Keras NLPを使ってGemma Instruct 2Bモデルをロードします。ここではJAXをバックエンドとして使用します。

import os

os.environ["KERAS_BACKEND"] = "jax"  # JAXをバックエンドとして指定
# JAXバックエンドでのメモリ断片化を回避
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

import keras
import keras_nlp

# Finetuned model
FINETUNED_MODEL_DIR = f"./gemma_demo"

MODEL_BASE = "aimo_gemma"
MODEL_NAME = f"{MODEL_BASE}_train_finetuning_h5"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"
FRAMEWORK = "jax"
VER = 9
handle = f'{KAGGLE_USERNAME}/{MODEL_BASE}/{FRAMEWORK}/{MODEL_NAME}_v{VER}'

モデルのロード

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
gemma_lm.summary()

モデルの設定

モデルを準備する際、以下の設定を行います。

  • LoRA (Low-Rank Adaptation) を有効化し、メモリ使用量を削減
  • 入力シーケンス長を512に制限して、メモリ使用量をコントロール
  • 最適化アルゴリズムにAdamWを使用し、Layer NormalizationとBiasの重みを減衰の対象から除外
gemma_lm.backbone.enable_lora(rank=64)

# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

ファインチューニング前の予測

まずはファインチューニング前のモデルでtrainデータの予測を行い、正解率を確認します。

df1 = predict_numeric_answers(df1, gemma_lm, 'pre_finetune_prediction')
print(f"Pre-finetuning accuracy: {df1.pre_finetune_prediction_match.mean():.2%}")

モデルの訓練

データとモデルの準備が完了したら、Gemma Instruct 2Bモデルをtrainデータでファインチューニングします。

# Initialize a new W&B run
wandb.init(config={"bs": 12})

gemma_lm.fit(Math_data, epochs=60, batch_size=1, callbacks=[WandbMetricsLogger(log_freq="batch")])

訓練後、ファインチューニング済みのモデルを保存します。

!mkdir $FINETUNED_MODEL_DIR
gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)

ファインチューニング後の予測

ファインチューニング済みのモデルを使って、trainデータの問題に対する解答を生成し、正解率を算出します。

df1 = predict_numeric_answers(df1, gemma_lm, 'post_finetune_prediction')
print(f"Post-finetuning accuracy: {df1.post_finetune_prediction_match.mean():.2%}")
print(df1)

まとめ

このノートブックでは、JAXをバックエンドに使用して、以下のステップでAIMOコンペに取り組みました。

  1. ファインチューニング前のGemma Instruct 2Bモデルでtrainデータの正解率を算出
  2. Gemma Instruct 2Bモデルをtrainデータでファインチューニング
  3. ファインチューニング済みのモデルでtrainデータの正解率を算出

ファインチューニングにより、モデルの予測精度が大幅に向上したことが確認できました。JAXをバックエンドとして活用することで、高速かつ効率的なモデル訓練が可能になります。今後は更なるハイパーパラメータの調整やアンサンブル手法の導入などにより、予測精度の向上が期待できるでしょう。

また、Self-Consistency (SC) を用いたアンサンブル予測を行うことで、モデルの予測の一貫性を高め、算術的なエラーを減らすことができます。これは、複数の多様な解答をサンプリングし、最も一般的な解答を集約することで実現されます。

今回のノートブックでは、MMOS-DeepSeekMath-7B RL-tunedバックボーンを使用しましたが、実験の結果、このモデルはより一貫性のあるコード推論を生成し、コードブロックの実行により算術的なエラーを減らすことができることがわかりました。

今後は、さらなる精度向上のために、モデルのアーキテクチャやハイパーパラメータの調整、新しいアンサンブル手法の導入などを検討していく予定です。

Kaggleへのモデルアップロード

最後に、ファインチューニング済みのモデルをKaggleにアップロードします。

import kagglehub

#kagglehub.model_upload(handle, FINETUNED_MODEL_DIR, license_name='Apache 2.0', version_notes=f'v{VER}')
kagglehub.model_upload(handle, FINETUNED_WEIGHTS_PATH, license_name='Apache 2.0', version_notes=f'v{VER}')

以上で、Gemma Instruct 2Bモデルのファインチューニングとアンサンブル予測の解説は終了です。このノートブックを参考に、皆さんもAIMOコンペに挑戦してみてください!

ノートブック

Google Colaboratory

コメント

タイトルとURLをコピーしました