Flaxを使用したRecurrentGemma2Bグリフィンモデルのファインチューニングチュートリアル(Kaggle、GoogleColabノート付)

AI

こんにちは!今回は、Flaxを使用して2Bグリフィンモデルをシンプルな翻訳タスクにファインチューニングする方法を学びます。グリフィンモデルは強力な言語モデルで、ファインチューニングによって特定のタスクに適応させることができます。

初心者の方にも分かりやすいよう、コードの説明を丁寧に行いながら、ステップバイステップでチュートリアルを進めていきます。それでは、早速始めていきましょう!

RecurrentGemma-9b: 革新的な自然言語処理モデルの登場
はじめに近年、自然言語処理(NLP)の分野では、大規模な言語モデルが目覚ましい進歩を遂げています。そんな中、Googleが開発したRecurrentGemmaモデルが注目を集めています。RecurrentGemmaは、従来のGemmaモデル...

セットアップ

まずは必要なライブラリをインストールし、環境を整えていきます。

!pip list --format=freeze > requirements.kaggle.txt
!pip list

pip listコマンドでインストール済みのライブラリを一覧表示し、requirements.kaggle.txtファイルに出力しています。これは現在の環境を再現するために必要な情報です。

次に、RecurrentGemmaライブラリをインストールします。このライブラリはグリフィンモデルを扱うために必要です。

!pip install git+https://github.com/google-deepmind/recurrentgemma.git

これでRecurrentGemmaがインストールできました。続いて、必要なライブラリをインポートしていきます。

import pathlib
from typing import Any
import enum
import functools

# JAXと関連パッケージをインポート
import chex
import jax
import jax.numpy as jnp
import optax

# データセットを扱うためにTensorFlowをインポート
import tensorflow as tf
import tensorflow_datasets as tfds

# RecurrentGemmaをインポート
import sentencepiece as spm
from recurrentgemma import jax as recurrentgemma

pathlib, typing, enumなどの標準ライブラリに加え、JAX, TensorFlow, RecurrentGemmaに関連するライブラリをインポートしました。

  • JAX: 機械学習のための高性能な数値計算ライブラリ
  • TensorFlow: 機械学習用のオープンソースプラットフォーム
  • RecurrentGemma: グリフィンモデルを扱うためのライブラリ

これらのライブラリを使って、グリフィンモデルのファインチューニングを行っていきます。

ファインチューニングするチェックポイントの選択

次に、ファインチューニングに使用するチェックポイントを選択します。2Bモデルはファインチューニングにメモリ内に収まるサイズです。

VARIANT = '2b-it' # @param ['2b', '2b-it', '9b', '9b-it'] {type:"string"}
# weights_dir = kagglehub.model_download(f'google/recurrentgemma/Flax/{VARIANT}')

weights_dir = pathlib.Path(f"/kaggle/input/recurrentgemma/flax/{VARIANT}/1")
ckpt_path = weights_dir / VARIANT
vocab_path = weights_dir / 'tokenizer.model'
preset = recurrentgemma.Preset.RECURRENT_GEMMA_2B_V1 if '2b' in VARIANT else recurrentgemma.Preset.RECURRENT_GEMMA_9B_V1
ckpt_path
!ls /kaggle/input/recurrentgemma/flax/2b-it/1/2b-it

ここでは、2b-itバリアントを選択しています。weights_dirにチェックポイントのパスを、ckpt_pathにモデルのパスを、vocab_pathに語彙ファイルのパスを指定しています。

presetには、モデルのサイズに応じたプリセット設定を選択しています。2Bモデルの場合はRECURRENT_GEMMA_2B_V1、9Bモデルの場合はRECURRENT_GEMMA_9B_V1が使用されます。

ckpt_pathとlsコマンドの出力で、選択したチェックポイントのパスを確認しています。

# パラメータのロード
params =  recurrentgemma.load_parameters(ckpt_path, "single_device")
model_config = recurrentgemma.GriffinConfig.from_flax_params_or_variables(params, preset=preset)
model = recurrentgemma.Griffin(model_config)

ここでは、選択したチェックポイントからパラメータをロードし、モデルの設定を行っています。

  • recurrentgemma.load_parameters関数でチェックポイントからパラメータをロード
  • recurrentgemma.GriffinConfig.from_flax_params_or_variables関数でパラメータとプリセットからモデルの設定を作成
  • recurrentgemma.Griffinクラスでモデルのインスタンスを生成

これで、グリフィンモデルの準備が整いました。次は、ファインチューニングに使用するデータセットの準備に移ります。

Step 1: データセットの準備

MTNTデータセット

このチュートリアルでは、MTNTデータセットを使用します。このデータセットは、論文「MTNT: A Testbed for Machine Translation of Noisy Text」で提案されたもので、ノイズの多いテキストの機械翻訳のためのテストベッドとなっています。

MTNTデータセットは、TensorFlow Datasetsで直接利用可能です。ここでは、特に英語からフランス語への翻訳に焦点を当てます。

まずは、データセット自体を見てみましょう。

ds = tfds.load("mtnt/en-fr", split="train")
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val.decode("utf-8")}')
  print()

データセットの各サンプルには、以下の2つのエントリが含まれています。

  • 'src': 原文の英語の文章
  • 'dst': 対応するフランス語の翻訳

データセットからランダムに2つのサンプルを取得し、原文と翻訳を表示しています。

トークナイザー

次に、語彙ベースのトークナイザーをロードしていきます。ここでは、SentencePieceライブラリを使用してトークナイザーを構築します。

vocab = spm.SentencePieceProcessor()
vocab.Load(str(vocab_path))

vocab_pathで指定したパスから語彙ファイルをロードし、SentencePieceProcessorオブジェクトを作成しています。

英語からフランス語への翻訳タスクに合わせて、SentencePieceProcessorをカスタマイズしていきます。英語のみのグリフィン2Bモデルをファインチューニングするため、いくつかの調整が必要です。

  • 入力プレフィックス: 各入力に共通のプレフィックスを追加し、翻訳タスクであることを示します。例えば、'Translate this into French: [INPUT_SENTENCE]' のようなプロンプトを使用できます。

  • 翻訳開始サフィックス: プロンプトの末尾に翻訳開始位置を示すサフィックスを追加します。改行文字で十分でしょう。

  • LMトークン: グリフィンモデルは、各シーケンスの先頭にシーケンス開始トークンを、末尾にシーケンス終了トークンを期待します。

これらを踏まえて、カスタムトークナイザーを実装していきます。

class GriffinTokenizer:
  """SentencePieceProcessorのTensorFlow用カスタムラッパー"""

  def __init__(self, spm_processor: spm.SentencePieceProcessor):
    self._spm_processor = spm_processor

  @property
  def pad_id(self) -> int:
    """パッドトークンのIDを返す"""
    return self._spm_processor.pad_id()

  def tokenize(
      self,
      example: str | bytes,
      prefix: str = '',
      suffix: str = '',
      add_eos: bool = True,
  ) -> jax.Array:
    """トークン化関数

    Args:
      example: トークン化する入力文字列
      prefix:  入力文字列の先頭に追加するプレフィックス
      suffix:  入力文字列の末尾に追加するサフィックス
      add_eos: Trueの場合、出力シーケンスの末尾にEOSトークンを追加

    Returns:
      入力文字列に対応するトークン列
    """
    int_list = [self._spm_processor.bos_id()]
    int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
    if add_eos:
      int_list.append(self._spm_processor.eos_id())

    return jnp.array(int_list, dtype=jnp.int32)

  def tokenize_tf_op(
      self,
      str_tensor: tf.Tensor,
      prefix: str = '',
      suffix: str = '',
      add_eos: bool = True,
  ) -> tf.Tensor:
    """tokenize関数のTensorFlow演算子版"""
    encoded = tf.numpy_function(
        self.tokenize,
        [str_tensor, prefix, suffix, add_eos],
        tf.int32)
    encoded.set_shape([None])
    return encoded

  def to_string(self, tokens: jax.Array) -> str:
    """トークン列を文字列に変換"""
    return self._spm_processor.EncodeIds(tokens.tolist())

GriffinTokenizerクラスは、SentencePieceProcessorのラッパークラスです。主要なメソッドは以下の通りです。

  • __init__: コンストラクタ。SentencePieceProcessorオブジェクトを受け取ります。
  • pad_id: パッドトークンのIDを返すプロパティ。
  • tokenize: 文字列をトークン化するメソッド。プレフィックスとサフィックスの追加、EOSトークンの追加が可能。
  • tokenize_tf_op: tokenizeメソッドのTensorFlow演算子版。
  • to_string: トークン列を文字列に変換するメソッド。

これで、MTNTデータセット用のカスタムトークナイザーが完成しました。実際にデータセットに適用してみましょう。

def tokenize_source(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(
      example,
      prefix='Translate this into French:\n',
      suffix='\n',
      add_eos=False
  )

def tokenize_destination(tokenizer, example: tf.Tensor):
  return tokenizer.tokenize_tf_op(example, add_eos=True)

tokenizer = GriffinTokenizer(vocab)
ds = tfds.load("mtnt/en-fr",split="train")
ds = ds.take(2)
ds = ds.map(lambda x: {
    'src': tokenize_source(tokenizer, x['src']),
    'dst': tokenize_destination(tokenizer, x['dst'])
  })
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

まず、原文(src)とフランス語訳(dst)をトークン化する関数を定義しています。

  • tokenize_source: 原文をトークン化。プレフィックスとサフィックスを追加し、EOSトークンは追加しない。
  • tokenize_destination: フランス語訳をトークン化。EOSトークンを追加。

次に、GriffinTokenizerオブジェクトを作成し、データセットの各サンプルに適用しています。

  • tfds.loadでMTNTデータセットをロード
  • ds.takeで先頭の2サンプルを取得
  • ds.mapで各サンプルのsrcとdstをトークン化
  • ds.as_numpy_iteratorでNumPyイテレータに変換

最後に、トークン化された各サンプルの情報を表示しています。原文とフランス語訳がトークンIDのリストに変換されていることが確認できます。

データローダー

これで、データセットの準備ができました。あとは、これらをまとめてデータローダーを構築するだけです。

@chex.dataclass(frozen=True)
class TrainingInput:
  # モデルに入力されるトークン
  input_tokens: jax.Array

  # 損失計算の対象となるトークンを決定するマスク
  target_mask: jax.Array

class DatasetSplit(enum.Enum):
  TRAIN = 'train'
  VALIDATION = 'valid'

class MTNTDatasetBuilder:
  """MTNTデータセット用のデータローダー"""

  N_ITEMS = {DatasetSplit.TRAIN: 35_692, DatasetSplit.VALIDATION: 811}

  BUFFER_SIZE_SHUFFLE = 10_000
  TRANSLATION_PREFIX = 'Translate this into French:\n'
  TRANSLATION_SUFFIX = '\n'

  def __init__(self,
               tokenizer : GriffinTokenizer,
               max_seq_len: int):
    """コンストラクタ

    Args:
      tokenizer: 使用するトークナイザー
      max_seq_len: バッチ内の各シーケンスの最大長
    """
    self._tokenizer = tokenizer
    self._base_data = {
        DatasetSplit.TRAIN: tfds.load("mtnt/en-fr",split="train"),
        DatasetSplit.VALIDATION: tfds.load("mtnt/en-fr",split="valid"),
    }
    self._max_seq_len = max_seq_len

  def _tokenize_source(self, example: tf.Tensor):
    """原文のトークン化関数"""
    return self._tokenizer.tokenize_tf_op(
        example, prefix=self.TRANSLATION_PREFIX, suffix=self.TRANSLATION_SUFFIX,
        add_eos=False
    )

  def _tokenize_destination(self, example: tf.Tensor):
    """フランス語訳のトークン化関数"""
    return self._tokenizer.tokenize_tf_op(example, add_eos=True)

  def _pad_up_to_max_len(self,
                         input_tensor: tf.Tensor,
                         pad_value: int | bool,
                         ) -> tf.Tensor:
    """バッチ内のシーケンス長に合わせてテンソルをパディング"""
    seq_len = tf.shape(input_tensor)[0]
    to_pad = tf.maximum(self._max_seq_len - seq_len, 0)
    return tf.pad(
        input_tensor, [[0, to_pad]], mode='CONSTANT', constant_values=pad_value,
    )

  def _to_training_input(
      self,
      src_tokens: jax.Array,
      dst_tokens: jax.Array,
  ) -> TrainingInput:
    """原文とフランス語訳のトークン列からTrainingInputを作成"""

    # モデルに入力するシーケンスは、原文とフランス語訳を連結したもの
    tokens = tf.concat([src_tokens, dst_tokens], axis=0)

    # 原文(入力)トークンに基づいてモデルを更新しないようにするため、
    # 各入力にターゲットマスクを追加
    q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)
    a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)
    mask = tf.concat([q_mask, a_mask], axis=0)

    # 出力トークン列がターゲットシーケンス長より短い場合は、
    # パッドトークンを追加
    tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)

    # パッドトークンに対してはバックワードを行わない
    mask = self._pad_up_to_max_len(mask, False)

    return TrainingInput(input_tokens=tokens, target_mask=mask)

  def get_train_dataset(self, batch_size: int, num_epochs: int):
    """学習データセットの作成"""

    # 各サンプルをトークン化
    ds = self._base_data[DatasetSplit.TRAIN].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )

    # TrainingInputに変換
    ds = ds.map(lambda x, y: self._to_training_input(x, y))

    # 長すぎるサンプルを除外
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)

    # データセットをシャッフル
    ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)

    # 必要に応じて繰り返し
    ds = ds.repeat(num_epochs)

    # バッチ作成
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

  def get_validation_dataset(self, batch_size: int):
    """検証データセットの作成"""

    # 学習データセットと同様だが、シャッフルと繰り返しは行わない
    ds = self._base_data[DatasetSplit.VALIDATION].map(
        lambda x : (self._tokenize_source(x['src']),
                    self._tokenize_destination(x['dst']))
    )
    ds = ds.map(lambda x, y: self._to_training_input(x, y))
    ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds

MTNTDatasetBuilderクラスは、MTNTデータセットを読み込み、モデルの学習に適した形式に変換するためのクラスです。主要なメソッドは以下の通りです。

  • __init__: コンストラクタ。トークナイザーとシーケンスの最大長を受け取ります。
  • _tokenize_source: 原文をトークン化するメソッド。
  • _tokenize_destination: フランス語訳をトークン化するメソッド。
  • _pad_up_to_max_len: シーケンスをバッチ内の最大長にパディングするメソッド。
  • _to_training_input: 原文とフランス語訳のトークン列からTrainingInputオブジェクトを作成するメソッド。
  • get_train_dataset: 学習データセットを作成するメソッド。
  • get_validation_dataset: 検証データセットを作成するメソッド。

_to_training_inputメソッドでは、以下の処理を行っています。

  1. 原文とフランス語訳のトークン列を連結し、モデルへの入力シーケンスを作成。
  2. 原文トークンにはゼロマスク、フランス語訳トークンにはマスクを作成し、連結。
  3. シーケンスをバッチ内の最大長にパディング。
  4. パッドトークンに対してはマスクを適用。

get_train_datasetとget_validation_datasetメソッドでは、データセットを読み込み、トークン化、TrainingInputへの変換、シャッフル、バッチ化などの前処理を行っています。

それでは、実際にデータローダーを使ってみましょう。

dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)
ds = dataset_builder.get_train_dataset(3, 1)
ds = ds.take(2)
ds = ds.as_numpy_iterator()

for idx, example in enumerate(ds):
  print(f'Example {idx}:')
  for key, val in example.items():
    print(f'{key}: {val}')
  print()

MTNTDatasetBuilderオブジェクトを作成し、シーケンスの最大長を20に設定しています。
get_train_datasetメソッドで学習データセットを作成し、バッチサイズを3、エポック数を1に設定しています。

ds.takeで先頭の2バッチを取得し、NumPyイテレータに変換後、各バッチの情報を表示しています。

input_tokensには原文とフランス語訳が連結されたトークン列が、target_maskにはフランス語訳部分のマスクが格納されています。これらがモデルの入力となります。

Step 2 : グリフィンモデルのファインチューニング

準備

まずはモデルをロードしましょう。recurrentgemma.GriffinConfig.from_flax_params_or_variables関数を使って、チェックポイントから自動的に正しい設定をロードできます。
※ここでのモデルの語彙数は、今回のリリースで未使用のトークンがあるため、入力の埋め込みの数よりも小さいことに注意してください。

では、このモデルはフランス語を翻訳できるでしょうか?試してみましょう!

sampler = recurrentgemma.Sampler(model=model, vocab=vocab, params=params)
output = sampler(
  ["Translate this into French:\nHello, my name is Morgane.\n"],
  # 生成時のステップ数
  total_generation_steps=30,
)
print(output.text[0])

ファインチューニング前のモデルに英語の文章を入力し、フランス語への翻訳を試してみました。
まだ十分な翻訳ができていないようですね。ファインチューニングによって翻訳性能を向上させていきましょう。

順伝播と損失関数

recurrentgemma.Griffinクラスは、flax.linen.Moduleを継承しています。このクラスには、以下の2つの重要なメソッドがあります。

  • init: モデルのパラメータを初期化します。
  • apply: 与えられたパラメータを使ってモデルの__call__関数を実行します。

今回は事前学習済みの重みを使用するので、initは使用しません。

applyを使って、順伝播と損失計算を行うforward_and_loss_fn関数を作成しましょう。

def forward_and_loss_fn(
    params,
    *,
    model: recurrentgemma.Griffin,
    input_tokens: jax.Array,            # 形状 [B, L]
    input_mask: jax.Array,              # 形状 [B, L]
    positions: jax.Array,               # 形状 [B, L]
) -> jax.Array:
  """順伝播と損失関数

  Args:
    params: モデルの入力パラメータ
    model: 呼び出すGriffinモデル
    input_tokens: 入力トークン列, 形状 [B, L]
    input_mask: 損失計算で無視するトークン, 形状 [B, L]
    positions: 各トークンの相対位置, 形状 [B, L]

  Returns:
    次のトークン予測タスクのソフトマックス交差エントロピー損失
  """

  # 入力データに対して順伝播
  # ここではattention cacheは不要
  logits, _ = model.apply(
        {"params": params},
        input_tokens,
        positions,
        None,              # attention cacheはNone
    )

  # 最後のステップはターゲットに含まれないため除外
  logits = logits[0, :-1]

  # 同様に、最初のトークンは予測できないため除外
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # ターゲットラベルをone-hot表現に変換
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # 不要なトークンは更新しない
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # 正規化係数
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # 負の対数尤度損失を返す
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

forward_and_loss_fn関数では以下の処理を行っています。

  1. model.applyでinput_tokensを入力として順伝播を行い、予測logitsを取得。
  2. 最後のステップとターゲットの最初のトークンを除外。
  3. ターゲットラベルをone-hot表現に変換。
  4. 損失計算に含めないトークンのone-hotベクトルをゼロにマスク。
  5. 正規化係数を計算。
  6. ソフトマックス交差エントロピー損失を計算して返す。

次に、backward_passを行い、モデルのパラメータを更新するtrain_step関数を作成します。

Params = dict[str, Any]

def get_positions(example: jax.Array, pad_id : int) -> jax.Array:
  """トークンから位置ベクトルを作成"""
  pad_mask = example != pad_id
  positions = jnp.cumsum(pad_mask, axis=-1)
  # 最初の有効な位置以降の全てのポジションから1を引く(0-indexed)
  positions = positions - (positions >= 1)
  return positions

@functools.partial(
    jax.jit,
    static_argnames=['model', 'optimizer'],
    donate_argnames=['params', 'opt_state'],
)
def train_step(
    model: recurrentgemma.Griffin,
    params: Params,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
    pad_id: int,
    example: TrainingInput,
) -> tuple[jax.Array, Params, optax.OptState]:
  """学習ステップ

  Args:
    model: Griffinモデル
    params: モデルの入力パラメータ
    optimizer: optaxオプティマイザ
    opt_state: オプティマイザの状態
    pad_id: パッドトークンのID
    example: 入力バッチ

  Returns:
    学習損失, 更新されたパラメータ, 更新されたオプティマイザの状態
  """

  positions = get_positions(example.input_tokens, pad_id)

  # 順伝播と逆伝播
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(
      params,
      model=model,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      positions=positions,
  )
  # パラメータの更新
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

train_step関数では以下の処理を行っています。

  1. get_positions関数で入力トークン列から位置ベクトルを作成。
  2. jax.value_and_gradでforward_and_loss_fnの順伝播と逆伝播を行い、損失と勾配を計算。
  3. オプティマイザのupdateメソッドで勾配を適用し、パラメータを更新。
  4. 学習損失、更新されたパラメータ、更新されたオプティマイザの状態を返す。

同様に、逆伝播を行わないvalidation_step関数を作成します。

@functools.partial(jax.jit, static_argnames=['model'])
def validation_step(
    model: recurrentgemma.Griffin,
    params: Params,
    pad_id: int,
    example: TrainingInput,
) -> jax.Array:
  return forward_and_loss_fn(
      params,
      model=model,
      input_tokens=example.input_tokens,
      input_mask=example.target_mask,
      positions=get_positions(example.input_tokens, pad_id),
  )

validation_step関数は、train_stepから逆伝播の部分を取り除いたものになっています。

いよいよ、学習ループ自体を実装します。

@chex.dataclass(frozen=True)
class TrainingConfig:
  optimizer: str
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  weight_decay: float = 0.0
  b2: float = 0.99
  eps: float = 1e-8
  max_steps: int | None = None

def griffin_weight_decay_mask(params_like: optax.Params) -> Any:
  # RGLRUとembeddings、biasesにはweight decayを適用しない
  def enable_weight_decay(path: list[str], _: Any) -> bool:
    # LRUとembedderのパラメータ
    path = [dict_key.key for dict_key in path]
    if 'rg_lru' in path or 'embedder' in path:
      return False
    # 全てのbiasesとscales
    if path[-1] in ('b', 'scale'):
      return False
    return True

  return jax.tree_util.tree_map_with_path(enable_weight_decay, params_like)

def train_loop(
    model: recurrentgemma.Griffin,
    params: Params,
    dataset_builder: MTNTDatasetBuilder,
    training_cfg: TrainingConfig,
):
  if training_cfg.optimizer == 'adamw':
    # より良い最適化のためにAdam-Wを使用
    optimizer = optax.adamw(
        learning_rate=training_cfg.learning_rate,
        b2=training_cfg.b2,
        eps=training_cfg.eps,
        weight_decay=training_cfg.weight_decay,
        mask=griffin_weight_decay_mask,
    )
  else:
    # メモリ節約のため、SGDオプティマイザを使用
    optimizer = optax.sgd(learning_rate=training_cfg.learning_rate)

  opt_state = jax.jit(optimizer.init)(params)

  # 学習データセットの作成
  train_ds = dataset_builder.get_train_dataset(
      batch_size=training_cfg.batch_size, num_epochs=training_cfg.num_epochs
  )

  train_ds = train_ds.as_numpy_iterator()

  # 検証データセットの作成(このデモでは少数のサンプルに制限)
  validation_ds = dataset_builder.get_validation_dataset(
      batch_size=training_cfg.batch_size
  )
  validation_ds = validation_ds.take(50)

  n_steps = 0
  avg_loss = 0

  # 最初の検証損失
  n_steps_eval = 0
  eval_loss = 0
  val_iterator = validation_ds.as_numpy_iterator()
  for val_example in val_iterator:
    eval_loss += validation_step(
        model, params, dataset_builder._tokenizer.pad_id, val_example
    )
    n_steps_eval += 1
  print(f"Start, validation loss: {eval_loss/n_steps_eval}")

  for train_example in train_ds:
    train_loss, params, opt_state = train_step(
        model=model,
        params=params,
        optimizer=optimizer,
        opt_state=opt_state,
        pad_id=dataset_builder._tokenizer.pad_id,
        example=train_example,
    )

    n_steps += 1
    avg_loss += train_loss
    if n_steps % training_cfg.eval_every_n == 0:
      eval_loss = 0

      n_steps_eval = 0
      val_iterator = validation_ds.as_numpy_iterator()
      for val_example in val_iterator:
        eval_loss += validation_step(
            model,
            params,
            dataset_builder._tokenizer.pad_id,
            val_example,
        )
        n_steps_eval +=1
      avg_loss /= training_cfg.eval_every_n
      eval_loss /= n_steps_eval
      print(f"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}")
      avg_loss=0
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

train_loop関数の処理の流れは以下の通りです。

  1. TrainingConfigに従って、オプティマイザを設定。
    • Adam-Wの場合は、griffin_weight_decay_mask関数でweight decayを適用するパラメータを指定。
    • SGDの場合は、learning_rateのみ指定。
  2. オプティマイザの状態を初期化。
  3. 学習データセットと検証データセットを作成。
  4. 最初の検証損失を計算して表示。
  5. 学習データセットをイテレーションし、各バッチに対して以下を実行。
    • train_step関数で順伝播、逆伝播、パラメータ更新を行う。
    • 一定ステップ(eval_every_n)ごとに検証損失を計算して表示。
    • max_stepsに達したら学習を終了。
  6. 学習後のパラメータを返す。

少数のステップで実際にモデルをファインチューニングしてみましょう。

# メモリに収まるようにシーケンス長を小さく設定
SEQ_SIZE = 25
tokenizer = GriffinTokenizer(vocab)
dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)
training_cfg = TrainingConfig(
    optimizer='sgd',
    learning_rate=2e-3,
    b2=0.96,
    num_epochs=1,
    eval_every_n=20,
    batch_size=1,
    max_steps=100,
)

trained_params = train_loop(
    model=model,
    params=params,
    dataset_builder=dataset_builder,
    training_cfg=training_cfg,
)

学習損失と検証損失の両方が下がっています。うまくいっているようですね。
では、先ほどと同じ例文で再度試してみましょう。
学習時の入力形式に合わせるため、プレフィックスの「Translate this into French:\n」と、末尾の改行文字を忘れずに付けてください。
これらは翻訳の開始をモデルに知らせるための合図となります。

sampler.params = trained_params
output = sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=30,
)
print(output.text[0])

ファインチューニング後のモデルは、かなりまともなフランス語訳を生成できるようになりました!
わずか100ステップの学習でこれだけの効果が得られるのは素晴らしいですね。

もちろん、本格的な翻訳タスクには、より大規模なデータセットと長い学習が必要になります。
しかし、事前学習済みのグリフィンモデルをベースにファインチューニングすることで、限られたリソースでも十分な性能が得られることが分かります。

以上で、FlaxとグリフィンモデルによるJAXベースの翻訳ファインチューニングチュートリアルは終了です。
JAXとFlaxの柔軟性と、グリフィンのような強力な事前学習済みモデルを組み合わせることで、効率的かつ高性能な学習が実現できることを実感いただけたと思います。

今回学んだテクニックを応用して、皆さんのプロジェクトをさらに発展させてください!

Kaggleノートブック

Flaxを使用した2Bグリフィンモデルのファインチューニングチュートリアル
Explore and run machine learning code with Kaggle Notebooks | Using data from RecurrentGemma

GoogleColabノートブック

Hugging Face 版

Google Colab

Kagglehub版

Google Colab

参考サイト

GitHub - Sunwood-ai-labs/recurrentgemma: Open weights language model from Google DeepMind, based on Griffin.
Open weights language model from Google DeepMind, based on Griffin. - Sunwood-ai-labs/recurrentgemma

コメント

タイトルとURLをコピーしました