Gemma Instruct 2Bモデルは、自然言語処理タスクに優れたパフォーマンスを発揮する大規模言語モデルです。このモデルをファインチューニングすることで、特定のタスクにおける性能をさらに向上させることができます。
本記事では、JAX、Wandb、Self-Consistency、Weaveを活用して、Gemma Instruct 2Bモデルをファインチューニングする方法を初心者向けに解説します。章立てやコードブロックを多用し、可読性の高い記事を目指します。
環境設定
まずは必要なライブラリをインストールします。
!pip install -U kagglehub kaggle
!pip install -U keras_nlp wandb weave
次に、WandbとKaggleの設定を行います。APIキーとユーザー名を設定してください。
import wandb
from wandb.keras import WandbMetricsLogger
from google.colab import userdata
wandb_api_key = userdata.get('WANDB_API_KEY')
!wandb login $wandb_api_key
import os
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
KAGGLE_USERNAME = userdata.get('KAGGLE_USERNAME')
データの準備
Kaggleからデータをダウンロードし、trainデータを読み込みます。
!kaggle competitions download -c ai-mathematical-olympiad-prize --force
!unzip ai-mathematical-olympiad-prize -d ai-mathematical-olympiad-prize
import pandas as pd
df1 = pd.read_csv("/content/ai-mathematical-olympiad-prize/train.csv")
trainデータからプロンプトを生成します。
Math_data = []
for index, row in df1.iterrows():
question, answer = row['problem'], row['answer']
template = (f"""
Context: You are an intelligent math tutor tasked with solving mathematical problems and explaining the solution steps in a clear and concise manner.
Problem: {question}
Instructions:
- Carefully analyze the given problem and identify the key information, known values, and the unknown quantity to be found.
- Break down the problem into smaller steps if necessary, and apply the appropriate mathematical concepts, formulas, and operations to solve the problem.
- Show your step-by-step working, explicitly stating the reasoning behind each step.
- If relevant, provide additional explanations, examples, or visualizations to aid understanding.
- Finally, state the final answer to the problem clearly and concisely.
Solution: {answer}
""")
Math_data.append(template)
モデルの準備
JAXをバックエンドに設定し、Keras NLPを使ってGemma Instruct 2Bモデルをロードします。
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
import keras
import keras_nlp
MODEL_NAME = "gemma_instruct_2b_olympiad"
FINETUNED_WEIGHTS_PATH = f"{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"vocabulary.spm"
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
LoRAを有効化し、最適化アルゴリズムにAdamWを使用します。
gemma_lm.backbone.enable_lora(rank=64)
gemma_lm.preprocessor.sequence_length = 512
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
Weaveでのモデル管理
Weave APIを使って、ファインチューニング前後のGemma Instruct 2Bモデルの推論結果の記録、実験、評価を行います。
まずはモデルを定義します。
import weave
class MathOlympiadModel(weave.Model):
model: keras_nlp.models.GemmaCausalLM
@weave.op()
def predict(self, problem: str) -> dict:
prompt = (f"""
Context: You are an intelligent math tutor tasked with solving mathematical problems and explaining the solution steps in a clear and concise manner.
Problem: {problem}
Instructions:
- Carefully analyze the given problem and identify the key information, known values, and the unknown quantity to be found.
- Break down the problem into smaller steps if necessary, and apply the appropriate mathematical concepts, formulas, and operations to solve the problem.
- Show your step-by-step working, explicitly stating the reasoning behind each step.
- If relevant, provide additional explanations, examples, or visualizations to aid understanding.
- Finally, state the final answer to the problem clearly and concisely.
Solution:
""")
result = self.model.generate(prompt, max_length=1024)
numeric_answer = ''.join(filter(str.isdigit, result.split(':')[-1]))
try:
answer = int(numeric_answer)
except:
answer = -1
return {"answer_raw": result, "answer": answer}
ファインチューニング前後のモデルのインスタンスを作成します。
weave.init('math-olympiad-project')
# ファインチューニング前モデル
pre_finetuned_model = MathOlympiadModel(model=gemma_lm)
# ファインチューニング後モデル(後で設定)
post_finetuned_model = None
ファインチューニング前のSelf-Consistencyによるアンサンブル
ファインチューニング前のモデルに対してSelf-Consistencyを用いてアンサンブル予測を行います。予測結果をWeaveで管理できるように拡張します。
import numpy as np
from collections import Counter
@weave.op()
def predict_numeric_answers(df, model: MathOlympiadModel, column_name, n_repetitions=3):
total_answers = []
for index, row in df.iterrows():
question = row['problem']
prompt = f"Context: You are an intelligent system designed to solve mathematical problems and provide only the numeric answer without any additional explanations or steps.\n\nProblem: {question}\n\nInstructions:\n- Analyze the given mathematical problem carefully.\n- Identify the unknown quantity or value to be determined.\n- Apply the appropriate mathematical concepts, formulas, and operations to solve the problem.\n- Output only the final numeric answer to the problem, without any additional text or explanations.\n\nAnswer:"
answers = []
for _ in range(n_repetitions):
result = model.predict(problem=question)
answers.append(result['answer'])
total_answers.append(answers)
final_answers = []
for answers in total_answers:
answers = np.array(answers)
answers[answers < 0] = -1
pred = Counter(answers.tolist()).most_common(2)
if len(pred) == 0:
ans = -1
elif len(pred) == 1:
ans = pred[0][0] if pred[0][0] >= 0 else -1
else:
ans = pred[0][0] if pred[0][0] >= 0 else pred[1][0]
final_answers.append(ans)
df[column_name] = final_answers
df[f'{column_name}_match'] = df.answer == df[column_name]
for i in range(n_repetitions):
df[f'{column_name}_ensemble_{i+1}'] = [answers[i] for answers in total_answers]
return df
df1 = predict_numeric_answers(df1, pre_finetuned_model, 'pre_finetune_prediction')
print(f"Pre-finetuning accuracy: {df1.pre_finetune_prediction_match.mean():.2%}")
ファインチューニングの実行
Wandbの設定を行い、モデルをファインチューニングします。
wandb.init(config={"bs": 12})
gemma_lm.fit(Math_data, epochs=60, batch_size=1, callbacks=[WandbMetricsLogger(log_freq="batch")])
ファインチューニングしたモデルを保存します。
gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_VOCAB_PATH)
ファインチューニング後のSelf-Consistencyによるアンサンブル
ファインチューニング後のモデルに対してSelf-Consistencyを用いてアンサンブル予測を行います。
# ファインチューニング後モデルの設定
gemma_lm.load_weights(FINETUNED_WEIGHTS_PATH)
post_finetuned_model = MathOlympiadModel(model=gemma_lm)
df1 = predict_numeric_answers(df1, post_finetuned_model, 'post_finetune_prediction')
print(f"Post-finetuning accuracy: {df1.post_finetune_prediction_match.mean():.2%}")
Weaveでのモデル評価
ファインチューニング前後のモデルの性能をWeave APIで比較・評価します。
まず評価関数を定義します。
from weave.flow.scorer import MultiTaskBinaryClassificationF1
@weave.op()
def answer_match_score(target: dict, model_output: dict) -> dict:
return {'correct': target['answer'] == model_output['answer']}
evaluation = weave.Evaluation(
dataset=df1[['problem', 'answer']].to_dict(orient='records'),
scorers=[MultiTaskBinaryClassificationF1(class_names=['answer']), answer_match_score],
)
ファインチューニング前後のモデルを評価します。
print("Pre-finetuned model evaluation:")
print(evaluation.evaluate(pre_finetuned_model))
print("Post-finetuned model evaluation:")
print(evaluation.evaluate(post_finetuned_model))
まとめ
本記事では、JAX、Wandb、Self-Consistency、Weaveを活用して、Gemma Instruct 2Bモデルをファインチューニングする方法を解説しました。
- JAXをバックエンドに設定し、Keras NLPでモデルをロード
- Weave APIでファインチューニング前後のモデルを管理
- ファインチューニング前にSelf-Consistencyでアンサンブル予測を実施し、Weaveで管理
- Wandbを使ってファインチューニングを実行
- ファインチューニング後にSelf-Consistencyでアンサンブル予測を実施し、Weaveで管理
- Weave APIでファインチューニング前後のモデルの性能を比較・評価
これらのツールを組み合わせることで、効率的にモデルの開発と改善を進めることができます。
特に、Weave APIを使うことで、ファインチューニングの各段階でモデルの性能を容易に比較・評価できるようになります。
ぜひ、皆さんも挑戦してみてください。
コメント