Keras 3.0とJAXを使ったgemmaのファインチューニング

AI・機械学習

はじめに

Keras 3.0がリリースされ、JAX、TensorFlow、PyTorchのいずれかをバックエンドとして選択できるようになりました。これにより、目的に応じて最適なフレームワークを使い分けることが可能になります。また、Kerasを低レベルのクロスフレームワーク言語として活用し、レイヤー、モデル、メトリクスなどのカスタムコンポーネントを開発することもできます。


こちらの記事もおすすめ

kagglehub を使った大規模言語モデル gemma のファインチューニングとモデル共有
はじめにこんにちは。この記事では、Kaggle の新機能である Kaggle Models を使って、大規模言語モデル gemma をファインチューニングし、コミュニティで共有する方法を初心者向けに解説します。Kaggle Models で...
LLama 3のSFTTrainer+Weights & Biasesでファインチューニング
はじめにLLama 3は、Meta社が開発した大規模言語モデルです。高性能でありながら、一般的なGPUでも扱えるサイズのモデルが提供されています。このモデルをファインチューニングすることで、様々なタスクに適用できます。本記事では、Huggi...

JAXとは

JAXは、NumPyの構文を使って機械学習モデルを記述でき、自動微分、JIT コンパイル、モデル並列化などの機能を提供するフレームワークです。GPUやTPUを利用した高速な計算が可能で、特に大規模モデルの学習に適しています。

Keras 3.0のメリット

Keras 3.0を使うことで、以下のようなメリットが得られます。

  • 最高のパフォーマンスを常に得られる: バックエンドを動的に選択することで、モデルに応じて最適なパフォーマンスを発揮できます。
  • エコシステムの選択肢が広がる: Keras 3モデルは、PyTorch、TensorFlow、JAXのエコシステムパッケージと組み合わせて使用できます。
  • JAXによる大規模なモデル並列化とデータ並列化: keras.distributionを使って、モデル並列化とデータ並列化を簡単に実現できます。
  • オープンソースモデルのリリースの影響力を最大化: Keras 3で実装されたモデルは、フレームワークに関係なく誰でも使用できます。
  • 任意のデータパイプラインを使用可能: tf.data.DatasetPyTorch DataLoader、NumPy配列、Pandasデータフレームなどに対応しています。

JAXをバックエンドに使用する

Keras 3.0では、以下のようにJAXをバックエンドとして指定できます。

import os

os.environ["KERAS_BACKEND"] = "jax"  # JAXをバックエンドとして指定
# JAXバックエンドでのメモリ断片化を回避
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

import keras
import keras_nlp

モデルの準備

Keras 3.0では、以下のようにモデルを準備します。

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
gemma_lm.summary()

gemma_lm.backbone.enable_lora(rank=64)

# 入力シーケンス長を512に制限(メモリ使用量のコントロール)
gemma_lm.preprocessor.sequence_length = 512
# AdamWオプティマイザを使用
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,  
)
# レイヤーノームとバイアス項を減衰の対象から除外
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],  
)
  • LoRA (Low-Rank Adaptation)を有効化し、メモリ使用量を削減
  • 入力シーケンス長を512に制限して、メモリ使用量をコントロール
  • 最適化アルゴリズムにAdamWを使用し、Layer NormalizationとBiasの重みを減衰の対象から除外

ファインチューニング

データとモデルの準備が完了したら、以下のようにモデルをファインチューニングします。

gemma_lm.fit(Math_data, epochs=60, batch_size=1)  

予測

ファインチューニング前後のモデルを使って、以下のように予測を行います。

pre_finetune_predictions = []
post_finetune_predictions = []

for index, row in df1.iterrows():
    question = row['problem']
    prompt = f"Context: You are an intelligent system designed to solve mathematical problems and provide only the numeric answer without any additional explanations or steps.\n\nProblem: {question}\n\nInstructions:\n- Analyze the given mathematical problem carefully.\n- Identify the unknown quantity or value to be determined.\n- Apply the appropriate mathematical concepts, formulas, and operations to solve the problem.\n- Output only the final numeric answer to the problem, without any additional text or explanations.\n\nAnswer:"  

    # ファインチューニング前の予測
    pre_answer = gemma_lm.generate(prompt, max_length=256)
    pre_numeric_answer = ''.join(filter(str.isdigit, pre_answer.split(':')[-1]))
    pre_finetune_predictions.append(int(pre_numeric_answer))

    # ファインチューニング後の予測  
    post_answer = gemma_lm.generate(prompt, max_length=256)
    post_numeric_answer = ''.join(filter(str.isdigit, post_answer.split(':')[-1]))
    post_finetune_predictions.append(int(post_numeric_answer))

df1['pre_finetune_prediction'] = pre_finetune_predictions
df1['pre_finetune_match'] = df1.answer == df1.pre_finetune_prediction
df1['post_finetune_prediction'] = post_finetune_predictions  
df1['post_finetune_match'] = df1.answer == df1.post_finetune_prediction

まとめ

Keras 3.0とJAXを組み合わせることで、高速かつ効率的な機械学習を実現できます。
ファインチューニングにより、モデルの予測精度が大幅に向上することが確認できました。
今後は更なるハイパーパラメータの調整やアンサンブル手法の導入などにより、予測精度の向上が期待できるでしょう。

ノートブック

《JP》AIMO Gemma Instruct 2B Finetuning with JAX¶
Explore and run machine learning code with Kaggle Notebooks | Using data from multiple data sources

参考サイト

Keras: Deep Learning for humans
Keras Core documentation
Compatibility Issue with Loading Model in Keras 3 · Issue #28 · unicode-org/lstm_word_segmentation
when attempting to load pre-trained model in Keras 3, encountering issue: ValueError: File format not supported: filepath=/home/srvk/Desktop/lstm_w/lstm_word_se...
load saved model in keras with tensorflow backend! · Issue #18709 · keras-team/keras
keras.models.load_model('my_model') keras.model.load_weights('my_model') With tf.keras it works but in new keras, ValueError: File format not supported: filepat...
Google Colaboratory
Save and Load a Keras NLP Tuned model · Issue #271 · google/generative-ai-docs
Hi, I was following this tutorial Link : Could please tell me as to how to save a f...
What's the best practice to save and load trained/fine-tuned models e.g. `BertClassifier`? · keras-team/keras-nlp · Discussion #1039
Hi @jbischof! I was wondering what's the best way to save and load any ...Classifier model e.g. BertClassifier that has been trained/fine-tuned via .fit on cust...
Keras documentation: KerasNLP
Keras documentation

コメント

タイトルとURLをコピーしました