初心者向け解説: CodeGemmaを使ったAI数学オリンピックコンペのベースラインノートブック

AI・機械学習

はじめに

このノートブックは、lua-cgemmaというライブラリを使用しています。lua-cgemmaは、gemma.cppのLuaバインディングを提供するライブラリです。gemma.cppは、Google Highway Libraryを利用してポータブルSIMDを活用したCPU推論を行うため、アクセラレータなしでもスムーズに実行できます。詳細については、README.mdを参照してください。


こちらの記事もおすすめ

JAXとWeights & Biasesを用いたGemma Instruct 2BモデルのFinetuning入門
はじめにこのノートブックでは、JAXをバックエンドに使用して、Kaggleの"AI Mathematical Olympiad"コンペティションに向けてGemma Instruct 2Bモデルをfinetuningする方法について解説します...
JAXとWandbとSelf-Consistencyを使ったGemma Instruct 2Bモデルのファインチューニング入門
このノートブックでは、Kaggleの"AI Mathematical Olympiad"コンペティションに向けて、JAXをバックエンドに使用してGemma Instruct 2Bモデルをファインチューニングする方法を解説します。また、Wei...

lua-cgemma、gemma.cpp、Google Highway Libraryの解説

近年、AI分野での発展は目覚ましく、特に自然言語処理の分野ではTransformerベースのモデルが大きな成功を収めています。そんな中、Googleが開発したGemini(ジェミニ)というモデルが注目を集めています。このGeminiモデルをC++で実装したのが、今回紹介するgemma.cppです。さらに、gemma.cppをLuaから簡単に利用できるようにしたのがlua-cgemmaです。また、これらのライブラリでは、Google Highway Libraryという、SIMDを活用した高速な計算を可能にするライブラリが使われています。

gemma.cpp: GeminiモデルのC++実装

gemma.cppは、Googleが開発したGeminiという大規模言語モデルをC++で実装したライブラリです。Geminiは、Transformerアーキテクチャをベースとした大規模な言語モデルで、自然言語の理解や生成に優れた性能を示しています。

C++で実装されているgemma.cppは、Pythonなどの高級言語に比べて高速に動作します。また、gemma.cppではSIMD(Single Instruction Multiple Data)という技術を活用することで、さらに高速な処理を実現しています。

lua-cgemma: gemma.cppのLuaバインディング

C++で書かれたgemma.cppを、Luaというスクリプト言語から簡単に利用できるようにしたのがlua-cgemmaです。Luaは軽量で高速なスクリプト言語で、ゲーム開発などでよく使われています。

lua-cgemmaを使うことで、Luaのシンプルで分かりやすい文法を使って、gemma.cppの機能を呼び出すことができます。これにより、C++の知識がなくてもGeminiモデルを利用したアプリケーションを開発することが可能になります。

Google Highway Library: SIMDを活用した高速計算ライブラリ

Google Highway Libraryは、SIMDを活用して高速な計算を実現するC++のライブラリです。SIMDとは、Single Instruction Multiple Dataの略で、一つの命令で複数のデータを同時に処理する技術のことを指します。

現代のCPUは、SIMDを利用することで、同じ命令を複数のデータに対して並列に実行できます。これにより、特に行列計算などで大きな速度向上が期待できます。Google Highway Libraryは、このSIMDを簡単に利用できるようにしたライブラリで、gemma.cppやlua-cgemmaでも活用されています。

以上が、lua-cgemma、gemma.cpp、Google Highway Libraryの概要です。これらのライブラリを使うことで、高速かつ高性能な自然言語処理システムを構築することができます。

環境設定を確認

import os

PRIVATE_TEST = True if os.getenv('KAGGLE_IS_COMPETITION_RERUN') else False
PRIVATE_TEST

上記のコードは、環境変数KAGGLE_IS_COMPETITION_RERUNの値を確認し、プライベートテストかどうかを判定しています。プライベートテストの場合はPRIVATE_TESTTrueに、そうでない場合はFalseになります。

インストール手順

lupaとportionのインストール

!pip install lupa --no-index --find-links=file:///kaggle/input/aimo-lua-cgemma-environment/lupa/
!pip install portion --no-index --find-links=file:///kaggle/input/aimo-lua-cgemma-environment/portion/

最初のステップでは、lupaportionというPythonパッケージをインストールします。これらのパッケージは、事前にダウンロードされたファイルから直接インストールされます。

lua-cgemmaとその他の依存関係の解凍

import os

os.chdir("/usr/local")
!tar xzvf /kaggle/input/aimo-lua-cgemma-environment/archive.tar.gz && ldconfig
os.chdir("/kaggle/working/")

次に、コンパイル済みのlua-cgemmaとその他の依存関係を含むarchive.tar.gzファイルを解凍します。解凍されたファイルは/usr/localディレクトリに配置されます。

モジュールのインストール確認

import sys
dlflags = sys.getdlopenflags()
sys.setdlopenflags(258)
from lupa.luajit21 import LuaRuntime
sys.setdlopenflags(dlflags)

lua = LuaRuntime(unpack_returned_tuples=True)
lua.execute('require("cgemma").info()')

最後に、lua-cgemmaモジュールが正しくインストールされているかを確認します。LuaRuntimeを使用して、Luaコードを実行し、cgemmaモジュールの情報を表示します。

事前学習済みモデルの読み込み

グローバルなGemmaインスタンスとチャットセッションを作成し、generatefix_exceptionevaluate関数を定義します。

lua.execute('''
    gemma, err = require("cgemma").new({
      tokenizer = "/kaggle/input/codegemma/gemmacpp/7b-it-sfp/1/tokenizer.spm",
      model = "7b-it",
      weights = "/kaggle/input/codegemma/gemmacpp/7b-it-sfp/1/7b-it-sfp.sbs"
    })
    if not gemma then
      error("Opoos! "..err)
    end
    session, seed = gemma:session({
        max_tokens = 8192,
        temperature = 0.1
    })
    if not session then
      error("Opoos! "..seed)
    end
''')

generate = lua.eval('''
    function(problem)
        local ok, err = session:load("/kaggle/input/aimo-lua-cgemma-environment/dump.bin")
        if not ok then
            error("Opoos! "..err)
        end
        local template = "Problem: %s\\nInstruction: Write a Python program to print the numerical answer to this math problem.\\nNote: The code must show detailed problem-solving step by step, no user input, no description, and no explanation. As shown in these examples above, make effective use of modules such as `sympy`, `scipy`, `portion`, `fractions`, and `ortools` to enhance the readability of the code."
        local reply, err = session(string.format(template, problem))
        if not reply then
          error("Opoos! "..err)
        end
        return reply
    end
''')

fix_exception = lua.eval('''
    function(typ, err)
        local template = "Instruction: Rewrite this program to avoid `%s` exception: %s"
        local reply, err = session(string.format(template, typ, err))
        if not reply then
          error("Opoos! "..err)
        end
        return reply
    end
''')

evaluate = lua.eval('''
    function()
        local reply, err = session("Instruction: Write the output of this code snippet.\\nNote: no description, and no explanation.")
        if not reply then
          error("Opoos! "..err)
        end
        return reply
    end
''')

このコードでは、事前学習済みのCodeGemmaモデルを読み込み、generatefix_exceptionevaluate関数を定義しています。

  • generate関数は、数学の問題を受け取り、Pythonのコードを生成して答えを出力します。
  • fix_exception関数は、例外が発生した場合にプログラムを書き直して例外を回避します。
  • evaluate関数は、生成されたコードスニペットの出力を返します。

ユーティリティ関数の定義

import re
import random
import signal
from io import StringIO
from contextlib import redirect_stdout

def extract_code_lines(code):
    flag = False
    for line in code.splitlines():
        if flag:
            if line.startswith("```"):
                break
            else:
                yield line
        elif line.lower().startswith("```python"):
            flag = True

def extract_result(reply):
    m = re.match("^[^0-9+\\-]*([+\\-]?[0-9]+(\\.[0-9]*)?)", reply)
    try:
        result = int(round(float(m[1])))
    except Exception:
        result = random.randrange(1000)
    return result

def execute(code, timeout=60):
    f = StringIO()
    with redirect_stdout(f):
        try:
            signal.alarm(timeout)
            exec(code, {
                "__name__": "__main__"
            })
        finally:
            signal.alarm(0)
    return f.getvalue()

def sigalrm(signum, frame):
    raise TimeoutError("timeout")

signal.signal(signal.SIGALRM, sigalrm)

ここでは、以下のようなユーティリティ関数を定義しています。

  • extract_code_lines: コードブロックからPythonのコードを抽出します。
  • extract_result: 生成された出力から数値の結果を抽出します。
  • execute: Pythonのコードを実行し、出力を取得します。タイムアウト機能も備えています。
  • sigalrm: タイムアウトが発生した際に呼び出されるシグナルハンドラです。

学習データサンプルでの推論

import itertools
import pandas as pd
from tqdm import tqdm
from IPython.display import display, Markdown

def solve(problem):
    display(Markdown(problem))
    reply = generate(problem)
    output = None
    display(Markdown("## Code"))
    try:
        for i in itertools.count():
            try:
                display(Markdown(reply))
                output = execute("\n".join(extract_code_lines(reply)))
                break
            except TimeoutError:
                raise
            except Exception as e:
                if e.__traceback__.tb_next and i < 2:
                    typ = type(e).__name__
                    err = "{} (line {})".format(str(e), e.__traceback__.tb_next.tb_next.tb_lineno) if e.__traceback__.tb_next.tb_next else str(e)
                    print(f"fix unhandled {typ}: {err}", file=sys.stderr)
                    reply = fix_exception(typ, err)
                    display(Markdown("## Rewritten code " + str(i + 1)))
                else:
                    raise
        answer = int(round(float(eval(output)))) % 1000
        display(Markdown("## Python execution"))
        display(Markdown("**output:** " + output))
        display(Markdown("**answer:** " + str(answer)))
    except Exception as e:
        print("python execution error:", str(e), file=sys.stderr)
        display(Markdown("## Gemma evaluation"))
        reply = evaluate()
        answer = extract_result(reply) % 1000
        display(Markdown("**reply:** " + reply))
        display(Markdown("**answer:** " + str(answer)))

if not PRIVATE_TEST:
    df = pd.read_csv("/kaggle/input/ai-mathematical-olympiad-prize/train.csv").sample(n=2)
    display(df)
    for id, problem in tqdm(list(zip(df["id"], df["problem"]))):
        display(Markdown("# Problem " + id))
        solve(problem)

この部分では、学習データのサンプルを使って推論を行います。

  1. solve関数を定義します。この関数は以下の処理を行います。

    • 問題を表示します。
    • generate関数を使って問題に対するPythonコードを生成します。
    • 生成されたコードを実行し、答えを計算します。
    • 実行結果と答えを表示します。
    • 例外が発生した場合は、fix_exception関数を使ってコードを書き直し、再実行します。
    • それでも例外が発生する場合は、evaluate関数を使って答えを推定します。
  2. プライベートテストでない場合、学習データからランダムに2つのサンプルを選択します。

  3. 選択したサンプルに対してsolve関数を適用し、問題の解決を行います。

追加の問題の解決

if not PRIVATE_TEST:
    solve("There are two numbers $x$ and $y$, it is known that $x+y=35$ and $2x+4y=94$, what is $x-y$?")

ここでは、追加の問題を解決しています。プライベートテストでない場合、solve関数を使って新しい問題を解決します。

予測の提出

def solve(problem):
    reply = generate(problem)
    output = None
    try:
        for i in itertools.count():
            try:
                output = execute("\n".join(extract_code_lines(reply)))
                break
            except TimeoutError:
                raise
            except Exception as e:
                if e.__traceback__.tb_next and i < 2:
                    typ = type(e).__name__
                    err = "{} (line {})".format(str(e), e.__traceback__.tb_next.tb_next.tb_lineno) if e.__traceback__.tb_next.tb_next else str(e)
                    print(f"fix unhandled {typ}: {err}", file=sys.stderr)
                    reply = fix_exception(typ, err)
                else:
                    raise
        return int(round(float(eval(output)))) % 1000
    except Exception as e:
        print("python execution error:", str(e), file=sys.stderr)
        return extract_result(evaluate()) % 1000

import aimo

env = aimo.make_env()
for test, submission in tqdm(env.iter_test()):
    submission["answer"] = solve(test['problem'][0])
    env.predict(submission)

最後に、予測を提出するための処理を行います。

  1. solve関数を再定義します。この関数は以下の処理を行います。

    • generate関数を使って問題に対するPythonコードを生成します。
    • 生成されたコードを実行し、答えを計算します。
    • 例外が発生した場合は、fix_exception関数を使ってコードを書き直し、再実行します。
    • それでも例外が発生する場合は、evaluate関数を使って答えを推定します。
  2. aimoモジュールを使って、提出用の環境を作成します。

  3. テストデータに対してsolve関数を適用し、答えを計算します。

  4. 計算された答えを提出します。

以上が、CodeGemmaを使ったAIMOコンペのベースラインノートブックの解説です。初心者の方でも、このノートブックの内容を理解し、活用できるようになることを願っています。

ノートブック

《JP》AIMO CodeGemma baseline
Explore and run machine learning code with Kaggle Notebooks | Using data from multiple data sources

参考サイト

https://www.kaggle.com/code/makimakiai/aimo-codegemma-baseline-e9aacf

コメント

タイトルとURLをコピーしました