このノートブックでは、Kaggleの"AI Mathematical Olympiad"コンペティションに向けて、JAXをバックエンドに使用してGemma Instruct 2Bモデルをファインチューニングする方法を解説します。また、Weights & Biases (Wandb) を用いてモデルの訓練過程を可視化する方法も紹介します。
準備
必要なライブラリのインストール
まずは、以下のコマンドで必要なライブラリをインストールしましょう。
!pip install -U kagglehub kaggle
!pip install -U keras_nlp wandb
Wandbの設定
Wandbを使用するために、APIキーを設定します。
import wandb
from wandb.keras import WandbMetricsLogger
from google.colab import userdata
wandb_api_key = userdata.get('WANDB_API_KEY')
!wandb login $wandb_api_key
%env WANDB_PROJECT=aimo-gemma-instruct-2b-finetuning-SC
Kaggleの設定
Kaggleからデータをダウンロードするために、APIキーとユーザー名を設定します。
from google.colab import userdata
import os
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
KAGGLE_USERNAME = userdata.get('KAGGLE_USERNAME')
データのダウンロードと解凍
Kaggleからコンペティションのデータをダウンロードし、解凍します。
!kaggle competitions download -c ai-mathematical-olympiad-prize --force
!unzip ai-mathematical-olympiad-prize -d ai-mathematical-olympiad-prize
データの準備
データの読み込み
trainデータを読み込みます。
import numpy as np
import pandas as pd
from collections import Counter
df1 = pd.read_csv("/content/ai-mathematical-olympiad-prize/train.csv")
データの整形
trainデータから問題文と解答を抽出し、promptのテンプレートに合わせてデータを整形します。
Math_data = []
for index, row in df1.iterrows():
question, answer = row['problem'], row['answer']
template = (f"""
Context: You are an intelligent math tutor tasked with solving mathematical problems and explaining the solution steps in a clear and concise manner.
Problem: ${question}
Instructions:
- Carefully analyze the given problem and identify the key information, known values, and the unknown quantity to be found.
- Break down the problem into smaller steps if necessary, and apply the appropriate mathematical concepts, formulas, and operations to solve the problem.
- Show your step-by-step working, explicitly stating the reasoning behind each step.
- If relevant, provide additional explanations, examples, or visualizations to aid understanding.
- Finally, state the final answer to the problem clearly and concisely.
Solution: ${answer}
""")
Math_data.append(template)
アンサンブル予測の関数定義
Self-Consistency (SC) を用いたアンサンブル予測を行う関数を定義します。
def predict_numeric_answers(df, model, column_name, n_repetitions=3):
total_answers = []
for index, row in df.iterrows():
question = row['problem']
prompt = f"Context: You are an intelligent system designed to solve mathematical problems and provide only the numeric answer without any additional explanations or steps.\n\nProblem: {question}\n\nInstructions:\n- Analyze the given mathematical problem carefully.\n- Identify the unknown quantity or value to be determined.\n- Apply the appropriate mathematical concepts, formulas, and operations to solve the problem.\n- Output only the final numeric answer to the problem, without any additional text or explanations.\n\nAnswer:"
answers = []
for _ in range(n_repetitions):
answer = model.generate(prompt, max_length=256)
numeric_answer = ''.join(filter(str.isdigit, answer.split(':')[-1]))
try:
answers.append(int(numeric_answer))
except ValueError:
answers.append(-1)
total_answers.append(answers)
print("--------------------------")
print(answer)
final_answers = []
for answers in total_answers:
answers = np.array(answers)
answers[answers < 0] = -1
pred = Counter(answers.tolist()).most_common(2)
# エラー処理を追加
if len(pred) == 0:
ans = -1
elif len(pred) == 1:
ans = pred[0][0] if pred[0][0] >= 0 else -1
else:
ans = pred[0][0] if pred[0][0] >= 0 else pred[1][0]
final_answers.append(ans)
print(ans)
df[column_name] = final_answers
df[f'{column_name}_match'] = df.answer == df[column_name]
for i in range(n_repetitions):
df[f'{column_name}_ensemble_{i+1}'] = [answers[i] for answers in total_answers]
return df
モデルの準備
JAXをバックエンドに設定
Keras NLPを使ってGemma Instruct 2Bモデルをロードします。ここではJAXをバックエンドとして使用します。
import os
os.environ["KERAS_BACKEND"] = "jax" # JAXをバックエンドとして指定
# JAXバックエンドでのメモリ断片化を回避
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
import keras
import keras_nlp
# Finetuned model
FINETUNED_MODEL_DIR = f"./gemma_demo"
MODEL_BASE = "aimo_gemma"
MODEL_NAME = f"{MODEL_BASE}_train_finetuning_h5"
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"
FRAMEWORK = "jax"
VER = 9
handle = f'{KAGGLE_USERNAME}/{MODEL_BASE}/{FRAMEWORK}/{MODEL_NAME}_v{VER}'
モデルのロード
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
gemma_lm.summary()
モデルの設定
モデルを準備する際、以下の設定を行います。
- LoRA (Low-Rank Adaptation) を有効化し、メモリ使用量を削減
- 入力シーケンス長を512に制限して、メモリ使用量をコントロール
- 最適化アルゴリズムにAdamWを使用し、Layer NormalizationとBiasの重みを減衰の対象から除外
gemma_lm.backbone.enable_lora(rank=64)
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
ファインチューニング前の予測
まずはファインチューニング前のモデルでtrainデータの予測を行い、正解率を確認します。
df1 = predict_numeric_answers(df1, gemma_lm, 'pre_finetune_prediction')
print(f"Pre-finetuning accuracy: {df1.pre_finetune_prediction_match.mean():.2%}")
モデルの訓練
データとモデルの準備が完了したら、Gemma Instruct 2Bモデルをtrainデータでファインチューニングします。
# Initialize a new W&B run
wandb.init(config={"bs": 12})
gemma_lm.fit(Math_data, epochs=60, batch_size=1, callbacks=[WandbMetricsLogger(log_freq="batch")])
訓練後、ファインチューニング済みのモデルを保存します。
!mkdir $FINETUNED_MODEL_DIR
gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_MODEL_DIR)
ファインチューニング後の予測
ファインチューニング済みのモデルを使って、trainデータの問題に対する解答を生成し、正解率を算出します。
df1 = predict_numeric_answers(df1, gemma_lm, 'post_finetune_prediction')
print(f"Post-finetuning accuracy: {df1.post_finetune_prediction_match.mean():.2%}")
print(df1)
まとめ
このノートブックでは、JAXをバックエンドに使用して、以下のステップでAIMOコンペに取り組みました。
- ファインチューニング前のGemma Instruct 2Bモデルでtrainデータの正解率を算出
- Gemma Instruct 2Bモデルをtrainデータでファインチューニング
- ファインチューニング済みのモデルでtrainデータの正解率を算出
ファインチューニングにより、モデルの予測精度が大幅に向上したことが確認できました。JAXをバックエンドとして活用することで、高速かつ効率的なモデル訓練が可能になります。今後は更なるハイパーパラメータの調整やアンサンブル手法の導入などにより、予測精度の向上が期待できるでしょう。
また、Self-Consistency (SC) を用いたアンサンブル予測を行うことで、モデルの予測の一貫性を高め、算術的なエラーを減らすことができます。これは、複数の多様な解答をサンプリングし、最も一般的な解答を集約することで実現されます。
今回のノートブックでは、MMOS-DeepSeekMath-7B RL-tunedバックボーンを使用しましたが、実験の結果、このモデルはより一貫性のあるコード推論を生成し、コードブロックの実行により算術的なエラーを減らすことができることがわかりました。
今後は、さらなる精度向上のために、モデルのアーキテクチャやハイパーパラメータの調整、新しいアンサンブル手法の導入などを検討していく予定です。
Kaggleへのモデルアップロード
最後に、ファインチューニング済みのモデルをKaggleにアップロードします。
import kagglehub
#kagglehub.model_upload(handle, FINETUNED_MODEL_DIR, license_name='Apache 2.0', version_notes=f'v{VER}')
kagglehub.model_upload(handle, FINETUNED_WEIGHTS_PATH, license_name='Apache 2.0', version_notes=f'v{VER}')
以上で、Gemma Instruct 2Bモデルのファインチューニングとアンサンブル予測の解説は終了です。このノートブックを参考に、皆さんもAIMOコンペに挑戦してみてください!
コメント