GluonTSで始める深層学習による時系列予測

AI・機械学習

深層学習を用いた時系列予測の世界へようこそ!この記事では、PythonのパッケージGluonTSを用いて、初心者でも時系列予測を理解し、実践できるようになるための完全ガイドを提供します。

時系列予測とは何か?

時系列予測とは、過去のデータに基づいて未来の出来事を予測することです。特に、時系列データ(時間とともに変化するデータ)の将来の値を予測することを指します。例えば、電力会社は電力需要を予測し、それに合わせて発電量を調整します。

時系列予測では、過去の観測可能な行動が未来にも継続するという暗黙の仮定に基づいています。例えば、人々は一般的に夜間よりも日中に多くのエネルギーを消費し、夕方にはテレビを視聴し、夏にはエアコンを使用します。

しかし、予測不可能なイベント(例えば、パンデミックによる旅行制限)は予測できません。したがって、時系列予測は、根本的な変化が起こらないという条件下で、通常の事象を予測するためのツールです。

ターゲットと特徴量

時系列予測では、予測したい時系列をターゲット時系列と呼びます。過去のターゲット値は、モデルが正確な予測を行うために最も重要な情報です。

加えて、モデルは特徴量(ターゲット値に影響を与える追加の値)を利用できます。特徴量には、静的特徴量と動的特徴量があります。

  • 動的特徴量: 各時点ごとに異なる可能性のある特徴量です(例:製品の価格、気温)。
  • 静的特徴量: 時間に依存しない、時系列を記述する特徴量です(例:店舗ID、製品ID)。

さらに、特徴量にはカテゴリ型と連続型があります。

  • 連続型特徴量: 数値自体に意味がある特徴量です(例:価格)。
  • カテゴリ型特徴量: 数値自体に意味がなく、個々の値が異なるカテゴリーを表す特徴量です(例:店舗ID)。

確率的予測

GluonTSの核となるアイデアの一つは、単純な値ではなく、確率分布を予測することです。これは、100回予測を行い、その結果得られた100個の時系列サンプルから分布を生成するようなものです。

分布は、予測値の範囲を示すという利点があります。例えば、レストランのオーナーは、食材をどれだけ購入すべきか悩むかもしれません。少なすぎると顧客の需要に応えられず、多すぎると無駄になります。需要を予測する際、モデルが「おそらく50皿の需要があるが、60皿を超える可能性は低い」と教えてくれれば、非常に役立ちます。

ローカルモデルとグローバルモデル

GluonTSでは、ローカルモデルとグローバルモデルという概念を使用します。

  • ローカルモデル: 単一の時系列に対して適合し、その時系列の予測に使用されます。
  • グローバルモデル: 多くの時系列にわたって学習され、単一のグローバルモデルがデータセット内のすべての時系列の予測に使用されます。

グローバルモデルの学習には時間がかかるため(数時間から数日)、予測リクエストの一部としてモデルを学習することは現実的ではなく、別途「オフライン」で行われます。一方、ローカルモデルの学習は通常はるかに高速であり、「オンライン」で予測の一部として行われます。

GluonTSでは、ローカルモデルはPredictorとして直接利用できますが、グローバルモデルはEstimatorとして提供され、最初に学習させる必要があります。

GluonTSのインストール

GluonTSはPyPiからインストールできます。

pip install gluonts

注意: GluonTSはバージョン管理にセマンティックバージョニングを使用しています。活発に開発されているため、メジャーバージョンとしてv0を使用しています。四半期ごとに新しいマイナーバージョンをリリースする予定です。現在のリリース予定はGitHubで確認できます。

オプションと追加の依存関係

GluonTSは最小限の依存関係モデルを使用しています。つまり、ほとんどのモデルと機能を使用するには、追加の依存関係をインストールする必要があります。

Pythonには、パッケージの特定の機能のロックを解除するためにオプションでインストールできる「extras」という概念があります。パッケージをインストールするとき、[...]でパッケージ名とバージョン指定子の間に渡されます。

pip install "some-package[extra-1,extra-2]==version"

GluonTSでは、必要な依存関係の量を最小限に抑えるために、オプションの依存関係を幅広く利用しています。それでもユーザーが特定の機能を選択できるように、多くの追加の依存関係を公開しています。

例えば、Apache Arrowを使用してArrowおよびParquetベースのデータセットを読み書きするためのサポートを提供しています。ただし、特に必要がない場合は、これは必須の依存関係としては大きすぎます。したがって、必要なパッケージをインストールし、次のように簡単に有効にできるarrow-extraを提供しています。

pip install "gluonts[arrow]"

GluonTSの使い方

データの読み込みと確認

ここでは、サンプルデータとしてelectricityデータセットを使用します。これは、電力消費量の時系列データです。

import pandas as pd
from gluonts.dataset.common import ListDataset

# サンプルデータの準備
url = "https://raw.githubusercontent.com/numenta/NAB/master/data/realKnownCause/nyc_taxi.csv"
df = pd.read_csv(url)
df = df.rename(columns={"timestamp": "start"})
df["target"] = df["value"]
df["start"] = pd.to_datetime(df["start"])
df = df[["start","target"]]
# ListDatasetとして読み込み
train_data = ListDataset(
    df.iloc[:-6 * 24, :].to_dict("records"),
    freq="H"
)
test_data = ListDataset(
    df.iloc[-6 * 24:, :].to_dict("records"),
    freq="H"
)

モデルの定義と学習

ここでは、GluonTSで提供されているDeepAREstimatorを使用します。DeepARは、深層学習を用いた確率的時系列予測モデルの一つです。

from gluonts.torch.model.deepar import DeepAREstimator
from gluonts.torch.trainer import Trainer

estimator = DeepAREstimator(
    freq="H",
    prediction_length=24,
    context_length=168,
    trainer=Trainer(
        epochs=5
    )
)
predictor = estimator.train(training_data)

予測

学習済みモデルを使って、未来の電力消費量を予測します。

from gluonts.dataset.util import to_pandas

for test_entry, forecast in zip(test_data, predictor.predict(test_data)):
    to_pandas(test_entry)[-24:].plot(linewidth=2)
    forecast.plot(color='g')
plt.grid(which='both')
plt.legend(["observations", "median prediction", "90% confidence interval"], loc="upper left")
plt.show()

まとめ

この記事では、GluonTSを用いた時系列予測の基本的な流れを紹介しました。GluonTSは、深層学習を用いた時系列予測を手軽に始められる強力なツールです。ぜひ、GluonTSを使って、様々な時系列データの予測に挑戦してみてください。

参考資料

GluonTS documentation
GluonTS documentation

コメント

タイトルとURLをコピーしました