llama.cpp × Gemma3nでlogprobs出力!詳細な確率分析(GoogleColab📒ノートブック付)

開発環境

このノートブックでは、llama.cppでGemma3nを動かし、トークンの生成確率(logprobs)を詳細に分析する方法を解説します。

このノートブックでできること

  • Gemma3nの修正済みGGUFを使用(Google公式版は動作しません!)
  • logprobs対応のサーバー実装でトークン確率を詳細取得
  • 確率分析機能でサプライズ値やパープレキシティを計算
  • 代替候補分析で生成の不確実性を可視化

⚠️ 重要:なぜUnsloth修正版が必要?

Google公式のGemma3n-E4Bはllama.cppで動作しません! Unslothチームが以下の重要なバグを修正済み:

  • チャットテンプレートとトークナイザーの修正
  • RoPEスケーリングの修正
  • 未初期化ウェイトの修正
  • BOSトークンの適切な処理

環境準備

# Google Driveをマウント
from google.colab import drive
drive.mount('/content/drive')

import os
import subprocess
import time
import requests
import json

llama.cppのセットアップ

既存のビルドがあれば復元、なければ新規ビルドを行います。

def setup_llama_cpp():
    """既存ビルドがあれば復元、なければ新規ビルド"""
    drive_build_path = "/content/drive/MyDrive/llama_cpp_build.tar.gz"

    if os.path.exists(drive_build_path):
        print("🔄 既存のビルドファイルを発見しました。復元中...")

        # llama.cppをクローン(軽量)
        if not os.path.exists('/content/llama.cpp'):
            os.system('git clone https://github.com/ggerganov/llama.cpp.git /content/llama.cpp')

        # ビルド済みファイルを復元
        os.chdir('/content/llama.cpp')
        os.system(f'tar -xzf {drive_build_path}')

        print("✅ ビルドファイルの復元が完了しました。")
        return True
    else:
        print("🔨 新規ビルドを開始します...")
        return False

# セットアップ実行
if not setup_llama_cpp():
    # 新規ビルドの場合
    print("新規ビルドを実行中...")

    # llama.cppをクローンしてビルド
    if not os.path.exists('/content/llama.cpp'):
        os.system('git clone https://github.com/ggerganov/llama.cpp.git /content/llama.cpp')

    os.chdir('/content/llama.cpp')

    # CMakeでCUDAサポートビルド
    os.system('cmake -B build -DGGML_CUDA=ON')
    os.system('cmake --build build --config Release -j$(nproc)')

    # ビルド結果を保存
    print("💾 ビルドファイルを保存中...")
    os.system('tar -czf llama_cpp_build.tar.gz build/bin/ $(find build -name "*.so" -o -name "*.a" 2>/dev/null)')
    os.system('cp llama_cpp_build.tar.gz /content/drive/MyDrive/')

    print("✅ ビルドが完了し、Driveに保存されました。")
else:
    os.chdir('/content/llama.cpp')

print("🎉 llama.cpp セットアップ完了!")

Gemma3n修正済みモデルのダウンロード

Unsloth修正済みGemma3nをダウンロードします。

def download_gemma3n_fixed():
    """Unsloth修正済みGemma3nをダウンロード"""
    model_path = "/content/gemma3n-e4b-fixed.gguf"

    # 既存ファイルをチェック
    if os.path.exists(model_path):
        file_size = os.path.getsize(model_path) / (1024**3)
        print(f"✅ 修正済みモデルが既に存在します: {model_path} ({file_size:.1f}GB)")
        return model_path

    print("📥 Gemma3n修正済みモデルをダウンロード中...")
    print("🔧 Unslothによるllama.cpp対応版をダウンロード中...")

    # huggingface_hubをインストール
    os.system('pip install huggingface_hub')

    # Unsloth修正済みE4B版をダウンロード
    download_url = "https://huggingface.co/unsloth/gemma-3n-E4B-it-GGUF/resolve/main/gemma-3n-E4B-it-UD-Q4_K_XL.gguf"

    print(f"🔄 ダウンロード中: Gemma3n E4B版(Unsloth修正済み)(約7.5GB)")
    print("   ⚠️ Google公式版は動作しませんが、これなら動きます!")
    os.system(f'wget -O {model_path} {download_url}')

    if os.path.exists(model_path):
        file_size = os.path.getsize(model_path) / (1024**3)
        print(f"✅ 修正済みモデルのダウンロード完了: {model_path} ({file_size:.1f}GB)")
        print("🎉 このモデルはllama.cppで正常に動作します!")
        return model_path
    else:
        print("❌ ダウンロードに失敗しました。")
        return None

# モデルをダウンロード
selected_model = download_gemma3n_fixed()

logprobs対応サーバーの起動

logprobs取得に対応したサーバーを起動します。

def start_llama_server_with_logprobs():
    """logprobs対応でサーバーを起動"""
    if not selected_model:
        print("❌ モデルファイルが見つかりません")
        return None

    print("🚀 logprobs対応サーバー起動中...")

    # サーバープロセスを起動
    process = subprocess.Popen([
        '/content/llama.cpp/build/bin/llama-server',
        '-m', selected_model,
        '--host', '0.0.0.0',
        '--port', '8081',
        '--n-gpu-layers', '99',
        '--ctx-size', '32768',     # Gemma 3nは32Kコンテキスト対応
        '--temp', '1.0',           # Unsloth推奨設定
        '--top-k', '64',
        '--top-p', '0.95',
        '--min-p', '0.0',
        '--repeat-penalty', '1.0',
        '--log-verbose',           # 詳細ログ出力
        '--threads', '8'           # CPUスレッド数
    ])

    print("⏱️ サーバー起動待機中...")
    time.sleep(15)  # 起動まで待機

    # サーバーの起動確認
    try:
        response = requests.get("http://localhost:8081/health", timeout=5)
        if response.status_code == 200:
            print("✅ logprobs対応サーバー起動完了!")
            print("🎯 トークン確率分析の準備完了")
        else:
            print("⚠️ サーバー起動中... 少し待ってください")
            time.sleep(5)
    except requests.exceptions.RequestException:
        print("⚠️ サーバー起動中... 少し待ってください")
        time.sleep(5)

    return process

# サーバー起動
server_process = start_llama_server_with_logprobs()

logprobs取得・分析機能の実装

トークンの生成確率を取得・分析する機能を実装します。

def query_with_logprobs(prompt, logprobs_count=5, max_tokens=128):
    """logprobs付きでテキスト生成を実行"""
    try:
        response = requests.post("http://localhost:8081/completion",
            json={
                "prompt": prompt,
                "n_predict": max_tokens,
                "temperature": 1.0,
                "top_k": 64,
                "top_p": 0.95,
                "min_p": 0.0,
                "repeat_penalty": 1.0,
                "logprobs": logprobs_count,  # logprobs取得数
                "stream": False
            },
            timeout=30
        )

        if response.status_code == 200:
            result = response.json()

            # 生成されたテキストを表示
            generated_text = result.get('content', '')
            print(f"🎭 Generated text: {generated_text}")

            return result
        else:
            print(f"❌ エラー: {response.status_code}")
            print(f"レスポンス: {response.text}")
            return None
    except Exception as e:
        print(f"❌ 接続エラー: {e}")
        return None

def analyze_logprobs(result):
    """logprobsの詳細分析"""
    if not result or "completion_probabilities" not in result:
        print("⚠️ 確率データが見つかりません")
        return

    completion_probs = result["completion_probabilities"]
    print(f"\n📈 === Log Probabilities 分析 ({len(completion_probs)} tokens) ===")

    for i, token_data in enumerate(completion_probs[:10]):  # 最初の10トークンのみ表示
        token = token_data.get("token", "")
        logprob = token_data.get("logprob", 0)
        prob = 2 ** logprob  # 対数確率を確率に変換

        print(f"🎯 Token {i:2d}: '{token:12s}' | Prob: {prob:7.4f} | LogProb: {logprob:7.3f}")

        # 上位候補トークンを表示
        if "top_logprobs" in token_data and token_data["top_logprobs"]:
            print("   📋 Top alternatives:")
            for j, top_token in enumerate(token_data["top_logprobs"][:3]):
                alt_token = top_token.get("token", "")
                alt_logprob = top_token.get("logprob", 0)
                alt_prob = 2 ** alt_logprob
                print(f"      {j+1}. '{alt_token:10s}' | Prob: {alt_prob:7.4f} | LogProb: {alt_logprob:7.3f}")
        print()

def detailed_probability_analysis(result):
    """詳細な確率統計分析"""
    if not result or "completion_probabilities" not in result:
        print("⚠️ 確率データがありません")
        return

    completion_probs = result["completion_probabilities"]
    print(f"\n🔬 === 詳細統計分析 ===")

    total_surprise = 0
    confidence_scores = []

    for i, token_data in enumerate(completion_probs):
        token = token_data.get("token", "")
        logprob = token_data.get("logprob", 0)
        prob = 2 ** logprob

        # サプライズ値(情報量)
        surprise = -logprob  # ビット単位の情報量
        total_surprise += surprise

        # 信頼度スコア(1位と2位の差)
        if "top_logprobs" in token_data and len(token_data["top_logprobs"]) > 1:
            top_probs = token_data["top_logprobs"]
            first_prob = 2 ** top_probs[0]["logprob"]
            second_prob = 2 ** top_probs[1]["logprob"] if len(top_probs) > 1 else 0
            confidence = first_prob - second_prob
            confidence_scores.append(confidence)

        if i < 5:  # 最初の5トークンのみ詳細表示
            print(f"&#x1f3b2; Token {i}: '{token:10s}' | Surprise: {surprise:5.2f} bits | Confidence: {confidence:.3f}")

    # 統計サマリー
    avg_surprise = total_surprise / len(completion_probs) if completion_probs else 0
    avg_confidence = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0
    perplexity = 2 ** avg_surprise

    print(f"\n&#x1f4ca; === 統計サマリー ===")
    print(f"&#x1f4c8; 総サプライズ: {total_surprise:.2f} bits")
    print(f"&#x1f4ca; 平均サプライズ: {avg_surprise:.2f} bits/token")
    print(f"&#x1f3af; 推定パープレキシティ: {perplexity:.2f}")
    print(f"&#x1f3aa; 平均信頼度: {avg_confidence:.3f}")

    # 解釈
    print(f"\n&#x1f50d; === 解釈 ===")
    if avg_surprise < 3:
        print("&#x2705; 予測しやすいテキスト(低サプライズ)")
    elif avg_surprise < 6:
        print("&#x1f504; 中程度の予測可能性")
    else:
        print("&#x1f3b2; 予測困難なテキスト(高サプライズ)")

    if avg_confidence > 0.5:
        print("&#x1f4aa; 高い生成信頼度")
    elif avg_confidence > 0.2:
        print("&#x1f504; 中程度の生成信頼度")
    else:
        print("&#x26a0;&#xfe0f; 低い生成信頼度(不確実性が高い)")

実験とテスト

様々なプロンプトでlogprobs実験を実行してみましょう。

def run_logprobs_experiments():
    """様々なプロンプトでlogprobs実験を実行"""
    print("&#x1f9ea; === Logprobs 実験開始 ===\n")

    test_prompts = [
        "The weather today is",
        "Once upon a time",
        "The capital of Japan is",
        "In machine learning, the term 'gradient descent' refers to",
        "2 + 2 equals"
    ]

    for i, prompt in enumerate(test_prompts):
        print(f"{'='*70}")
        print(f"&#x1f52c; 実験 {i+1}: '{prompt}'")
        print('='*70)

        # logprobs付きで生成
        result = query_with_logprobs(prompt, logprobs_count=3, max_tokens=50)

        if result:
            # 基本分析
            analyze_logprobs(result)

            # 詳細分析
            detailed_probability_analysis(result)

        print("\n" + "-"*70 + "\n")

# 実験実行
run_logprobs_experiments()

独自プロンプトでの分析

ここであなた独自のプロンプトで分析を行ってみてください。

# あなたの独自プロンプトでの分析
your_prompt = "AI and machine learning will"  # ← ここを変更してください

print(f"&#x1f3aa; 独自プロンプト分析: '{your_prompt}'")
result = query_with_logprobs(your_prompt, logprobs_count=5, max_tokens=40)

if result:
    analyze_logprobs(result)
    detailed_probability_analysis(result)

# より長いテキスト生成での分析
print(f"\n&#x1f3ad; === 長文生成での確率分析 ===")
long_result = query_with_logprobs("Explain the concept of artificial intelligence", 
                                  logprobs_count=3, max_tokens=100)
if long_result:
    detailed_probability_analysis(long_result)

システム状態確認

def check_system_status():
    """システム状態の最終確認"""
    print("&#x1f50d; === システム状態確認 ===")

    # llama.cppの状態確認
    if os.path.exists('/content/llama.cpp/build/bin/llama-cli'):
        print("&#x2705; llama.cpp: ビルド済み")
    else:
        print("&#x274c; llama.cpp: ビルドが必要")

    # モデルファイルの状態確認  
    if selected_model and os.path.exists(selected_model):
        file_size = os.path.getsize(selected_model) / (1024**3)
        print(f"&#x2705; 修正済みモデル: {selected_model} ({file_size:.1f}GB)")
        print("&#x1f527; Unsloth修正版を使用中")
    else:
        print("&#x274c; モデルファイル: 見つかりません")

    # サーバー状態確認
    try:
        response = requests.get("http://localhost:8081/health", timeout=3)
        if response.status_code == 200:
            print("&#x2705; サーバー: 稼働中")
            print("&#x1f3af; logprobs対応: 有効")
        else:
            print("&#x26a0;&#xfe0f; サーバー: 応答異常")
    except:
        print("&#x274c; サーバー: 停止中")

    # GPU確認
    print("\n&#x1f5a5;&#xfe0f; === GPU情報 ===")
    !nvidia-smi --query-gpu=name,memory.total,memory.used --format=csv,noheader,nounits

# 状態確認実行
check_system_status()

まとめ

このノートブックでは、llama.cppとGemma3nを使用したlogprobs分析の実装方法を解説しました。

実現できたこと

  • 修正済みGemma3nでllama.cppでの安定動作
  • logprobs取得でトークン生成確率の詳細分析
  • 統計分析でサプライズ値・パープレキシティ・信頼度の計算
  • 代替候補分析で生成の不確実性可視化

分析可能な指標

  • 確率値: 各トークンの生成確率
  • サプライズ値: 情報量(bits)
  • 信頼度: 1位と2位候補の差
  • パープレキシティ: モデルの困惑度

活用場面

  • モデル評価: 生成品質の定量的評価
  • プロンプト最適化: 効果的なプロンプト設計
  • 不確実性分析: モデルの確信度測定
  • 研究・開発: 言語モデルの挙動分析

📒ノートブック

Google Colab

参考リンク

コメント

タイトルとURLをコピーしました