Python > ARIMA/SARIMAによる時系列データの予測
広告
Pythonで時系列データの予測を行います。
(時系列データに限らず)予測をするためには、過去のデータから「モデル」を構築し、そのモデルに基づいて未来を予測するのが定石です。 モデルには多くの種類がありますが、有名な時系列データのモデルとしてARIMAやSARIMAがありますので、それを使ってみます。 ググれば、これらのモデルの解説ページはたくさん見つかると思います。
予測に使うサンプルデータとして、AirPassengers.csvを使います。 このデータは色々なところで勉強用に使われているので、"AirPassengers.csv"でググれば、データは見つかります。例えば、ここに、データがあります。
以下の説明ではJupyter Notebook環境を前提とします。 以下、NotebookやPandasの使い方については詳しく説明しません。 Jupyter Notebookのインストールは「Jupyter Notebookを使う」の記事を参照してください。 また、pandasの使い方については、Pandasの使い方を参照してください。
必要なモジュールの準備
いくつかモジュールが必要なので、importしておきます。 (モジュールのインストールはpipなどのモジュール管理システムでインストールしてください)import numpy as np import pandas as pd from scipy import stats from matplotlib import pylab as plt import seaborn as sns %matplotlib inline sns.set() import statsmodels.api as sm
データの読み込み
今回使うデータ"AirPassengers.csv"は、以下のような形式になっています。1列目が「年月」で、2列目が乗客数ですね。Month,#Passengers 1949-01,112 1949-02,118 1949-03,132 1949-04,129 1949-05,121これを、以下のようにして読み込みます。
# 日付をインデックスとして指定 dateparse = lambda dates: pd.datetime.strptime(dates, '%Y-%m') data = pd.read_csv('AirPassengers.csv', index_col='Month', date_parser=dateparse, dtype='float')乗客数のデータを読み込んでおきます。
ts = data['#Passengers']
データの分析
いよいよ分析開始です。 まずはグラフを表示してみます。 可視化することで、データの特徴を視覚的につかむことができます。plt.plot(ts)
グラフを見ると、一年周期のパターンがあることがわかると思います。 このようなパターンは、分析/予測において重要になります。
ARIMAモデルは、3つのパラメータを持ちます。
- AR(自己回帰)モデルにおける「回帰数」
- 差分を取る「回数」。ここでの差分とは、時系列データにおいて隣り合うデータの差分です。
- MA(移動平均)モデルにおける「平均を計算する際に考慮するデータ数」
差分データはdiffメソッドで簡単に計算できます。
diff = ts.diff() diff = diff.dropna() # NaN要素を削除statsmodelsを用いて、上記1.と3.の値を計算します。
params = sm.tsa.arma_order_select_ic(diff, ic='aic', trend='nc')paramsを表示すると以下になります。
{'aic': 0 1 2 0 NaN 1397.257791 1397.093436 1 1401.852641 1412.615224 1385.496795 2 1396.587654 1378.338024 1353.175766 3 1395.021214 1379.614000 1351.138814 4 1388.216680 1379.616584 1373.560615, 'aic_min_order': (3, 2)}
最後の(3, 2)が、上記1.と3.の値になります。
では、ARIMAモデルを構築してみます。
from statsmodels.tsa.arima_model import ARIMA arima_model = ARIMA(ts, order=(3,1,2)).fit(dist=False)tsは対象となる時系列データです。そのあとのorderパラメータが、上記1.、2.、3.のパラメータになります。 少し時間がかかるかもしれませんが、モデル構築が終わったら、残差を表示してみましょう。 (残差(residual)とは、モデルから推測される値と、実測値の差です。)
resid = arima_model.resid plt.plot(resid)
どんどん残差が大きくなるグラフが表示されたかと思います。 残差が大きくなっているわけなので、モデルとしてはイマイチ、ということです。
では次に、SARIMAモデルを構築してみます。 SARIMAモデルは、ARIMAモデルに「季節的な周期パターン」を加えたモデルです。 今回の乗客数データは一年間ごとの繰り返しがあるので、なんだかマッチしそうです。 以下が、モデル構築のコードになります。
sarima_model = sm.tsa.SARIMAX(ts, order=(3,1,2), seasonal_order=(1,1,1,12)).fit()
seasonal_orderというパラメータが追加されています。 このパラメータは4つの引数を持ちます。 4つ目の値(上記だと12)が、周期が発生する間隔です。 今回のデータは、月ごとのデータで、一年周期が存在するので、12としています。
1~3番目の引数は全部1にしていますが、これはえいやで決めています。 引数の意味としては、季節的な周期パターンにおける自己回帰モデルや差分回数などのパラメータになります。
先ほどと同じように残差をグラフ化してみると、ARIMAモデルよりも残差が小さくなっていることがわかると思います。
plt.plot(sarima_model.resid)
predict = sarima_model.predict('1960-01-01', '1962-12-01') plt.plot(ts) # 実データをプロット plt.plot(predict) # 予測データをプロット
predicメソッドで予測する範囲(開始/終了)を指定するわけですが、開始時期は実測値がある時期を指定する必要があります。 AirPassengers.csvは1960年12月までデータがあるので、ここでは1960年1月を開始時期としています。