JAXとWandbとSelf-ConsistencyとWeaveを使ったGemma Instruct 2Bモデルのファインチューニング入門

LLM

Gemma Instruct 2Bモデルは、自然言語処理タスクに優れたパフォーマンスを発揮する大規模言語モデルです。このモデルをファインチューニングすることで、特定のタスクにおける性能をさらに向上させることができます。

本記事では、JAX、Wandb、Self-Consistency、Weaveを活用して、Gemma Instruct 2Bモデルをファインチューニングする方法を初心者向けに解説します。章立てやコードブロックを多用し、可読性の高い記事を目指します。

環境設定

まずは必要なライブラリをインストールします。

!pip install -U kagglehub kaggle
!pip install -U keras_nlp wandb weave

次に、WandbとKaggleの設定を行います。APIキーとユーザー名を設定してください。

import wandb
from wandb.keras import WandbMetricsLogger

from google.colab import userdata
wandb_api_key = userdata.get('WANDB_API_KEY')
!wandb login $wandb_api_key
import os

os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY') 
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
KAGGLE_USERNAME = userdata.get('KAGGLE_USERNAME')

データの準備

Kaggleからデータをダウンロードし、trainデータを読み込みます。

!kaggle competitions download -c ai-mathematical-olympiad-prize --force
!unzip ai-mathematical-olympiad-prize -d ai-mathematical-olympiad-prize

import pandas as pd

df1 = pd.read_csv("/content/ai-mathematical-olympiad-prize/train.csv")

trainデータからプロンプトを生成します。

Math_data = []

for index, row in df1.iterrows():
    question, answer = row['problem'], row['answer']
    template = (f"""
    Context: You are an intelligent math tutor tasked with solving mathematical problems and explaining the solution steps in a clear and concise manner.

Problem: {question}

Instructions:
- Carefully analyze the given problem and identify the key information, known values, and the unknown quantity to be found.
- Break down the problem into smaller steps if necessary, and apply the appropriate mathematical concepts, formulas, and operations to solve the problem.
- Show your step-by-step working, explicitly stating the reasoning behind each step. 
- If relevant, provide additional explanations, examples, or visualizations to aid understanding.
- Finally, state the final answer to the problem clearly and concisely.

Solution: {answer}

    """)
    Math_data.append(template)

モデルの準備

JAXをバックエンドに設定し、Keras NLPを使ってGemma Instruct 2Bモデルをロードします。

import os

os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

import keras
import keras_nlp

MODEL_NAME = "gemma_instruct_2b_olympiad"
FINETUNED_WEIGHTS_PATH = f"{MODEL_NAME}.weights.h5"
FINETUNED_VOCAB_PATH = f"vocabulary.spm"

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")

LoRAを有効化し、最適化アルゴリズムにAdamWを使用します。

gemma_lm.backbone.enable_lora(rank=64)

gemma_lm.preprocessor.sequence_length = 512

optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

Weaveでのモデル管理

Weave APIを使って、ファインチューニング前後のGemma Instruct 2Bモデルの推論結果の記録、実験、評価を行います。

まずはモデルを定義します。

import weave

class MathOlympiadModel(weave.Model):
    model: keras_nlp.models.GemmaCausalLM

    @weave.op()
    def predict(self, problem: str) -> dict:
        prompt = (f"""
        Context: You are an intelligent math tutor tasked with solving mathematical problems and explaining the solution steps in a clear and concise manner.

Problem: {problem}

Instructions: 
- Carefully analyze the given problem and identify the key information, known values, and the unknown quantity to be found.
- Break down the problem into smaller steps if necessary, and apply the appropriate mathematical concepts, formulas, and operations to solve the problem.
- Show your step-by-step working, explicitly stating the reasoning behind each step.
- If relevant, provide additional explanations, examples, or visualizations to aid understanding. 
- Finally, state the final answer to the problem clearly and concisely.

Solution:
        """)

        result = self.model.generate(prompt, max_length=1024)
        numeric_answer = ''.join(filter(str.isdigit, result.split(':')[-1]))
        try:
            answer = int(numeric_answer)
        except:
            answer = -1

        return {"answer_raw": result, "answer": answer}

ファインチューニング前後のモデルのインスタンスを作成します。

weave.init('math-olympiad-project')

# ファインチューニング前モデル
pre_finetuned_model = MathOlympiadModel(model=gemma_lm)

# ファインチューニング後モデル(後で設定)
post_finetuned_model = None

ファインチューニング前のSelf-Consistencyによるアンサンブル

ファインチューニング前のモデルに対してSelf-Consistencyを用いてアンサンブル予測を行います。予測結果をWeaveで管理できるように拡張します。

import numpy as np
from collections import Counter

@weave.op()
def predict_numeric_answers(df, model: MathOlympiadModel, column_name, n_repetitions=3):
    total_answers = []

    for index, row in df.iterrows():
        question = row['problem'] 
        prompt = f"Context: You are an intelligent system designed to solve mathematical problems and provide only the numeric answer without any additional explanations or steps.\n\nProblem: {question}\n\nInstructions:\n- Analyze the given mathematical problem carefully.\n- Identify the unknown quantity or value to be determined.\n- Apply the appropriate mathematical concepts, formulas, and operations to solve the problem.\n- Output only the final numeric answer to the problem, without any additional text or explanations.\n\nAnswer:"

        answers = []
        for _ in range(n_repetitions):
            result = model.predict(problem=question)
            answers.append(result['answer'])

        total_answers.append(answers)

    final_answers = []
    for answers in total_answers:
        answers = np.array(answers)
        answers[answers < 0] = -1

        pred = Counter(answers.tolist()).most_common(2)

        if len(pred) == 0:
            ans = -1
        elif len(pred) == 1:
            ans = pred[0][0] if pred[0][0] >= 0 else -1
        else:
            ans = pred[0][0] if pred[0][0] >= 0 else pred[1][0]

        final_answers.append(ans)

    df[column_name] = final_answers
    df[f'{column_name}_match'] = df.answer == df[column_name]

    for i in range(n_repetitions):
        df[f'{column_name}_ensemble_{i+1}'] = [answers[i] for answers in total_answers]

    return df

df1 = predict_numeric_answers(df1, pre_finetuned_model, 'pre_finetune_prediction')
print(f"Pre-finetuning accuracy: {df1.pre_finetune_prediction_match.mean():.2%}")

ファインチューニングの実行

Wandbの設定を行い、モデルをファインチューニングします。

wandb.init(config={"bs": 12})

gemma_lm.fit(Math_data, epochs=60, batch_size=1, callbacks=[WandbMetricsLogger(log_freq="batch")])

ファインチューニングしたモデルを保存します。

gemma_lm.save_weights(FINETUNED_WEIGHTS_PATH)
gemma_lm.preprocessor.tokenizer.save_assets(FINETUNED_VOCAB_PATH)

ファインチューニング後のSelf-Consistencyによるアンサンブル

ファインチューニング後のモデルに対してSelf-Consistencyを用いてアンサンブル予測を行います。

# ファインチューニング後モデルの設定
gemma_lm.load_weights(FINETUNED_WEIGHTS_PATH)
post_finetuned_model = MathOlympiadModel(model=gemma_lm)

df1 = predict_numeric_answers(df1, post_finetuned_model, 'post_finetune_prediction')
print(f"Post-finetuning accuracy: {df1.post_finetune_prediction_match.mean():.2%}")

Weaveでのモデル評価

ファインチューニング前後のモデルの性能をWeave APIで比較・評価します。

まず評価関数を定義します。

from weave.flow.scorer import MultiTaskBinaryClassificationF1

@weave.op()
def answer_match_score(target: dict, model_output: dict) -> dict:
    return {'correct': target['answer'] == model_output['answer']}

evaluation = weave.Evaluation(
    dataset=df1[['problem', 'answer']].to_dict(orient='records'), 
    scorers=[MultiTaskBinaryClassificationF1(class_names=['answer']), answer_match_score],
)

ファインチューニング前後のモデルを評価します。

print("Pre-finetuned model evaluation:")
print(evaluation.evaluate(pre_finetuned_model))

print("Post-finetuned model evaluation:")  
print(evaluation.evaluate(post_finetuned_model))

file

まとめ

本記事では、JAX、Wandb、Self-Consistency、Weaveを活用して、Gemma Instruct 2Bモデルをファインチューニングする方法を解説しました。

  1. JAXをバックエンドに設定し、Keras NLPでモデルをロード
  2. Weave APIでファインチューニング前後のモデルを管理
  3. ファインチューニング前にSelf-Consistencyでアンサンブル予測を実施し、Weaveで管理
  4. Wandbを使ってファインチューニングを実行
  5. ファインチューニング後にSelf-Consistencyでアンサンブル予測を実施し、Weaveで管理
  6. Weave APIでファインチューニング前後のモデルの性能を比較・評価

これらのツールを組み合わせることで、効率的にモデルの開発と改善を進めることができます。
特に、Weave APIを使うことで、ファインチューニングの各段階でモデルの性能を容易に比較・評価できるようになります。
ぜひ、皆さんも挑戦してみてください。

ノートブック

Google Colaboratory

参考サイト

LLMアプリケーションの記録・実験・評価のプラットフォーム Weave を試す|npaka
LLMアプリケーションの記録・実験・評価のプラットフォーム「Weave」がリリースされたので、試してみました。 この入門記事は、「Weights & Biases」のご支援により提供されています。Weights & Biases JapanのNoteでは他にも多くの有用な記事が掲載されていますので是非ご覧ください。...
Tutorial: Build an Evaluation pipeline | W&B Weave
To iterate on an application, we need a way to evaluate if it's improving. To do so, a common practice is to test it against the same set of examples when there...
Introduction | W&B Weave
Weave is a lightweight toolkit for tracking and evaluating LLM applications, built by Weights & Biases.

コメント

タイトルとURLをコピーしました