はじめに
Keras 3.0がリリースされ、JAX、TensorFlow、PyTorchのいずれかをバックエンドとして選択できるようになりました。これにより、目的に応じて最適なフレームワークを使い分けることが可能になります。また、Kerasを低レベルのクロスフレームワーク言語として活用し、レイヤー、モデル、メトリクスなどのカスタムコンポーネントを開発することもできます。
こちらの記事もおすすめ
kagglehub を使った大規模言語モデル gemma のファインチューニングとモデル共有
はじめにこんにちは。この記事では、Kaggle の新機能である Kaggle Models を使って、大規模言語モデル gemma をファインチューニングし、コミュニティで共有する方法を初心者向けに解説します。Kaggle Models で...
LLama 3のSFTTrainer+Weights & Biasesでファインチューニング
はじめにLLama 3は、Meta社が開発した大規模言語モデルです。高性能でありながら、一般的なGPUでも扱えるサイズのモデルが提供されています。このモデルをファインチューニングすることで、様々なタスクに適用できます。本記事では、Huggi...
JAXとは
JAXは、NumPyの構文を使って機械学習モデルを記述でき、自動微分、JIT コンパイル、モデル並列化などの機能を提供するフレームワークです。GPUやTPUを利用した高速な計算が可能で、特に大規模モデルの学習に適しています。
Keras 3.0のメリット
Keras 3.0を使うことで、以下のようなメリットが得られます。
- 最高のパフォーマンスを常に得られる: バックエンドを動的に選択することで、モデルに応じて最適なパフォーマンスを発揮できます。
- エコシステムの選択肢が広がる: Keras 3モデルは、PyTorch、TensorFlow、JAXのエコシステムパッケージと組み合わせて使用できます。
- JAXによる大規模なモデル並列化とデータ並列化:
keras.distribution
を使って、モデル並列化とデータ並列化を簡単に実現できます。 - オープンソースモデルのリリースの影響力を最大化: Keras 3で実装されたモデルは、フレームワークに関係なく誰でも使用できます。
- 任意のデータパイプラインを使用可能:
tf.data.Dataset
、PyTorch DataLoader
、NumPy配列、Pandasデータフレームなどに対応しています。
JAXをバックエンドに使用する
Keras 3.0では、以下のようにJAXをバックエンドとして指定できます。
import os
os.environ["KERAS_BACKEND"] = "jax" # JAXをバックエンドとして指定
# JAXバックエンドでのメモリ断片化を回避
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
import keras
import keras_nlp
モデルの準備
Keras 3.0では、以下のようにモデルを準備します。
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_2b_en")
gemma_lm.summary()
gemma_lm.backbone.enable_lora(rank=64)
# 入力シーケンス長を512に制限(メモリ使用量のコントロール)
gemma_lm.preprocessor.sequence_length = 512
# AdamWオプティマイザを使用
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# レイヤーノームとバイアス項を減衰の対象から除外
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
- LoRA (Low-Rank Adaptation)を有効化し、メモリ使用量を削減
- 入力シーケンス長を512に制限して、メモリ使用量をコントロール
- 最適化アルゴリズムにAdamWを使用し、Layer NormalizationとBiasの重みを減衰の対象から除外
ファインチューニング
データとモデルの準備が完了したら、以下のようにモデルをファインチューニングします。
gemma_lm.fit(Math_data, epochs=60, batch_size=1)
予測
ファインチューニング前後のモデルを使って、以下のように予測を行います。
pre_finetune_predictions = []
post_finetune_predictions = []
for index, row in df1.iterrows():
question = row['problem']
prompt = f"Context: You are an intelligent system designed to solve mathematical problems and provide only the numeric answer without any additional explanations or steps.\n\nProblem: {question}\n\nInstructions:\n- Analyze the given mathematical problem carefully.\n- Identify the unknown quantity or value to be determined.\n- Apply the appropriate mathematical concepts, formulas, and operations to solve the problem.\n- Output only the final numeric answer to the problem, without any additional text or explanations.\n\nAnswer:"
# ファインチューニング前の予測
pre_answer = gemma_lm.generate(prompt, max_length=256)
pre_numeric_answer = ''.join(filter(str.isdigit, pre_answer.split(':')[-1]))
pre_finetune_predictions.append(int(pre_numeric_answer))
# ファインチューニング後の予測
post_answer = gemma_lm.generate(prompt, max_length=256)
post_numeric_answer = ''.join(filter(str.isdigit, post_answer.split(':')[-1]))
post_finetune_predictions.append(int(post_numeric_answer))
df1['pre_finetune_prediction'] = pre_finetune_predictions
df1['pre_finetune_match'] = df1.answer == df1.pre_finetune_prediction
df1['post_finetune_prediction'] = post_finetune_predictions
df1['post_finetune_match'] = df1.answer == df1.post_finetune_prediction
まとめ
Keras 3.0とJAXを組み合わせることで、高速かつ効率的な機械学習を実現できます。
ファインチューニングにより、モデルの予測精度が大幅に向上することが確認できました。
今後は更なるハイパーパラメータの調整やアンサンブル手法の導入などにより、予測精度の向上が期待できるでしょう。
ノートブック
《JP》AIMO Gemma Instruct 2B Finetuning with JAX¶
Explore and run machine learning code with Kaggle Notebooks | Using data from multiple data sources
参考サイト
Keras: Deep Learning for humans
Keras Core documentation
Compatibility Issue with Loading Model in Keras 3 · Issue #28 · unicode-org/lstm_word_segmentation
when attempting to load pre-trained model in Keras 3, encountering issue: ValueError: File format not supported: filepath=/home/srvk/Desktop/lstm_w/lstm_word_se...
load saved model in keras with tensorflow backend! · Issue #18709 · keras-team/keras
keras.models.load_model('my_model') keras.model.load_weights('my_model') With tf.keras it works but in new keras, ValueError: File format not supported: filepat...
Google Colaboratory
Save and Load a Keras NLP Tuned model · Issue #271 · google/generative-ai-docs
Hi, I was following this tutorial Link : Could please tell me as to how to save a f...
What's the best practice to save and load trained/fine-tuned models e.g. `BertClassifier`? · keras-team/keras-nlp · Discussion #1039
Hi @jbischof! I was wondering what's the best way to save and load any ...Classifier model e.g. BertClassifier that has been trained/fine-tuned via .fit on cust...
Keras documentation: KerasNLP
Keras documentation
コメント