Fate Stay Nightで学ぶGraphRAG(GoogleColab付)

AI・機械学習

はじめに

Graph retrieval augmented generation (Graph RAG) は、従来のベクター検索による情報検索手法に強力な手法として注目を集めています。Graph RAGは、データをノードと関係性で構造化するグラフデータベースの特性を活かし、検索された情報の深さと文脈性を高めます。

本記事では、人気アニメ「Fate Stay Night」のWikipediaデータを使って、LangChainとNeo4jを用いたGraph RAGの実践的な構築方法を初心者向けに解説します。

環境のセットアップ

まずは必要なライブラリをインストールしましょう。

%%capture
%pip install --upgrade --quiet langchain langchain-community langchain-openai langchain-experimental neo4j wikipedia tiktoken yfiles_jupyter_graphs

次に、Fate Stay NightのWikipediaデータをダウンロードします。

!git clone https://huggingface.co/datasets/MakiAi/DemoDocs.git

必要なモジュールをインポートします。

import os
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough, ConfigurableField
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate  
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
from langchain_core.output_parsers import StrOutputParser
from langchain_community.graphs import Neo4jGraph
from langchain.document_loaders import TextLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_community.vectorstores import Neo4jVector  
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars

try:
    import google.colab
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    pass

APIキーと接続情報を設定します。NEO4J_URINEO4J_USERNAMENEO4J_PASSWORDは自身の環境に合わせて変更してください。

import getpass
from google.colab import userdata

os.environ["OPENAI_API_KEY"]  = userdata.get('OPENAI_API_KEY')  
os.environ["NEO4J_URI"]       = userdata.get('NEO4J_URI')
os.environ["NEO4J_USERNAME"]  = userdata.get('NEO4J_USERNAME')
os.environ["NEO4J_PASSWORD"]  = userdata.get('NEO4J_PASSWORD')

データの取り込みと分割

Wikipediaデータを読み込み、テキストを適切な長さに分割します。

raw_documents = TextLoader('/content/DemoDocs/FateStayNight_Wiki_mini2.txt').load()
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=125)  
documents = text_splitter.split_documents(raw_documents)
  • TextLoaderでテキストファイルを読み込みます。
  • TokenTextSplitterでテキストを指定したトークン数(chunk_size)ごとに分割します。
  • chunk_overlapは分割されたチャンク間でオーバーラップさせるトークン数を指定します。

知識グラフの構築

分割したドキュメントからLLMを使って知識グラフを構築します。

llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125") 
llm_transformer = LLMGraphTransformer(llm=llm)
graph_documents = llm_transformer.convert_to_graph_documents(documents)
  • ChatOpenAIでGPT-3.5-turboモデルを指定します。temperature=0で確定的な出力を得ます。
  • LLMGraphTransformerにLLMを渡してグラフ変換器を初期化します。
  • convert_to_graph_documentsメソッドで分割したドキュメントをグラフ形式に変換します。

Neo4jのグラフインスタンスを作成します。

graph = Neo4jGraph()

生成したグラフをNeo4jデータベースに保存します。

graph.add_graph_documents(
    graph_documents,
    baseEntityLabel=True,
    include_source=True  
)
  • add_graph_documentsメソッドでグラフをNeo4jに追加します。
  • baseEntityLabel=Trueで各ノードに__Entity__ラベルを付与し、インデックス作成とクエリのパフォーマンスを向上させます。
  • include_source=Trueでノードと元のドキュメントをリンクし、データのトレーサビリティと文脈理解を容易にします。

グラフの可視化

Cypher クエリを使ってグラフを可視化するための関数を定義します。

default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"

def showGraph(cypher: str = default_cypher):
    driver = GraphDatabase.driver(
        uri = os.environ["NEO4J_URI"],  
        auth = (os.environ["NEO4J_USERNAME"],
                os.environ["NEO4J_PASSWORD"]))
    session = driver.session()
    widget = GraphWidget(graph = session.run(cypher).graph()) 
    widget.node_label_mapping = 'id'
    return widget
  • default_cypherはデフォルトのCypherクエリを定義します。
  • showGraph関数は、与えられたCypherクエリの結果をグラフとして可視化します。
    • GraphDatabase.driverでNeo4jデータベースに接続します。
    • session.run(cypher).graph()でクエリを実行し、結果をグラフオブジェクトとして取得します。
    • GraphWidgetでグラフを可視化します。
    • node_label_mapping='id'でノードのラベルをidプロパティにマッピングします。

関数を呼び出してグラフを表示します。

showGraph()

ベクター検索インデックスの作成

既存のグラフからベクター検索インデックスを作成します。

vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"], 
    embedding_node_property="embedding"
)  
  • OpenAIEmbeddingsでOpenAIの埋め込みモデルを使用します。
  • search_type="hybrid"でキーワードとベクターの両方を使用したハイブリッド検索を設定します。
  • node_label="Document"Documentラベルを持つノードをインデックス対象とします。
  • text_node_properties=["text"]でテキストプロパティを指定します。
  • embedding_node_property="embedding"で埋め込みベクトルを格納するプロパティを指定します。

全文検索インデックスの作成

エンティティノードに全文検索インデックスを作成します。

graph.query("CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")
  • CREATE FULLTEXT INDEXで全文検索インデックスを作成します。
  • FOR (e:__Entity__)__Entity__ラベルを持つノードをインデックス対象とします。
  • ON EACH [e.id]idプロパティにインデックスを作成します。

エンティティの抽出

テキストからエンティティを抽出するためのクラスとプロンプトを定義します。

class Entities(BaseModel):
    """エンティティの識別情報"""

    names: List[str] = Field(
        ...,
        description="テキスト中に登場する人、組織、ビジネスエンティティ全て",
    )

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system", 
            "テキストから組織と人物のエンティティを抽出します。",
        ),
        (
            "human",
            "指定されたフォーマットを使って以下の入力から情報を抽出してください: {question}",  
        ),
    ]
)

entity_chain = prompt | llm.with_structured_output(Entities)
  • Entitiesクラスは、抽出されたエンティティの名前を保持します。
  • ChatPromptTemplateでエンティティ抽出のためのプロンプトを定義します。
  • llm.with_structured_output(Entities)でLLMの出力をEntitiesクラスの構造に合わせてパースします。

テストしてみましょう。

entity_chain.invoke({"question": "士郎とセイバーは戦った"}).names

全文検索クエリの生成

全文検索クエリを生成するための関数を定義します。

def generate_full_text_query(input: str) -> str:
    """
    与えられた入力文字列に対する全文検索クエリを生成する。

    この関数は、全文検索に適したクエリ文字列を構築する。入力文字列を単語に分割し、  
    各単語に類似性のしきい値(~2文字の変更)を追加してANDでつなぎ、処理する。
    ユーザーの質問からエンティティをデータベースの値にマップする際に有用で、 
    ある程度のミススペルを許容する。
    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()
  • 入力文字列を単語に分割し、各単語に~2(2文字までの変更を許容)を追加します。
  • 単語をANDで連結して全文検索クエリを生成します。

構造化データのレトリーバー

質問に含まれるエンティティの近傍を取得する構造化データのレトリーバーを定義します。

def structured_retriever(question: str) -> str: 
    """
    質問中で言及されているエンティティの近傍を収集する。
    """
    result = ""
    entities = entity_chain.invoke({"question": question})
    for entity in entities.names:
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, {limit:20})
            YIELD node,score
            CALL { 
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node  
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS output  
            }
            RETURN output LIMIT 1000
            """,
            {"query": generate_full_text_query(entity)},
        )
        result += "\n".join([el['output'] for el in response])
    return result    
  • entity_chainを使って質問からエンティティを抽出します。
  • 抽出されたエンティティごとに以下のクエリを実行します:
    • db.index.fulltext.queryNodesで全文検索インデックスを使ってエンティティノードを検索します。
    • (node)-[r:!MENTIONS]->(neighbor)でエンティティノードから外向きのリレーションシップを持つ近傍ノードを取得します。
    • (node)<-[r:!MENTIONS]-(neighbor)でエンティティノードへの内向きのリレーションシップを持つ近傍ノードを取得します。
    • 取得した近傍ノードをoutput変数に集約し、最大1000件まで返します。
  • 取得した近傍ノードの情報を結果文字列に追加します。

テストしてみましょう。

print(structured_retriever("士郎と関わりがあるエンティティを知りたい"))  

最終的なレトリーバー

構造化データと非構造化データのレトリーバーを組み合わせて、最終的なコンテキストを生成します。

def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
    {structured_data} 
    Unstructured data:
    {"#Document ". join(unstructured_data)} 
    """
    return final_data
  • structured_retrieverで構造化データを取得します。
  • vector_index.similarity_searchでベクター検索を行い、非構造化データを取得します。
  • 構造化データと非構造化データを結合して最終的なコンテキストを生成します。

RAGチェーンの定義

検索クエリを抽出するためのRunnableLambdaを定義します。

_search_query = RunnableLambda(lambda x: x["question"])

回答生成のためのプロンプトテンプレートを定義します。

template = """あなたは優秀なAIです。下記のコンテキストを利用してユーザーの質問に丁寧に答えてください。
必ず文脈からわかる情報のみを使用して回答を生成してください。
{context}

ユーザーの質問: {question}"""
prompt = ChatPromptTemplate.from_template(template)
  • {context}プレースホルダーにレトリーバーから取得したコンテキストが挿入されます。
  • {question}プレースホルダーにユーザーの質問が挿入されます。

RAGチェーンを定義します。

chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)
  • RunnableParallelで検索クエリとユーザーの質問を並列に処理します。
    • "context"には検索クエリをレトリーバーにパイプして取得したコンテキストが格納されます。
    • "question"にはユーザーの質問がそのまま渡されます。
  • プロンプトテンプレートにコンテキストと質問を渡します。
  • LLMにプロンプトを渡して回答を生成します。
  • StrOutputParserで回答を文字列として取得します。

テストしてみましょう。

chain.invoke({"question": "士郎と仲が良いのは誰?"})

まとめ

本記事では、Fate Stay NightのWikipediaデータを使って、LangChainとNeo4jを用いたGraph RAGの実践的な構築方法を初心者向けに解説しました。

主なステップは以下の通りです:

  1. 環境のセットアップ
  2. データの取り込みと分割
  3. 知識グラフの構築
  4. グラフの可視化
  5. ベクター検索インデックスの作成
  6. 全文検索インデックスの作成
  7. エンティティの抽出
  8. 全文検索クエリの生成
  9. 構造化データのレトリーバー
  10. 最終的なレトリーバー
  11. RAGチェーンの定義

LLMGraphTransformerの登場により、知識グラフの生成プロセスがよりスムーズで利用しやすくなりました。これにより、知識グラフの深さと文脈性を活かしてRAGベースのアプリケーションを強化したい人にとって、取り組みやすくなったと言えるでしょう。

Graph RAGはまだ発展途上の分野ですが、非常に有望なアプローチです。本記事が、より多くの人がGraph RAGに興味を持ち、活用していくきっかけになれば幸いです。

コードは全てGitHubに公開していますので、ぜひ参考にしてみてください。ご意見・ご感想などありましたら、お気軽にお問い合わせください。

以上で、「Fate Stay Nightで学ぶGraphRAG」の解説を終わります。お読みいただきありがとうございました。

ノートブック

Google Colab

参考資料

YouTube
作成した動画を友だち、家族、世界中の人たちと共有
Enhancing RAG-based application accuracy by constructing and leveraging knowledge graphs
A practical guide to constructing and retrieving information from knowledge graphs in RAG applications with Neo4j and LangChainEditor's Note: the following is a...

コメント

タイトルとURLをコピーしました