내가 보려고 만든 블로그

<Bayesian> Variational Inference ( Pyro) 본문

Data Science/Bayesian

<Bayesian> Variational Inference ( Pyro)

정의김 2022. 11. 12. 15:15

앞 선 포스팅에서 pyro를 이용해 mcmc를 사용하는 법을 써보았다. mcmc의 경우 샘플링수를 충분히 늘린다면 사후분포를 잘 추정할 수 있겠지만 샘플링수가 늘어나는 만큼 사후분포를 추정하는 시간이 오래 걸릴 것이다. 

사후분포를 좀더 빠르게 추정하기 위해서 기존의 모델들이 그랬듯이 최소화할 Loss를 지정하고 이를 줄여서 근사한 해를 구하는 방식을 사용하여 사후분포를 추정할 수 있다. 이 방법이 바로 Vairational Inference = VI = 변분추론 .

VI에서는 후선 사후분포를 추정하기 위해 q라는 분포를 가정한다. 사후분포는 일반적으로 구하기에는 너무 복잡하기 때문. (notation은 당연히 다를 수 있는데 보통 q로 표현) 

이 q라는 분포를 우리가 구해야하는 사후분포에 최대한 근사하게 만들면 그것이 바로 사후분포에 대한 좋은 추정값이 될 것이다. 

사후분포와 q가 근사하다는 기준을 KL- diverence을 통해 구한다. KL- divergence가 최소가 되도록하는 q가 사후분포와 가장 근사한 분포가 되는 것. 이 KL-divergence도 구하기가 매우 어렵기 때문에 KL-divergence를 최소로 하기 위해서는 ELBO를 최대로 하면 된다 .
이 ELBO가 일종의 loss function 역할을 해줄 것. 

(KL- divergence와 ELBO에 대해서는 설명해주는 블로그 들이 매우 많으므로 그걸 참고. 이 글에서는 pyro를 통해 프로그래밍 하기 위해서 최소한의 지식만을 적어둠. ) 

 

Pyro 를 통한 프로그래밍 .

역시나 pyro Doc에 있는 예제 가져옴. 동전던지기 예제.  베이지안을 공부하면 거의 항상 처음에 나오는 앞면이 나올 확률이 50프로가 정말 맞을까 예제임.  첫번째 포스팅에서 설명했던 것들은 생략하고 적어봄. 

 

https://pyro.ai/examples/svi_part_i.html

 

SVI Part I: An Introduction to Stochastic Variational Inference in Pyro — Pyro Tutorials 1.8.2 documentation

Pyro has been designed with particular attention paid to supporting stochastic variational inference as a general purpose inference algorithm. Let’s see how we go about doing variational inference in Pyro. Setup We’re going to assume we’ve already de

pyro.ai

 

import pyro.distributions as dist

def model(data):
    # define the hyperparameters that control the Beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the Beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the Bernoulli
        # likelihood Bernoulli(f)
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

1. 먼저 앞면이 나올 확률 p 를 f = 베타(10,10)으로 가정한다.  베타분포의 두 모수 alpha, beta 값이 같을 경우 x = 0.5에서 극대값을 가지면 대칭인 분포를 가진다. 샘플을 관측하기 전에 p= 0.5 즉 , 반반이라고 가정하는 것이 가장 타당한 사전분포 일 것이다. 그리고
for i in range(len(data)):

    pyro.sample~~~

을 통해 샘플링한 값을 넣어준다. 여기에는 for을 이용해서 샘플링한 값을 반영해주었지만 앞서 사용햇드이 plate를 이용해도 상관없다. 

 

다음으로 f라는 사후분포를 추정하는데 사용할 q를 지정해줘야 한다. 이를 pyro 에서는 guide 라고 함 . 

def guide(data):
    # register the two variational parameters with Pyro.
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

1.  첫번째 포스팅에서 언급한적 있는데 "latent_fairness"라는 이름은 guide와 model 각각에서 똑같이 지정을 해주어야함. 

2. torch.param으로 지정해두면 default로 require_grad = True가 되어 그라디언트를 통해 학습이 됨 . 

 

마지막으로 모델을 만들고 fit해주면 끝. 

# set up the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 5000
# do gradient steps
for step in range(n_steps):
    svi.step(data)

1. Optimizer로는 아담을 사용ㅎ.

2. loss는 앞서 말했듯이 elbo를 사용할거고 Trace_ELbo() 를 통해 사용할 수 있다.

3. iteration은 5000번 . 

아직 data를 지정 안해주었는데 6번 앞면 이 나오고 4번 뒷면이 나온 상황을 가정할 것이다 
data= [1,1,1,1,1,1,0,0,0,0] 

 

종합하면 코드는 다음과 같음. 

import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

assert pyro.__version__.startswith('1.8.2')

# clear the param store in case we're in a REPL
pyro.clear_param_store()

# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.0))

def model(data):
    # define the hyperparameters that control the Beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the Beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the ernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    # register the two variational parameters with Pyro
    # - both parameters will have initial value 15.0.
    # - because we invoke constraints.positive, the optimizer
    # will take gradients on the unconstrained parameters
    # (which are related to the constrained parameters by a log)
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# do gradient steps
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print('.', end='')

# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

# here we use some facts about the Beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nBased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))