KerasとJAXを使ってGemmaモデルをTPU分散学習する方法

AI・機械学習

概要

Gemmaは、軽量でありながら最先端の性能を持つオープンモデルのファミリーで、Googleの研究と技術を基に構築されています。Gemmaは特定のニーズに合わせてさらにファインチューニングすることができます。しかし、Gemmaのような大規模な言語モデルは、そのサイズゆえに単一のアクセラレータではファインチューニングできない場合があります。その場合、一般的に2つのアプローチがあります。

  1. パラメータ効率の良いファインチューニング(PEFT): モデルサイズを犠牲にして効果的にモデルサイズを縮小する手法。LoRAがこのカテゴリに入ります。
  2. モデル並列化を用いたフルパラメータファインチューニング: モデルの重みを複数のデバイスに分散させ、水平スケーリングを可能にします。

このチュートリアルでは、KerasとJAXバックエンドを使って、GoogleのTensor Processing Unit(TPU)上でLoRAとモデル並列化による分散学習を用いてGemma 7Bモデルをファインチューニングする方法を説明します。


こちらの記事もおすすめ

Keras 3.0とJAXを使ったgemmaのファインチューニング
はじめにKeras 3.0がリリースされ、JAX、TensorFlow、PyTorchのいずれかをバックエンドとして選択できるようになりました。これにより、目的に応じて最適なフレームワークを使い分けることが可能になります。また、Kerasを...
kagglehub を使った大規模言語モデル gemma のファインチューニングとモデル共有
はじめにこんにちは。この記事では、Kaggle の新機能である Kaggle Models を使って、大規模言語モデル gemma をファインチューニングし、コミュニティで共有する方法を初心者向けに解説します。Kaggle Models で...

始める前に

Kaggleの認証情報

Gemmaモデルは Kaggle でホストされています。Gemmaを使用するには、Kaggleでアクセス権を要求します。

次に、Kaggle APIを使用するために、APIトークンを作成します。

  • Kaggleの設定を開く
  • "Create New Token"を選択
  • kaggle.jsonファイルがダウンロードされます。これにはKaggleの認証情報が含まれています

以下のセルを実行します。

import os
from google.colab import userdata
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME') 
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

インストール

KerasとKerasNLPをGemmaモデルと一緒にインストールします。

!pip install tensorflow-cpu~=2.16.0 keras-nlp==0.8.2 tensorflow-hub==0.16.1 keras==3.0.5 tensorflow-text==2.16.1

Keras JAXバックエンドのセットアップ

JAXをインポートし、TPUで動作確認を行います。ColabはTPUv2-8デバイスを提供しており、各8GBの高帯域幅メモリを持つ8つのTPUコアがあります。

import jax

jax.devices()

import os

# Keras 3の分散APIは、現在JAXバックエンドでのみ実装されています 
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  

モデルのロード

import keras
import keras_nlp

JAX環境では、混合精度 (keras.config.set_floatx("bfloat16")) を使用することで、学習の品質に最小限の影響を与えながら、メモリ使用量を節約できます。

keras.config.set_floatx("bfloat16")

TPU上で重みとテンソルを分散してモデルをロードするには、まず新しいDeviceMeshを作成します。 DeviceMeshは、分散計算用に構成されたハードウェアデバイスの集合を表し、Keras 3の統一分散APIの一部として導入されました。

device_mesh = keras.distribution.DeviceMesh(
    (1, 8),  
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMapは、重みとテンソルをどのようにシャーディングまたはレプリケーションするかを指定します。

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

layout_map["token_embedding/embeddings"] = (None, model_dim)

layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    None, model_dim, None)
layout_map["decoder_block.*attention_output.*kernel"] = (
    None, None, model_dim) 
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)

ModelParallelを使用すると、DeviceMesh上のすべてのデバイスにモデルの重みまたは活性化テンソルをシャーディングできます。この場合、Gemma 7Bモデルの一部の重みは、上記で定義したlayout_mapに従って8つのTPUコアに分散されます。

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_7b_en")  

モデルが正しくパーティション分割されていることを確認しましょう。例としてdecoder_block_1を見てみます。

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')

ファインチューニング前の推論

gemma_lm.generate("Best comedy movies: ", max_length=64)  

モデルは90年代の素晴らしいコメディ映画のリストを生成します。 次に、Gemmaモデルをファインチューニングして、出力スタイルを変更します。

IMDbデータセットを用いたファインチューニング

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",  
    split="train",
    as_supervised=True,
    batch_size=2,
)
# ラベルをドロップ
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()

# 高速学習のためにデータセットのサブセットを使用
imdb_train = imdb_train.take(2000)  

Low Rank Adaptation (LoRA)を使用してファインチューニングを行います。 LoRAは、モデルの全重みを凍結し、少数の新しい学習可能な重みをモデルに挿入することで、下流タスクの学習可能なパラメータ数を大幅に削減するファインチューニング手法です。

# モデルに対してLoRAを有効にし、LoRAランクを4に設定
gemma_lm.backbone.enable_lora(rank=4)

# 入力シーケンス長を128に制限してメモリ使用量を制御
gemma_lm.preprocessor.sequence_length = 128

# AdamWオプティマイザを使用
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,  
    weight_decay=0.01,
)
# レイヤーノームとバイアス項を減衰から除外
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,  
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)

LoRAを有効にすると、学習可能なパラメータ数が70億から1100万に大幅に減少することに注意してください。

ファインチューニング後の推論

gemma_lm.generate("Best comedy movies: ", max_length=256)

ファインチューニング後、モデルは映画レビューのスタイルを学習し、90年代のコメディ映画の文脈でそのスタイルで出力を生成するようになりました。

Google Colabノートブック

Google Colaboratory

Kaggle ノートブック

KerasとJAXを使ってGemmaモデルをTPU分散学習する方法
Explore and run machine learning code with Kaggle Notebooks | Using data from Gemma

次のステップ

このチュートリアルでは、KerasNLPのJAXバックエンドを使用して、強力なTPU上でIMDbデータセットを用いてGemmaモデルを分散学習する方法を学びました。さらに学ぶべきことのいくつかの提案を以下に示します。

以上が、提供された情報を基にした初心者向けの日本語記事です。章立てや箇条書き、実行可能なコードブロックを活用し、わかりやすく説明しました。ご確認ください。

コメント

タイトルとURLをコピーしました