概要
Gemmaは、軽量でありながら最先端の性能を持つオープンモデルのファミリーで、Googleの研究と技術を基に構築されています。Gemmaは特定のニーズに合わせてさらにファインチューニングすることができます。しかし、Gemmaのような大規模な言語モデルは、そのサイズゆえに単一のアクセラレータではファインチューニングできない場合があります。その場合、一般的に2つのアプローチがあります。
- パラメータ効率の良いファインチューニング(PEFT): モデルサイズを犠牲にして効果的にモデルサイズを縮小する手法。LoRAがこのカテゴリに入ります。
- モデル並列化を用いたフルパラメータファインチューニング: モデルの重みを複数のデバイスに分散させ、水平スケーリングを可能にします。
このチュートリアルでは、KerasとJAXバックエンドを使って、GoogleのTensor Processing Unit(TPU)上でLoRAとモデル並列化による分散学習を用いてGemma 7Bモデルをファインチューニングする方法を説明します。
こちらの記事もおすすめ
始める前に
Kaggleの認証情報
Gemmaモデルは Kaggle でホストされています。Gemmaを使用するには、Kaggleでアクセス権を要求します。
- kaggle.com にサインインまたは登録
- Gemmaモデルカードを開き、"Request Access"を選択
- 同意書に記入し、利用規約に同意
次に、Kaggle APIを使用するために、APIトークンを作成します。
- Kaggleの設定を開く
- "Create New Token"を選択
kaggle.json
ファイルがダウンロードされます。これにはKaggleの認証情報が含まれています
以下のセルを実行します。
import os
from google.colab import userdata
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
インストール
KerasとKerasNLPをGemmaモデルと一緒にインストールします。
!pip install tensorflow-cpu~=2.16.0 keras-nlp==0.8.2 tensorflow-hub==0.16.1 keras==3.0.5 tensorflow-text==2.16.1
Keras JAXバックエンドのセットアップ
JAXをインポートし、TPUで動作確認を行います。ColabはTPUv2-8デバイスを提供しており、各8GBの高帯域幅メモリを持つ8つのTPUコアがあります。
import jax
jax.devices()
import os
# Keras 3の分散APIは、現在JAXバックエンドでのみ実装されています
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
モデルのロード
import keras
import keras_nlp
JAX環境では、混合精度 (keras.config.set_floatx("bfloat16")
) を使用することで、学習の品質に最小限の影響を与えながら、メモリ使用量を節約できます。
keras.config.set_floatx("bfloat16")
TPU上で重みとテンソルを分散してモデルをロードするには、まず新しいDeviceMesh
を作成します。 DeviceMesh
は、分散計算用に構成されたハードウェアデバイスの集合を表し、Keras 3の統一分散APIの一部として導入されました。
device_mesh = keras.distribution.DeviceMesh(
(1, 8),
["batch", "model"],
devices=keras.distribution.list_devices())
LayoutMap
は、重みとテンソルをどのようにシャーディングまたはレプリケーションするかを指定します。
model_dim = "model"
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, model_dim)
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
None, model_dim, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
None, None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)
ModelParallel
を使用すると、DeviceMesh
上のすべてのデバイスにモデルの重みまたは活性化テンソルをシャーディングできます。この場合、Gemma 7Bモデルの一部の重みは、上記で定義したlayout_map
に従って8つのTPUコアに分散されます。
model_parallel = keras.distribution.ModelParallel(
device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_7b_en")
モデルが正しくパーティション分割されていることを確認しましょう。例としてdecoder_block_1
を見てみます。
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
ファインチューニング前の推論
gemma_lm.generate("Best comedy movies: ", max_length=64)
モデルは90年代の素晴らしいコメディ映画のリストを生成します。 次に、Gemmaモデルをファインチューニングして、出力スタイルを変更します。
IMDbデータセットを用いたファインチューニング
import tensorflow_datasets as tfds
imdb_train = tfds.load(
"imdb_reviews",
split="train",
as_supervised=True,
batch_size=2,
)
# ラベルをドロップ
imdb_train = imdb_train.map(lambda x, y: x)
imdb_train.unbatch().take(1).get_single_element().numpy()
# 高速学習のためにデータセットのサブセットを使用
imdb_train = imdb_train.take(2000)
Low Rank Adaptation (LoRA)を使用してファインチューニングを行います。 LoRAは、モデルの全重みを凍結し、少数の新しい学習可能な重みをモデルに挿入することで、下流タスクの学習可能なパラメータ数を大幅に削減するファインチューニング手法です。
# モデルに対してLoRAを有効にし、LoRAランクを4に設定
gemma_lm.backbone.enable_lora(rank=4)
# 入力シーケンス長を128に制限してメモリ使用量を制御
gemma_lm.preprocessor.sequence_length = 128
# AdamWオプティマイザを使用
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
# レイヤーノームとバイアス項を減衰から除外
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
LoRAを有効にすると、学習可能なパラメータ数が70億から1100万に大幅に減少することに注意してください。
ファインチューニング後の推論
gemma_lm.generate("Best comedy movies: ", max_length=256)
ファインチューニング後、モデルは映画レビューのスタイルを学習し、90年代のコメディ映画の文脈でそのスタイルで出力を生成するようになりました。
Google Colabノートブック
Kaggle ノートブック
次のステップ
このチュートリアルでは、KerasNLPのJAXバックエンドを使用して、強力なTPU上でIMDbデータセットを用いてGemmaモデルを分散学習する方法を学びました。さらに学ぶべきことのいくつかの提案を以下に示します。
以上が、提供された情報を基にした初心者向けの日本語記事です。章立てや箇条書き、実行可能なコードブロックを活用し、わかりやすく説明しました。ご確認ください。
コメント