stanで状態空間モデル(ローカルレベル)

author: Unadon (見習い飯炊き兵) 動作環境:Mac OS Sierra 10.12.1; R version3.3.1; rstan 2.10.1

 

状態空間モデル: ローカルレベルモデル

LocalLevelBayes2.jpeg

はじめに

rstanで色々ベイズしてみたい!という方に向けて、基本的な解析とその周辺Tipsをご紹介していきます。

今回は状態空間モデル。時系列データ分析の王道です。

なかでも一番シンプルなローカルレベルモデルというのをやってみます。この形を理解していれば、周期性やイベントなどを考慮したモデルに展開していくことができます。

ローカルレベルモデルの構造は簡単で、次のような仮定をおきます。

  1. 時系列で変化する観測データは、同じく時系列で変化する真の状態+観測誤差から生成されている。

  2. 真の状態は、一時点前の状態+移動範囲で動く

これだけです。

たとえば、観測データを”人工知能の検索回数1年分”としましょう。真の状態は”人工知能のネット注目度”あたりでしょうか。

今日の注目度は、昨日の注目度にある程度依存していると考えられますので、昨日の注目度+移動範囲のなかで、今日の注目度が求められます。

今日の注目度+誤差で、実際の今日の検索回数が生まれる、と考えられます。

さて、ベイズでは、”真の状態”、”観測誤差”、”移動範囲”をパラメータとして求めます。

さっそく、データを準備しましょう。 GoogleTrendから時系列データをとってきます。

データの取得

Googleアカウントのアドレスとパスワードを入力して、下記を実行しましょう。 データの可視化まで行きます。

取ってくるデータは、”人工知能”、”ディープラーニング”、”機械学習”の3つのキーワードの検索人気度を過去13年分です。

確か、週ごとに間引かれた値が取得されたはずです。

rm(list=ls(all=TRUE))

#Initial Setting------------------------------------
#set your google accounts
usr <- "" #mail adress
psw <- "" # PassWord

#Keywors setting
keyWords<-c("人工知能","ディープラーニング","機械学習")

#samoling term start to end
startDate<-c("2004-01-01")
endDate<-c("2017-01-01")
#-----------------------------------------------------

library(gtrendsR) # for Scraping Google Trends
library(ggplot2) # Graph
library(plotly)# Graph
library(stringi) # Change Encoding
library(stringr) # Change Encoding
library(rstan) #stan
library(reshape2) #data reshape
library(shinystan) #shiny stan

#Stan Parallel(each chain) processing 
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

#############################################################
#1. Data scraping from Google trend with R package "gtrendsR" 
#############################################################

#login
gconnect(usr, psw)

#Scraping with gtrends
GetTrend <- gtrends(query = stri_encode(keyWords, "", "UTF-8"),  # Set UTF-8 or shift-jis
                    geo = "JP", # get data of "Japan Regions: JP"
                    start_date = startDate, 
                    end_date = endDate)



# Geometoric point&line plot with ggplot2-----------------------------------------
Dat<-as.data.frame(GetTrend$trend)# ここでデータフレームにする
g<-ggplot(Dat, aes(x = start, y = hits, colour = keyword)) +   
    geom_point(size = 1.0,alpha=0.4) +theme_set(theme_bw(base_family="HiraKakuProN-W3"))+
    labs(colour = "検索クエリ", x = "日付", y = "検索人気度", title = GetTrend$meta) +
    geom_line(size = 0.3,alpha=0.8)+
    scale_colour_brewer(palette = "Set2")+ 
    facet_wrap(~keyword,ncol=3,scales = "free")

plot(g)

ぃぃぃよっこらせっと!

AItrend.jpeg

急激な右肩上がりです。特にディープラーニング。

10年前の”人工知能”はSFの夢物語だった。

ところが、ディープラーニングの登場によって人口知能は一気に”現実の技術に変わった、というような印象を持ちますね。

前半と後半で”人工知能”というキーワードのもつ意味が、変わっているように思えます。

図がちょっと見にくい(時系列データ感がない)ので、見せ方を変えておきます。

AItrend2.jpeg

ローカルレベルモデルの実行

さて、それぞれのキーワードの検索人気度から、真の状態(ネット注目度?)を求めていきましょう。

682のタイムポイントで、検索人気度hitsが得られたので、キーワードごとに整理します。

そしてそれらをリスト化し、Stanに渡すデータとします。

AI<-Dat$hits[1:682]#AIの検索人気度
DL<-Dat$hits[c(682+1):c(682*2)] #ディープラーニング
ML<-Dat$hits[c(682*2+1):c(682*3)] #機械学習
N=3 #キーワードの数
t=682 #タイムポイントの数

#stanに渡すデータ
datastan= list(AI=AI,ML=ML,DL=DL,t=t,N=N)

#LocalLevelのパラメータ。それぞれ観測誤差、状態の誤差(移動範囲)、最初の心の状態、2時点目以降の真の状態
parameters<-c("sigmaObs","sigmaState","muZero","mu")

Stanモデルの定義

ここからがStanによるローカルレベルモデルの定義です。 ポイントは、

  1. タイムポイント(t)の数だけ観測データ(AI,DL,ML)がある。

  2. タイムポイント(t)とキーワード[N]の数だけ真の状態(mu[N,t])がある。

  3. 真の状態は、一時点前の真の状態muを平均とし、移動範囲の標準偏差sigmaStateをもつ正規分布から出てくる

  4. 観測データは、同時点の真の状態を平均とし、観測誤差sigmaObsを標準偏差とする正規分布から出てくる

  5. 一番最初の真の状態は、一次点前の状態が未知なので、muZeroで別定義する。

というところ。そのまんまコード化します。

#Stan Model
model<-"
data{
    int N;
    int t;
    int AI[t];
    int ML[t];
    int DL[t];
}


parameters{
    matrix <lower=0>[N,t]mu;
    real <lower=0>sigmaObs[N];
    real <lower=0>sigmaState[N];
    real <lower=0>muZero[N];
}

model{
    for(i in 1: N){
        sigmaObs[i] ~ cauchy(0,25);
        sigmaState[i] ~ cauchy(0,25);
        muZero[i] ~ normal(0,100);
        mu[i,1] ~ normal(muZero[i],sigmaState[i]);
        AI[1] ~ normal(mu[1,1],sigmaObs[1]);
        ML[1] ~ normal(mu[2,1],sigmaObs[2]);
        DL[1] ~ normal(mu[3,1],sigmaObs[3]);
    }
    for(j in 2:t){
        mu[1,j] ~ normal(mu[1,j-1],sigmaState[1]);
        AI[j] ~ normal(mu[1,j],sigmaObs[1]);
        mu[2,j] ~ normal(mu[2,j-1],sigmaState[2]);
        ML[j] ~ normal(mu[2,j],sigmaObs[2]);
        mu[3,j] ~ normal(mu[3,j-1],sigmaState[3]);
        DL[j] ~ normal(mu[3,j],sigmaObs[3]);
    }
}
"

#Stan の実行
fit<-stan(model_code=model,data=datastan,
          pars=parameters,
          iter=3000,
          chains = 3,thin=10,
          warmup = 1000)

推定結果を確認します。

#推定平均値mean(EAP推定量)と95%信用区間のチェック
summary(fit)$summary[c(1:40),c(1,3,4,8,9,10)]

                    mean         sd        2.5%      97.5%    n_eff      Rhat
sigmaObs[1]    7.0734838 0.30182107  6.52066070  7.7106476 295.5247 1.0088588
sigmaObs[2]    1.0038284 0.03064590  0.94683458  1.0677913 600.0000 0.9956023
sigmaObs[3]    1.6038556 0.06361654  1.47588848  1.7264848 276.1647 0.9998402
sigmaState[1]  2.8810838 0.37461289  2.08731027  3.6344393 179.7387 1.0127090
sigmaState[2]  0.1512111 0.01235371  0.13011098  0.1792376 505.3003 1.0072224
sigmaState[3]  0.5533958 0.06956030  0.42860917  0.6942975 127.1683 1.0104506
muZero[1]     38.9935236 4.43705481 29.94447755 47.3272554 475.1568 1.0003497
muZero[2]      0.3831252 0.24216710  0.02598954  0.8897831 549.1964 1.0075318
muZero[3]      1.3353684 0.74665199  0.09320793  2.8533133 532.9074 0.9982005
mu[1,1]       39.0017912 3.40690357 32.61380346 45.4170545 466.4055 1.0005448
mu[1,2]       34.2043599 3.10276557 28.27472529 40.6449492 542.5005 0.9986699
mu[1,3]       33.3712065 3.16538799 27.53473722 39.7055171 473.4783 0.9971092
mu[1,4]       30.4860268 3.10364818 24.29522326 36.8529825 476.6611 0.9997451
mu[1,5]       27.0637477 3.24676803 20.93094314 33.8924692 545.4625 1.0031939
mu[1,6]       23.0841939 3.27929378 16.33959523 29.0162999 566.2405 0.9963408
mu[1,7]       21.3165559 3.30246586 14.41639328 27.3339473 535.6559 1.0008654
mu[1,8]       20.8365971 3.29943299 14.24204392 26.9673208 600.0000 1.0043029

<以下略>

収束にも問題なさそうです。一応shinystanのDIAGNOSEから、Neffとかを見ておきましょう。

RhatとMCSEとNeffについて、全てNoneが表示されていたらOKです。

#shinystan
launch_shinystan(fit)

推定結果の可視化

さて、結果の図を出していきます。stanの推定結果からsampleを抜き出してきて、 95%信用区間とかを計算していきます。

汚いコードですみません。あんまり参考にしないでください。

お疲れモードです。

#MCMCサンプルを取り出してくる。
mu<-rstan::extract(fit)$mu
#3次元のデータをデータフレームとして二次元に縮約
d2<-data.frame(mu)
#事後分布の平均、95%上限下限を求める
up<-c()
lo<-c()
M<-c()
for( i in 1:length(d2)){
    up[i]<-quantile(d2[,i],0.975)
    lo[i]<-quantile(d2[,i],0.025)
    M[i] <-mean(d2[,i])
}

#列の名前を取ってくる
df.name<-names(d2)
#Transrate row to col
d3<-data.frame(t(d2))
#col bind
t.d3<-cbind(df.name,d3)
#quantile bind
df.q<-cbind(M,up,lo,t.d3)
df.q$lo[df.q$lo<0]<-0
#要約値だけを取り出す
df.comp<-df.q[,c(1,2,3,4)]

#col name separation  2次元縮約された変数を、ロングデータに展開
df.comp2<-df.comp %>% separate(df.name, c("Nvar","mu"))
df.comp2$mu<-as.numeric(df.comp2$mu)
df.plot<-df.comp2

#検索クエリのヒット数だけ抜いてくる
hits<-Dat[,c("hits")]
str(hits)
raw<-melt(hits)
str(raw$value)
raw$Nvar <- rep(c("X1","X2","X3"),each=682)
raw$mu  <- rep(c(1:682))
merge_dat<-merge(df.plot, raw, by=c("Nvar","mu"), all=T) 
merge_dat$Nvar<-rep(c("人工知能","ディープラーニング","機械学習"),each=682)
df.plot.comp<-merge_dat

#ggplotでプロットしていく。最初に、全てのグラフに共通するスタイルとフォントを決める
ggplot()+theme_set(theme_bw(base_size = 14,base_family="HiraKakuProN-W3"))
g<-ggplot(df.plot.comp)
g<-g+geom_line(aes(x=mu,y=value,colour=Nvar),size=0.3,alpha=0.3)

#95%信用区間をエラーバーで示す
g<- g+ geom_errorbar(aes(x=mu,ymin=lo,ymax = up),colour="gray40",alpha=0.15,size=1)
g<-g+geom_point(aes(x=mu,y=value,colour=Nvar),alpha=0.3,size=0.5)
g<-g+geom_line(aes(x=mu,y=M,colour=Nvar),size=0.75)
g<-g+scale_y_continuous(expand = c(0,0))
g<-g+scale_x_continuous(expand = c(0,0))
g<-g+facet_wrap(~Nvar,scales="free",ncol=1) #参加者ごとに図を分けて書き出す。
g<-g+labs(colour = "検索クエリ", x = "タイムポイント", y = "検索人気度",title = GetTrend$meta) +scale_colour_brewer(palette = "Set2")

g #ggplotで可視化

よっこらせ。

LocalLevelBayes.jpeg

むむ。見にくい。

もういっちょ。

LocalLevelBayes2.jpeg

ローカルレベルモデルの展開

今回は、「真の状態muは、一次点前の状態+移動範囲の正規分布に従う」ことを仮定しました。

「2時点前と1時点前の平均+移動範囲の正規分布に従う」と仮定すれば、もう少しなめらかに状態遷移を捉えることができます。

また、「状態mu = 一次点前の状態(mu[t-1])+係数✕イベント変数(event[t])」というように定義すれば、

イベントの影響[係数]や、それをを統制したmuを得ることができます。

イベントを”曜日”や、”週末”にしてダミー変数を作れば、周期性の効果を推定&統制することもできます。

さらに、「状態muはある時点Z [t>Z]をもって、mu+deltaで定義される」というモデルを組めば、

変化点検出のモデルへと拡張することができます。

ご自身のデータと関心にあわせて、もっともっと

EnjoyStan!!

Written on January 29, 2017