내가 보려고 만든 블로그

<Bayesian> MCMC, pyro 로 모델링 (HMC) 본문

Data Science/Bayesian

<Bayesian> MCMC, pyro 로 모델링 (HMC)

정의김 2022. 10. 30. 17:17

앞써서 베이지안 리그레션을 pyro로 모델링하는 방법에 대하여 포스팅 했다. svi를 이용해서 사후분포를 추론하는 방법을 다루었는데 이번에는 MCMC를 이용해 사후분포를 추정하는 것을 pyro를 이용해 모델링 하는 방법에 대하여 포스팅함. 

사실, 매우 간단함. 라이브러리를 사용하는 것은.

 

예제는 앞써서 사용했던 log_gdp를 두개의 설명변수 ( rugged ,  cont_africa , rugged* cont_africa)를 통해 모델링 하는 것을 그대로 사용하였다.

 

1. 데이터를 받아와서 전처리하는 과정 . 특별한 것은 없고 df_rugged , df_cont_africa , df_log_gdp 를 torch의 텐서로 바꿔줌. (pyro가 토치를 기반으로 만들어진 프레임워크라)

 

2. 모델 정의, 앞서 베이지안 리그레션에서 사용한 모델 그대로이다.

 

3. Variational Inferece를 사용할 때는 근사분포 q를 위해 guide를 만들어줘야 했지만 mcmc는 샘플링 방법이므로 당연하게도 필요 없다.

4. mcmc 

우선 pyro의 mcmc 는 hmc라는 방법을 사용하여 mcmc를 수행한다고 한다.

사실, 이건 처음 봤는데 기존에 metropolis hasting, gibbs sampling 등에 비하여 연산량이 많이 들지만 대신 적은 샘플링만으로도 꽤 훌륭하게 사후분포를 추정하는 것이 가능하다고 한다. 다만 단점으로는 parameter의 space가 conitnuous 하지 않고 discrete할 경우에는 사용 할 수 없다고 한다 . 이럴때에는 깁스샘플링을 사용해야할 것 같은데 깁스샘플링은 pyro 에서 제공하지 않아서 따로 작업을 조금해주어야한다. 다행히도 example이 있어서 후에 깁스샘플링도 올려보겠음!

그리고 첫번째 줄에 보면 nuts라는 것이 있는데 no u-turn sampling 으로 hmc를 수행할 때 더 효율적으로 수행할 수 있다고 한다. 

kernel = NUTS(model) # no u-turn sampling 
mcmc = MCMC(kernel, num_samples=1000)

mcmc.run(df_rugged, df_cont_africa, df_log_gdp) # hmc 를 사용

 

4. 추론한 결과에 대하여 다음과 같이 확인 가능하다. 

각 parameter에 대한 확률분포

 

https://github.com/todtjs92/Bayesian/blob/master/pyro/1.MCMC.ipynb

 

GitHub - todtjs92/Bayesian: Pyro 또는 torch 를 이용한 베이지안 프로그래밍

Pyro 또는 torch 를 이용한 베이지안 프로그래밍. Contribute to todtjs92/Bayesian development by creating an account on GitHub.

github.com