Python > ARIMA/SARIMAによる時系列データの予測

更新日 2019-08-16
広告

Pythonで時系列データの予測を行います。

(時系列データに限らず)予測をするためには、過去のデータから「モデル」を構築し、そのモデルに基づいて未来を予測するのが定石です。 モデルには多くの種類がありますが、有名な時系列データのモデルとしてARIMAやSARIMAがありますので、それを使ってみます。 ググれば、これらのモデルの解説ページはたくさん見つかると思います。

予測に使うサンプルデータとして、AirPassengers.csvを使います。 このデータは色々なところで勉強用に使われているので、"AirPassengers.csv"でググれば、データは見つかります。例えば、ここに、データがあります。

以下の説明ではJupyter Notebook環境を前提とします。 以下、NotebookやPandasの使い方については詳しく説明しません。 Jupyter Notebookのインストールは「Jupyter Notebookを使う」の記事を参照してください。 また、pandasの使い方については、Pandasの使い方を参照してください。

必要なモジュールの準備

いくつかモジュールが必要なので、importしておきます。 (モジュールのインストールはpipなどのモジュール管理システムでインストールしてください)
import numpy as np
import pandas as pd
from scipy import stats
from matplotlib import pylab as plt
import seaborn as sns
%matplotlib inline
sns.set()
import statsmodels.api as sm

データの読み込み

今回使うデータ"AirPassengers.csv"は、以下のような形式になっています。1列目が「年月」で、2列目が乗客数ですね。
Month,#Passengers
1949-01,112
1949-02,118
1949-03,132
1949-04,129
1949-05,121
これを、以下のようにして読み込みます。
# 日付をインデックスとして指定
dateparse = lambda dates: pd.datetime.strptime(dates, '%Y-%m')
data = pd.read_csv('AirPassengers.csv', index_col='Month', date_parser=dateparse, dtype='float')
乗客数のデータを読み込んでおきます。
ts = data['#Passengers']

データの分析

いよいよ分析開始です。 まずはグラフを表示してみます。 可視化することで、データの特徴を視覚的につかむことができます。
plt.plot(ts)

グラフを見ると、一年周期のパターンがあることがわかると思います。 このようなパターンは、分析/予測において重要になります。

ARIMAモデルは、3つのパラメータを持ちます。

  1. AR(自己回帰)モデルにおける「回帰数」
  2. 差分を取る「回数」。ここでの差分とは、時系列データにおいて隣り合うデータの差分です。
  3. MA(移動平均)モデルにおける「平均を計算する際に考慮するデータ数」
ARIMAモデルを用いてデータ予測するためには、この3つのパラメータを適切に決める必要があります。 Pythonは、上記2.(差分を取る回数)は自動的に計算できないみたいなので、今回は差分回数は「1」で決め打ちします。

差分データはdiffメソッドで簡単に計算できます。

diff = ts.diff()
diff = diff.dropna() # NaN要素を削除
statsmodelsを用いて、上記1.と3.の値を計算します。
params = sm.tsa.arma_order_select_ic(diff, ic='aic', trend='nc')
paramsを表示すると以下になります。
{'aic':              0            1            2
 0          NaN  1397.257791  1397.093436
 1  1401.852641  1412.615224  1385.496795
 2  1396.587654  1378.338024  1353.175766
 3  1395.021214  1379.614000  1351.138814
 4  1388.216680  1379.616584  1373.560615, 'aic_min_order': (3, 2)}

最後の(3, 2)が、上記1.と3.の値になります。

では、ARIMAモデルを構築してみます。

from statsmodels.tsa.arima_model import ARIMA
arima_model = ARIMA(ts, order=(3,1,2)).fit(dist=False)
tsは対象となる時系列データです。そのあとのorderパラメータが、上記1.、2.、3.のパラメータになります。 少し時間がかかるかもしれませんが、モデル構築が終わったら、残差を表示してみましょう。 (残差(residual)とは、モデルから推測される値と、実測値の差です。)
resid = arima_model.resid
plt.plot(resid)

どんどん残差が大きくなるグラフが表示されたかと思います。 残差が大きくなっているわけなので、モデルとしてはイマイチ、ということです。

では次に、SARIMAモデルを構築してみます。 SARIMAモデルは、ARIMAモデルに「季節的な周期パターン」を加えたモデルです。 今回の乗客数データは一年間ごとの繰り返しがあるので、なんだかマッチしそうです。 以下が、モデル構築のコードになります。

sarima_model = sm.tsa.SARIMAX(ts, order=(3,1,2), seasonal_order=(1,1,1,12)).fit()

seasonal_orderというパラメータが追加されています。 このパラメータは4つの引数を持ちます。 4つ目の値(上記だと12)が、周期が発生する間隔です。 今回のデータは、月ごとのデータで、一年周期が存在するので、12としています。

1~3番目の引数は全部1にしていますが、これはえいやで決めています。 引数の意味としては、季節的な周期パターンにおける自己回帰モデルや差分回数などのパラメータになります。

先ほどと同じように残差をグラフ化してみると、ARIMAモデルよりも残差が小さくなっていることがわかると思います。

plt.plot(sarima_model.resid)
いよいよ、構築したモデルに基づいて予測してみます。
predict = sarima_model.predict('1960-01-01', '1962-12-01')
plt.plot(ts) # 実データをプロット
plt.plot(predict) # 予測データをプロット

predicメソッドで予測する範囲(開始/終了)を指定するわけですが、開始時期は実測値がある時期を指定する必要があります。 AirPassengers.csvは1960年12月までデータがあるので、ここでは1960年1月を開始時期としています。

このように、非常にそれっぽい予測値を得ることができました!
広告
お問い合わせは sweng.tips@gmail.com まで。
inserted by FC2 system