はじめに
このノートブックは、lua-cgemmaというライブラリを使用しています。lua-cgemmaは、gemma.cppのLuaバインディングを提供するライブラリです。gemma.cppは、Google Highway Libraryを利用してポータブルSIMDを活用したCPU推論を行うため、アクセラレータなしでもスムーズに実行できます。詳細については、README.mdを参照してください。
こちらの記事もおすすめ
lua-cgemma、gemma.cpp、Google Highway Libraryの解説
近年、AI分野での発展は目覚ましく、特に自然言語処理の分野ではTransformerベースのモデルが大きな成功を収めています。そんな中、Googleが開発したGemini(ジェミニ)というモデルが注目を集めています。このGeminiモデルをC++で実装したのが、今回紹介するgemma.cppです。さらに、gemma.cppをLuaから簡単に利用できるようにしたのがlua-cgemmaです。また、これらのライブラリでは、Google Highway Libraryという、SIMDを活用した高速な計算を可能にするライブラリが使われています。
gemma.cpp: GeminiモデルのC++実装
gemma.cppは、Googleが開発したGeminiという大規模言語モデルをC++で実装したライブラリです。Geminiは、Transformerアーキテクチャをベースとした大規模な言語モデルで、自然言語の理解や生成に優れた性能を示しています。
C++で実装されているgemma.cppは、Pythonなどの高級言語に比べて高速に動作します。また、gemma.cppではSIMD(Single Instruction Multiple Data)という技術を活用することで、さらに高速な処理を実現しています。
lua-cgemma: gemma.cppのLuaバインディング
C++で書かれたgemma.cppを、Luaというスクリプト言語から簡単に利用できるようにしたのがlua-cgemmaです。Luaは軽量で高速なスクリプト言語で、ゲーム開発などでよく使われています。
lua-cgemmaを使うことで、Luaのシンプルで分かりやすい文法を使って、gemma.cppの機能を呼び出すことができます。これにより、C++の知識がなくてもGeminiモデルを利用したアプリケーションを開発することが可能になります。
Google Highway Library: SIMDを活用した高速計算ライブラリ
Google Highway Libraryは、SIMDを活用して高速な計算を実現するC++のライブラリです。SIMDとは、Single Instruction Multiple Dataの略で、一つの命令で複数のデータを同時に処理する技術のことを指します。
現代のCPUは、SIMDを利用することで、同じ命令を複数のデータに対して並列に実行できます。これにより、特に行列計算などで大きな速度向上が期待できます。Google Highway Libraryは、このSIMDを簡単に利用できるようにしたライブラリで、gemma.cppやlua-cgemmaでも活用されています。
以上が、lua-cgemma、gemma.cpp、Google Highway Libraryの概要です。これらのライブラリを使うことで、高速かつ高性能な自然言語処理システムを構築することができます。
環境設定を確認
import os
PRIVATE_TEST = True if os.getenv('KAGGLE_IS_COMPETITION_RERUN') else False
PRIVATE_TEST
上記のコードは、環境変数KAGGLE_IS_COMPETITION_RERUN
の値を確認し、プライベートテストかどうかを判定しています。プライベートテストの場合はPRIVATE_TEST
がTrue
に、そうでない場合はFalse
になります。
インストール手順
lupaとportionのインストール
!pip install lupa --no-index --find-links=file:///kaggle/input/aimo-lua-cgemma-environment/lupa/
!pip install portion --no-index --find-links=file:///kaggle/input/aimo-lua-cgemma-environment/portion/
最初のステップでは、lupa
とportion
というPythonパッケージをインストールします。これらのパッケージは、事前にダウンロードされたファイルから直接インストールされます。
lua-cgemmaとその他の依存関係の解凍
import os
os.chdir("/usr/local")
!tar xzvf /kaggle/input/aimo-lua-cgemma-environment/archive.tar.gz && ldconfig
os.chdir("/kaggle/working/")
次に、コンパイル済みのlua-cgemmaとその他の依存関係を含むarchive.tar.gz
ファイルを解凍します。解凍されたファイルは/usr/local
ディレクトリに配置されます。
モジュールのインストール確認
import sys
dlflags = sys.getdlopenflags()
sys.setdlopenflags(258)
from lupa.luajit21 import LuaRuntime
sys.setdlopenflags(dlflags)
lua = LuaRuntime(unpack_returned_tuples=True)
lua.execute('require("cgemma").info()')
最後に、lua-cgemmaモジュールが正しくインストールされているかを確認します。LuaRuntime
を使用して、Luaコードを実行し、cgemmaモジュールの情報を表示します。
事前学習済みモデルの読み込み
グローバルなGemmaインスタンスとチャットセッションを作成し、generate
、fix_exception
、evaluate
関数を定義します。
lua.execute('''
gemma, err = require("cgemma").new({
tokenizer = "/kaggle/input/codegemma/gemmacpp/7b-it-sfp/1/tokenizer.spm",
model = "7b-it",
weights = "/kaggle/input/codegemma/gemmacpp/7b-it-sfp/1/7b-it-sfp.sbs"
})
if not gemma then
error("Opoos! "..err)
end
session, seed = gemma:session({
max_tokens = 8192,
temperature = 0.1
})
if not session then
error("Opoos! "..seed)
end
''')
generate = lua.eval('''
function(problem)
local ok, err = session:load("/kaggle/input/aimo-lua-cgemma-environment/dump.bin")
if not ok then
error("Opoos! "..err)
end
local template = "Problem: %s\\nInstruction: Write a Python program to print the numerical answer to this math problem.\\nNote: The code must show detailed problem-solving step by step, no user input, no description, and no explanation. As shown in these examples above, make effective use of modules such as `sympy`, `scipy`, `portion`, `fractions`, and `ortools` to enhance the readability of the code."
local reply, err = session(string.format(template, problem))
if not reply then
error("Opoos! "..err)
end
return reply
end
''')
fix_exception = lua.eval('''
function(typ, err)
local template = "Instruction: Rewrite this program to avoid `%s` exception: %s"
local reply, err = session(string.format(template, typ, err))
if not reply then
error("Opoos! "..err)
end
return reply
end
''')
evaluate = lua.eval('''
function()
local reply, err = session("Instruction: Write the output of this code snippet.\\nNote: no description, and no explanation.")
if not reply then
error("Opoos! "..err)
end
return reply
end
''')
このコードでは、事前学習済みのCodeGemmaモデルを読み込み、generate
、fix_exception
、evaluate
関数を定義しています。
generate
関数は、数学の問題を受け取り、Pythonのコードを生成して答えを出力します。fix_exception
関数は、例外が発生した場合にプログラムを書き直して例外を回避します。evaluate
関数は、生成されたコードスニペットの出力を返します。
ユーティリティ関数の定義
import re
import random
import signal
from io import StringIO
from contextlib import redirect_stdout
def extract_code_lines(code):
flag = False
for line in code.splitlines():
if flag:
if line.startswith("```"):
break
else:
yield line
elif line.lower().startswith("```python"):
flag = True
def extract_result(reply):
m = re.match("^[^0-9+\\-]*([+\\-]?[0-9]+(\\.[0-9]*)?)", reply)
try:
result = int(round(float(m[1])))
except Exception:
result = random.randrange(1000)
return result
def execute(code, timeout=60):
f = StringIO()
with redirect_stdout(f):
try:
signal.alarm(timeout)
exec(code, {
"__name__": "__main__"
})
finally:
signal.alarm(0)
return f.getvalue()
def sigalrm(signum, frame):
raise TimeoutError("timeout")
signal.signal(signal.SIGALRM, sigalrm)
ここでは、以下のようなユーティリティ関数を定義しています。
extract_code_lines
: コードブロックからPythonのコードを抽出します。extract_result
: 生成された出力から数値の結果を抽出します。execute
: Pythonのコードを実行し、出力を取得します。タイムアウト機能も備えています。sigalrm
: タイムアウトが発生した際に呼び出されるシグナルハンドラです。
学習データサンプルでの推論
import itertools
import pandas as pd
from tqdm import tqdm
from IPython.display import display, Markdown
def solve(problem):
display(Markdown(problem))
reply = generate(problem)
output = None
display(Markdown("## Code"))
try:
for i in itertools.count():
try:
display(Markdown(reply))
output = execute("\n".join(extract_code_lines(reply)))
break
except TimeoutError:
raise
except Exception as e:
if e.__traceback__.tb_next and i < 2:
typ = type(e).__name__
err = "{} (line {})".format(str(e), e.__traceback__.tb_next.tb_next.tb_lineno) if e.__traceback__.tb_next.tb_next else str(e)
print(f"fix unhandled {typ}: {err}", file=sys.stderr)
reply = fix_exception(typ, err)
display(Markdown("## Rewritten code " + str(i + 1)))
else:
raise
answer = int(round(float(eval(output)))) % 1000
display(Markdown("## Python execution"))
display(Markdown("**output:** " + output))
display(Markdown("**answer:** " + str(answer)))
except Exception as e:
print("python execution error:", str(e), file=sys.stderr)
display(Markdown("## Gemma evaluation"))
reply = evaluate()
answer = extract_result(reply) % 1000
display(Markdown("**reply:** " + reply))
display(Markdown("**answer:** " + str(answer)))
if not PRIVATE_TEST:
df = pd.read_csv("/kaggle/input/ai-mathematical-olympiad-prize/train.csv").sample(n=2)
display(df)
for id, problem in tqdm(list(zip(df["id"], df["problem"]))):
display(Markdown("# Problem " + id))
solve(problem)
この部分では、学習データのサンプルを使って推論を行います。
-
solve
関数を定義します。この関数は以下の処理を行います。- 問題を表示します。
generate
関数を使って問題に対するPythonコードを生成します。- 生成されたコードを実行し、答えを計算します。
- 実行結果と答えを表示します。
- 例外が発生した場合は、
fix_exception
関数を使ってコードを書き直し、再実行します。 - それでも例外が発生する場合は、
evaluate
関数を使って答えを推定します。
-
プライベートテストでない場合、学習データからランダムに2つのサンプルを選択します。
-
選択したサンプルに対して
solve
関数を適用し、問題の解決を行います。
追加の問題の解決
if not PRIVATE_TEST:
solve("There are two numbers $x$ and $y$, it is known that $x+y=35$ and $2x+4y=94$, what is $x-y$?")
ここでは、追加の問題を解決しています。プライベートテストでない場合、solve
関数を使って新しい問題を解決します。
予測の提出
def solve(problem):
reply = generate(problem)
output = None
try:
for i in itertools.count():
try:
output = execute("\n".join(extract_code_lines(reply)))
break
except TimeoutError:
raise
except Exception as e:
if e.__traceback__.tb_next and i < 2:
typ = type(e).__name__
err = "{} (line {})".format(str(e), e.__traceback__.tb_next.tb_next.tb_lineno) if e.__traceback__.tb_next.tb_next else str(e)
print(f"fix unhandled {typ}: {err}", file=sys.stderr)
reply = fix_exception(typ, err)
else:
raise
return int(round(float(eval(output)))) % 1000
except Exception as e:
print("python execution error:", str(e), file=sys.stderr)
return extract_result(evaluate()) % 1000
import aimo
env = aimo.make_env()
for test, submission in tqdm(env.iter_test()):
submission["answer"] = solve(test['problem'][0])
env.predict(submission)
最後に、予測を提出するための処理を行います。
-
solve
関数を再定義します。この関数は以下の処理を行います。generate
関数を使って問題に対するPythonコードを生成します。- 生成されたコードを実行し、答えを計算します。
- 例外が発生した場合は、
fix_exception
関数を使ってコードを書き直し、再実行します。 - それでも例外が発生する場合は、
evaluate
関数を使って答えを推定します。
-
aimo
モジュールを使って、提出用の環境を作成します。 -
テストデータに対して
solve
関数を適用し、答えを計算します。 -
計算された答えを提出します。
以上が、CodeGemmaを使ったAIMOコンペのベースラインノートブックの解説です。初心者の方でも、このノートブックの内容を理解し、活用できるようになることを願っています。
コメント