ARCコンペ:代数をEDAに適用する

数値計算

このノートブックでは、同じ次元を持つ正方行列の例のみ を使用します。データセットには、他の分布も含まれています。

数字を使った例をプロットすることで、例の数学的分布を視覚化することができます。また、引き算によって、入力と出力の違いを明確に理解することができます。

ライブラリのインポート

まずは必要なライブラリをインポートします。

import json  # JSON データの読み込み用
import pandas as pd  # データの操作用
import numpy as np  # 数値計算用
import matplotlib.pyplot as plt  # グラフ描画用
from   matplotlib import colors  # 色の設定用

データの読み込み

データセットを読み込みます。

base_path = '/kaggle/input/arc-prize-2024/'

# JSON データを読み込む関数
def load_json(file_path):
    with open(file_path) as f:
        data = json.load(f)
    return data

# このノートブックでは、正方形で等しい行列の例のみを使用します。
# データセットには、異なる分布を持つ他のcsvファイルも含まれています。
df = pd.read_csv("/kaggle/input/arc-2024-training-explamples-by-form/equals_squared_train.csv")   
training_challenges = load_json(base_path + 'arc-agi_training_challenges.json')

入力、出力、引き算の抽出

入力と出力の行列を取得し、その差分を計算する関数を定義します。

# 行列を読み込む関数
def get_matrix_pair(challenge):
    x = pd.DataFrame(challenge['input'])  # 入力の行列
    y = pd.DataFrame(challenge['output']) # 出力の行列

    # 引き算を行う (他の演算に変更することも可能)
    z = y - x  
    return x, y, z

色の設定

ヒートマップに使用する色を設定します。

# 色のリストを定義
cmap = colors.ListedColormap(
    ['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
     '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])

# 0から9までの値を色に正規化
norm = colors.Normalize(vmin=0, vmax=9)

# 色の見本を表示
plt.figure(figsize=(3, 1), dpi=150) 
plt.imshow([list(range(10))], cmap=cmap, norm=norm)
plt.xticks(list(range(10)))
plt.yticks([])
plt.show()

ヒートマップの作成

数字付きのヒートマップを作成するための関数を定義します。

def heatmap(data, row_labels, col_labels, ax=None,
            cbar_kw=None, cbarlabel="", **kwargs):
    """
    NumPy 配列と2つのラベルリストからヒートマップを作成する関数

    Parameters
    ----------
    data: numpy.ndarray
        (M, N) 形状の2次元 NumPy 配列
    row_labels: list または numpy.ndarray
        行のラベルを含む、長さ M のリストまたは配列
    col_labels: list または numpy.ndarray
        列のラベルを含む、長さ N のリストまたは配列
    ax: matplotlib.axes.Axes, optional
        ヒートマップを描画する Axes インスタンス。指定しない場合は、現在の Axes を使用するか、新しい Axes を作成します。
    cbar_kw: dict, optional
        matplotlib.figure.Figure.colorbar に渡す引数の辞書。
    cbarlabel: str, optional
        カラーバーのラベル。
    **kwargs: 
        imshow に渡すその他の引数。
    """

    if ax is None:
        ax = plt.gca()

    if cbar_kw is None:
        cbar_kw = {}

    # ヒートマップを描画
    im = ax.imshow(data, **kwargs)

    # カラーバーを作成
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")

    # すべての目盛りを表示し、それぞれのリストのエントリでラベル付けします。
    ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
    ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)

    # 横軸のラベルを上に表示します。
    ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)

    # 目盛ラベルを回転させて配置します。
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
             rotation_mode="anchor")

    # 枠線をオフにし、白いグリッドを作成します。
    ax.spines[:].set_visible(False)

    ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
    ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    return im, cbar

# ヒートマップに数値のテキストを追加する関数
def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
                     textcolors=("black", "white"),
                     threshold=None, **textkw):
    """
    ヒートマップに注釈を付ける関数

    Parameters
    ----------
    im: matplotlib.image.AxesImage
        ラベル付けする AxesImage
    data: numpy.ndarray, optional
        注釈に使用するデータ。None の場合、画像のデータが使用されます。
    valfmt: str または matplotlib.ticker.Formatter, optional
        ヒートマップ内の注釈の書式。文字列の書式設定メソッド (例: "$ {x:.2f}") 
        または matplotlib.ticker.Formatter を使用します。
    textcolors: tuple, optional
        色のペア。最初はしきい値以下の値に使用され、2番目はしきい値以上の値に使用されます。
    threshold: float, optional
        textcolors からの色が適用されるデータ単位の値。
        None (デフォルト) の場合は、カラーマップの中間値が区切りとして使用されます。
    **kwargs: 
        テキストラベルの作成に使用される text の呼び出しごとに転送される、その他すべての引数。
    """

    if not isinstance(data, (list, np.ndarray)):
        data = im.get_array()

    # しきい値を画像の色範囲に正規化します。
    if threshold is not None:
        threshold = im.norm(threshold)
    else:
        threshold = im.norm(data.max())/2.

    # デフォルトの配置を中央に設定しますが、textkw で上書きできるようにします。
    kw = dict(horizontalalignment="center",
              verticalalignment="center")
    kw.update(textkw)

    # 文字列が指定されている場合は、フォーマッターを取得します
    if isinstance(valfmt, str):
        valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

    # データをループ処理し、各「ピクセル」の Text を作成します。
    # データに応じてテキストの色を変更します。
    texts = []
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
            text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
            texts.append(text)

    return texts

問題と解答の表示

問題、入力、出力、入力と出力の差を表示する関数を定義します。

def ploting_exercices(challenge, x, y, z):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))  # 3つのプロットを作成
    fig.suptitle(challenge) # タイトルに問題IDを表示
    cmap = colors.ListedColormap(['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
                                      '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25']) # 色のリスト
    norm = colors.Normalize(vmin=0, vmax=9) # 0~9の値を色に正規化

    # 入力行列を表示
    ax1.imshow(x, cmap=cmap, norm=norm)
    ax1.set_title('Input') # タイトル
    # 出力行列を表示
    ax2.imshow(y, cmap=cmap, norm=norm)
    ax2.set_title('Output') # タイトル
    # 入出力の差分を表示
    ax3.imshow(z, cmap=cmap, norm=norm)
    ax3.set_title('(output) - (input)') # タイトル

    # 入出力、差分の各行列の数字を表示
    for i in range(len(x[0])):
        for j in range(len(x[0])):
            text = ax1.text(i, j, x[i][j],
                           ha="center", va="center", color="r", size=15)

    for i in range(len(x[0])):
        for j in range(len(x[0])):
            text = ax2.text(i, j, y[i][j],
                           ha="center", va="center", color="r", size=15)

    for i in range(len(x[0])):
        for j in range(len(x[0])):
            text = ax3.text(i, j, z[i][j],
                           ha="center", va="center", color="r", size=15)

    fig.tight_layout()
    plt.show()

10個の例をプロット

10個の例について、問題、入力、出力、入力と出力の差を表示します。

# 最初の10個のデータを取得
batch_1 = df['id'][0:10]  
# 10個のデータについて、問題、入力、出力、入力と出力の差を表示
for challenge_id in batch_1:
    train_dic = training_challenges[challenge_id]['train']
    for pair in train_dic:
        x, y, z = get_matrix_pair(pair)
        ploting_exercices(challenge_id, x, y, z)

改善点

  • 転置や他の演算を試してみてください。
  • 入力と出力の関係を分類問題として分析するために、解答を含まないデータセットを作成してみてください。
  • 行列のパターンを探索するために、グラフを実装することを検討してみてください。

ノートブック

Kaggle: Your Home for Data Science
Kaggle is the world’s largest data science community with powerful tools and resources to help you achieve your data science goals.

参考文献

コメント

タイトルとURLをコピーしました