TensorFlow Probabilityでベイズ構造時系列モデル

- arata-furukawa

TensorFlow Probabilityは、確率的推論と統計的分析のためのライブラリです。

もともと、Edwardなどのライブラリを取り込んで急速に発展しました。以前はEdwardそのままな部分も多かったのですが、現在はだいぶAPIが整理されていて、とても使いやすくなっています。

TensorFlowで実装されていることで、GPUやTPUなどの最新のハードウェアアクセラレーションを活用することができるほか、ニューラルネットワークと組み合わせるといったことも可能で、Kerasのインターフェイスで簡単にモデルを構築できます。

TensorFlow Probabilityでできること

  • 確率に関する一通りの基本的な演算
  • 確率的統計モデリングのための機能
  • 変分推論とMCMC
  • Nelder-MeadやBFGS、SGLDなど多数の最適化アルゴリズム
  • 構造時系列モデル
  • などなど。。。

普段、RやStanで統計分析などを行っている人が求めている機能がだいぶ揃ってきているのではないかと思います。

構造時系列モデル

この記事では、TensorFlow Probabilityで構造時系列を扱います。

TensorFlow Probabilityに構造時系列サポートが登場したのは2年前、2019年の春でした。

Googleのサイエンティストが作ったRのbsts(ベイジアン構造時系列の略)パッケージを踏襲しており、構造時系列のための状態空間モデルを変分推論とHMCでフィッティングできます。

  • 自己回帰
  • ローカル線形トレンド
  • 季節性
  • 外部共変量での回帰

といった機能が予め用意されています。

もともと、業務で機械学習を扱う際はいつもTensorFlowを採用しているので、TensorFlow Probabilityは個人的にとても関心の高いライブラリでした。しかし当時はEager実行のサポートが十分でなくSession等を使う必要があり、これが非常に不便であったため、残念ながら採用することはありませんでした。2年でだいぶライブラリが成熟してきたようなので、今回はTensorFlow v2サポートにも注目しながら、構造時系列モデルを再び試してみます。

この記事は下記のバージョンで動作を確認しています。TensorFlow Probabilityは本体に比べてまだAPIが不安定なところもあるので、ご注意ください。

import tensorflow as tf
import tensorflow_probability as tfp
print('TensorFlow', tf.__version__)
print('TensorFlow Probability', tfp.__version__)
TensorFlow 2.5.0
TensorFlow Probability 0.13.0

今回は簡単に、疑似データを作って使用します。簡単に一通り試したいので、過程誤差は0で、適当に周期性とトレンドを入れておきます。

%matplotlib inline
%config InlineBackend.figure_formats = {'png', 'retina'}
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

np.random.seed(0)
def create_data(t, n_observed):
  tn = len(t)
  ti = np.r_[:tn]
  trend = 0.025 * ti
  interc = 10
  obs_noise = np.random.normal(0, 0.5, tn)

  s = np.sin(ti * np.pi / 3) + np.sin(ti * np.pi / 6) + trend + interc
  y = s + obs_noise

  return s, y, (t[:n_observed], y[:n_observed]), (t[n_observed:], y[n_observed:])
t = np.arange("1996-01", "2021-01", dtype="datetime64[M]")
tn = len(t)
n_observed = int(tn*0.8)
n_forecast = tn - n_observed
t_now = t[n_observed]
print(f'all={tn}, observed={n_observed}, forecast={n_forecast}')

state, y, (t_observed, y_observed), (t_forecast, y_forecast) = create_data(t, n_observed)

plt.figure(figsize=(10, 4))
plt.axvline(t_now, linestyle="--")
plt.plot(t, state, lw=2, color='gray', alpha=.5, label='ground truth of state')
plt.plot(t_observed, y_observed, marker='x', markersize=3, ls='--', lw=.3, color='b', alpha=.5, label='observed for train')
plt.plot(t_forecast, y_forecast, marker='x', markersize=3, ls='--', lw=.3, color='r', alpha=.5, label='observed for eval forecast')
plt.legend()
plt.show()
all=300, observed=240, forecast=60
png

青点が学習に用いる観測データです。赤点が予想したい部分の正答とする観測データです。

モデリングしましょう。

def build_model(observed_time_series):
  """ ローカル線形トレンド+季節性モデル """
  return tfp.sts.Sum([
      tfp.sts.LocalLinearTrend(observed_time_series=observed_time_series),
      tfp.sts.Seasonal(num_seasons=12, observed_time_series=observed_time_series),
  ], observed_time_series=observed_time_series)

tfp.stsモジュールに構造時系列用の機能が含まれています。tfp.sts.Sumで複数のコンポーネントを組み合わせることが可能です。ここでは疑似データの特徴がわかりきっているので、そのままローカル線形トレンドと12個周期の季節効果をモデル化します。

model = build_model(y_observed)

変分近似事後分布を求めます。

variational_posteriors = tfp.sts.build_factored_surrogate_posterior(model=model)

ELBOの最小化

変分損失関数ELBO(Evidence Lower BOund)をAdamで最小化します。tfp.viモジュールに変分推論のための機能が含まれています。

num_steps = int(200)
optimizer = tf.optimizers.Adam(learning_rate=.1)

@tf.function(jit_compile=True)
def train():
  elbo_loss_curve = tfp.vi.fit_surrogate_posterior(
      target_log_prob_fn=model.joint_log_prob(
          observed_time_series=y_observed),
      surrogate_posterior=variational_posteriors,
      optimizer=optimizer,
      num_steps=num_steps)
  return elbo_loss_curve
%%time
elbo_loss_curve = train()

plt.plot(elbo_loss_curve)
plt.show()
png
CPU times: user 16.6 s, sys: 415 ms, total: 17.1 s
Wall time: 16.6 s

動きました! 以前検証したときは、この段階からSessionが必要でしたが、とてもシンプルに実行できるようになっています。また、手元の同一PC上では、2019年に試した際は200ステップに30秒ほどかかっていたのが、10秒以上速くなっています。v2に対応してtf.functionをJITコンパイルした効果が出ているのかもしれません。

推定事後分布のサンプルから推定パラメータを確認します。

ps_vi = variational_posteriors.sample(50)
def print_params(parameter_samples):
  for name, value in parameter_samples.items():
    print(f'{name:40}: {np.mean(value, axis=0):.2f}'+
          f' ± {np.std(value, axis=0):.2f}')

print_params(ps_vi)
observation_noise_scale                 : 0.51 ± 0.02
LocalLinearTrend/_level_scale           : 0.02 ± 0.04
LocalLinearTrend/_slope_scale           : 0.00 ± 0.00
Seasonal/_drift_scale                   : 0.01 ± 0.02

Forecast

疑似データではありますが、良い感じに出ました。Forecastしてみましょう。

def forecast(model, observed_time_series, parameter_samples, num_steps_forecast):
  dist = tfp.sts.forecast(
      model,
      observed_time_series=observed_time_series,
      parameter_samples=parameter_samples,
      num_steps_forecast=num_steps_forecast)

  loc = dist.mean()[..., 0]
  scale = dist.stddev()[..., 0]
  samples = dist.sample(10)[..., 0]
  return dist, loc, scale, samples
dist, loc, scale, samples = forecast(model, y_observed, ps_vi, n_forecast)
def score(y, p):
  se = np.power(y - p, 2)
  mse = np.mean(se)
  r2 = np.power(np.corrcoef(y, p)[0,1], 2)
  return mse, r2

def plot_score(y, s, p):
  def pos(co,r):
    a, b = co
    return a+(b-a)*r

  mse, r2 = score(y, p)
  mse_s, r2_s = score(s, p)
  print(f'       mse: {mse:7.4f}, R2: {r2:7.4f}')
  print(f' state mse: {mse_s:7.4f}, R2: {r2_s:7.4f}')

  plt.figure(figsize=(10,5))
  plt.subplot(1,2,1)
  plt.scatter(p, y)
  plt.xlabel('forecast')
  plt.ylabel('ground truth observed')
  xlim = plt.xlim()
  ylim = plt.ylim()
  plt.text(pos(xlim,0.1), pos(ylim,0.9), f'mse={mse:.4f}\n$R^2$={r2:.4f}')
  plt.plot(xlim, ylim, c='gray')
  plt.xlim(xlim)
  plt.ylim(ylim)
  plt.subplot(1,2,2)
  plt.scatter(p, s)
  plt.xlabel('forecast')
  plt.ylabel('ground truth state')
  xlim = plt.xlim()
  ylim = plt.ylim()
  plt.text(pos(xlim,0.1), pos(ylim,0.9), f'mse={mse_s:.4f}\n$R^2$={r2_s:.4f}')
  plt.plot(xlim, ylim, c='gray')
  plt.xlim(xlim)
  plt.ylim(ylim)

def plot_forecast_results(t_observed, y_observed, t_forecast, y_forecast, t, state,
                          forecast_samples, forecast_mean, forecast_scale,
                          num_step_forecast):
  plt.figure(figsize=(10,5))
  plt.axvline(t_forecast[0], linestyle='--')
  plt.plot(t, state, lw=2, color='gray', alpha=.5, label='ground truth of state')
  plt.plot(t_observed, y_observed, marker='x', markersize=3, ls='--', lw=.3, color='b', alpha=.5, label='observed')
  plt.plot(t_forecast, tf.transpose(forecast_samples), lw=.6, ls='--', color='r', alpha=0.1)
  plt.plot(t_forecast, forecast_mean, lw=1, ls='-', color='r', alpha=1, label='forecast')
  plt.fill_between(t_forecast,
                   forecast_mean-2*forecast_scale,
                   forecast_mean+2*forecast_scale, color='r', alpha=0.1)
  plt.legend()
plot_score(y[n_observed:], state[n_observed:], loc)
plt.show()
plot_forecast_results(t_observed, y_observed, t_forecast, y_forecast, t, state,
                      samples, loc, scale, n_forecast)
plt.show()
       mse:  0.2731, R2:  0.8228
 state mse:  0.0245, R2:  0.9809
png
png

よくできました!

コンポーネント分解

コンポーネント分解もしてみましょう。

def get_components(model, observed_time_series, parameter_samples, forecast_dist):
  # 観測分 component dists for observed
  o = tfp.sts.decompose_by_component(
      model,
      observed_time_series=observed_time_series,
      parameter_samples=parameter_samples)
  means_o = {k.name: c.mean() for k, c in o.items()}
  stddevs_o = {k.name: c.stddev() for k, c in o.items()}

  # 予測分
  f = tfp.sts.decompose_forecast_by_component(
      model,
      forecast_dist=forecast_dist,
      parameter_samples=parameter_samples)
  means_f = {k.name: c.mean() for k, c in f.items()}
  stddevs_f = {k.name: c.stddev() for k, c in f.items()}

  # 観測分と予測分をキー毎に結合
  def concat_per_key(dicts, keys=None):
    if keys is None:
      keys = dicts[0].keys()
    return {
      k: tf.concat([d[k] for d in dicts], axis=0) for k in keys
    }
  means = concat_per_key([means_o, means_f])
  stddevs = concat_per_key([stddevs_o, stddevs_f])
  return means, stddevs
means, stddevs = get_components(model, y_observed, ps_vi, dist)
def plot_components(t, means, stddevs, vline=None):
  N = len(means)
  fig = plt.figure(figsize=(12, 3 * N))
  for i, name in enumerate(means.keys()):
    mean = means[name]
    stddev = stddevs[name]
    ax = fig.add_subplot(N,1,1+i)
    if vline is not None:
      ax.axvline(vline, ls='--', c='b', lw=2, alpha=.5)
    ax.plot(t, mean, lw=2)
    ax.fill_between(t, mean-2*stddev, mean+2*stddev, alpha=0.2)
    ax.set_title(name)
  fig.autofmt_xdate()
  fig.tight_layout()
plot_components(t, means, stddevs, t_now)
plt.show()
png
def forecast_and_plot(model, parameter_samples,
                      t_observed, y_observed, t_forecast, y_forecast, t, state,
                      num_steps_forecast):
  """ ここまでの処理をまとめた """
  print_params(parameter_samples)

  dist, loc, scale, samples = forecast(model, y_observed, parameter_samples, num_steps_forecast)

  plot_score(y_forecast, state[len(t_observed):], loc)
  plt.show()
  plot_forecast_results(t_observed, y_observed, t_forecast, y_forecast, t, state,
                        samples, loc, scale, num_steps_forecast)
  plt.show()
  means, stddevs = get_components(model, y_observed, parameter_samples, dist)
  plot_components(t, means, stddevs, t_forecast[0])
  plt.show()

HMC

ついでにHMC(ハミルトニアンモンテカルロ法)を用いて事後分布からサンプルを取得してみましょう。ここではパラメータは調整せずに実行してみす。変分推論に比べてかなり時間がかかります。サンプリング関係はかなり機能が奥深いので、今回は触りだけにします。

%%time
model2 = build_model(y_observed)
samples_hmc, kernel_results = tfp.sts.fit_with_hmc(model2, y_observed)
CPU times: user 18min 55s, sys: 4min 38s, total: 23min 34s
Wall time: 6min 20s
print(f'Acceptance Rate: {np.mean(kernel_results.inner_results.inner_results.is_accepted, axis=0):.2%}')
Acceptance Rate: 98.00%
ps_hmc = {p.name: v for p, v in zip(model2.parameters, samples_hmc)}
forecast_and_plot(model2, ps_hmc,
                  t_observed, y_observed, t_forecast, y_forecast, t, state,
                  n_forecast)
observation_noise_scale                 : 0.47 ± 0.02
LocalLinearTrend/_level_scale           : 0.02 ± 0.02
LocalLinearTrend/_slope_scale           : 0.00 ± 0.00
Seasonal/_drift_scale                   : 0.01 ± 0.02
       mse:  0.2744, R2:  0.8225
 state mse:  0.0253, R2:  0.9812
png
png
png

今回はここまで。Eager実行がちゃんとサポートされていることと、構造時系列モデルの基本が簡単に構築できることがわかりました。

Statsmodelsといったライブラリに比べると、要約や可視化に関する機能が存在していない分、そこは少々手間に感じます。しかしTensorFlow Probabilityには構造時系列以外にもかなりの機能があり、TensorFlowやKerasと連携できることや、GPUなどのリソースをフル活用できるなど、メリットもたくさんあります。

Eager実行が安定していることがわかったので、気が向いたときにTensorFlow Probabilityの他の機能や、実際のデータを用いて構造時系列モデリングの例などもブログにしようかなと思っています。