Variational Inference (Part 1): Variational Inference & CAVI

Table of Contents

Overview

In this post I will present to you what's known as variational inference (VI), a family of approximate Bayesian inference methods. In particular, we will focus on one of the more basic VI methods called coordinate ascent variational inference (CAVI). We will go through the theory for VI, and apply CAVI to the case of a Gaussian mixture, following and adding to what's presented in blei16_variat_infer. With the theory clear we will go through an implemention in Turing.jl as a simple (and slightly contrived) example of how to implement "custom" variational inference methods. Hopefully you'll learn something. The goal is that you walk way with (in descending order of importance):

  1. A basic understanding of VI and why we bother
  2. A general idea of how to use the (work in progress) VI interface in Turing.jl
  3. A vague familiarity with CAVI

Just a heads up: in contrast to the previous post this is one is less focused on memes, involving a fair bit of mathematics and code. At ripe age of 24, there's certainly one thing I've learned: memes can only get you so far in life.

Notation

  • VI abbreviates variational inference
  • CAVI abbvr. coordinate ascent VI
  • 2019-06-25-variational-inference-part-1-cavi_5403f6bbface4889f05450c96efe5dfafd041d71.png denotes the number of data
  • 2019-06-25-variational-inference-part-1-cavi_2379137a70f55d3e67efab50fe29f05dbee4ce3a.png denotes a set of data
  • Evidence lower bound (ELBO):

    2019-06-25-variational-inference-part-1-cavi_c10fdef1e34636c0d0cc82ed4266306cf24cded8.png

    and

    2019-06-25-variational-inference-part-1-cavi_b9706b8fff3aadef76992524a882a153b273e12a.png

Motivation

In Bayesian inference one usually specifies a model as follows: given data 2019-06-25-variational-inference-part-1-cavi_03d068b4949ffc917c6ef07dd7455c00089aa282.png,

2019-06-25-variational-inference-part-1-cavi_4c7f5c78d51648152e123790485b2011df30b132.png

where 2019-06-25-variational-inference-part-1-cavi_5179e66dd4274af2f7d0b87eb3ace1f19d37414d.png denotes that the samples are identically independently distributed. Our goal in Bayesian inference is then to find the posterior

2019-06-25-variational-inference-part-1-cavi_aa4c76f4dc318e4d4bb2f5a7c65e115b4a4dcb02.png

In general one cannot obtain a closed form expression for 2019-06-25-variational-inference-part-1-cavi_897c7fada95f8af980ffc865424400e015b13d3a.png, but one might still be able to sample from 2019-06-25-variational-inference-part-1-cavi_897c7fada95f8af980ffc865424400e015b13d3a.png with guarantees of converging to the target posterior 2019-06-25-variational-inference-part-1-cavi_897c7fada95f8af980ffc865424400e015b13d3a.png as the number of samples go to 2019-06-25-variational-inference-part-1-cavi_8135ffb7ca4c5b92ced3fcd4204b8705472e3c5c.png, e.g. MCMC.

"So what's the problem? We have these nice methods which we know converge to the true posterior in the infinite limit. What more do we really need?!"

Weeell, things aren't always that easy.

  1. As the model 2019-06-25-variational-inference-part-1-cavi_b346b8eaf5913ed81b033dacc1d6568c142bbb4a.png becomes increasingly complex, e.g. more variables or multi-modality (i.e. separate "peaks"), convergence of these unbiased samplers can slow down dramatically. Still, in the infinite limit, these methods should converge to the true posterior, but infinity is fairly large, like, at least more than 12!
  2. Large amounts of data means more evaluations of the log-likelihood which, if it's sufficiently expensive to compute 2019-06-25-variational-inference-part-1-cavi_ab20ba1962a9eb230e624066148820ba492659f0.png, can make direct sampling infeasible.

In fact, there is also an application of VI in the context of Bayesian model selection, but we will not dig into that here.pmlr-v80-chen18k

Therefore it might at times be a good idea to use an approximate posterior, which we'll denote 2019-06-25-variational-inference-part-1-cavi_fe00df71e7543e86c03dfb375840b14ae64dba82.png. Clearly we'd like it to be close to target 2019-06-25-variational-inference-part-1-cavi_b346b8eaf5913ed81b033dacc1d6568c142bbb4a.png, but what does even "close" mean in the context of probability densities?

This question really deserves much more than a paragraph and is indeed quite an interesting topic. There are plenty of approaches to take in this case, some come down to putting actual metrics (defining a "distance") on spaces of probability densities, e.g. Wasserstein / Earth-mover distance, while other approaches are less strict, e.g. Kullback-Leibler (KL) divergence which, compared to a metric, is not symmetric in its arguments.

Objective

For now we will consider the KL-divergence, but as we will probably see in later posts, there are several viable alternatives which can also be employed in VI. KL-divergence has the nice property that it's equivalent to maximizing the likelihood.

With KL-divergence as this measure of how "good" of an approximation our approximate posterior 2019-06-25-variational-inference-part-1-cavi_fe00df71e7543e86c03dfb375840b14ae64dba82.png is to 2019-06-25-variational-inference-part-1-cavi_6e68cb0ca7095a69db7adc49eeec5ab146d7134c.png, the objective is to find the optimal 2019-06-25-variational-inference-part-1-cavi_2b6f58b235cda96698852897ddd4abf7d1473a77.png where 2019-06-25-variational-inference-part-1-cavi_27e16d785f7ae685d71efb97ad73aa13b3ea32ac.png is space of densities we have chosen to consider:

2019-06-25-variational-inference-part-1-cavi_c080302dfb89af9f748ca75b4c02cc2a376ec18b.png

where the KL-divergence is defined

2019-06-25-variational-inference-part-1-cavi_0d37c53b9cdc1028ffdf63526d7225668bea4fd6.png

So the idea is to construct some approximate posterior 2019-06-25-variational-inference-part-1-cavi_9974194a836f8c06efda1c78a55ab184c6c2d746.png which it's feasible to sample from, allowing us to do approximate Bayesian inference. But there's a number of practical issues with the above objective:

  1. We can't evaluate 2019-06-25-variational-inference-part-1-cavi_37bf5958eebb9907bab22437e5bfe74e4452164f.png! Evaluating 2019-06-25-variational-inference-part-1-cavi_37bf5958eebb9907bab22437e5bfe74e4452164f.png is the entire goal of Bayesian inference; if we could simply evaluate this, we would be done.
  2. Integrating over all 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png is of course in general intractable.

In the case of KL-divergence, the standard approach to address (1) is to instead optimize what's know as the evidence lower bound (ELBO).

The famous ELBO1

Using the fact that 2019-06-25-variational-inference-part-1-cavi_c147244f76e7b608dd6810bf6da1af0c645b6df7.png, the RHS of the KL-divergence becomes

2019-06-25-variational-inference-part-1-cavi_04616c291d424a379530bfb9fa2af2d270aac8bd.png

where we've simply brought the 2019-06-25-variational-inference-part-1-cavi_2d71aa14217932e7c1dd02bbf12c38f25b05654c.png term outside of the expectation since 2019-06-25-variational-inference-part-1-cavi_edba165435c97711c6c316b9b57b45649434e403.png is independent of 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png. We can therefore re-arrange the original expression to

2019-06-25-variational-inference-part-1-cavi_d767b943a2cf5b53bddd5ffef8ff0690a9269a30.png

This then implies that if we maximize the second term on the RHS we're implicitly minimizing the KL-divergence between 2019-06-25-variational-inference-part-1-cavi_37bf5958eebb9907bab22437e5bfe74e4452164f.png and 2019-06-25-variational-inference-part-1-cavi_fe00df71e7543e86c03dfb375840b14ae64dba82.png! This term is called the evidence lower bound (ELBO), stemming from the fact that 2019-06-25-variational-inference-part-1-cavi_048f6c98e916ed8890f70a26012e2f994b6f56f5.png is often referred to as the "evidence" in the Bayesian literature and it turns out that the KL-divergence is indeed non-negative, i.e. 2019-06-25-variational-inference-part-1-cavi_ca91cfb8277cbeb1d4cf27badb533dad1b9dbcb1.png. We therefore define

2019-06-25-variational-inference-part-1-cavi_c10fdef1e34636c0d0cc82ed4266306cf24cded8.png

This still requires us to feasibly compute 2019-06-25-variational-inference-part-1-cavi_f07b4155a393eae07c6e121526e0674b5310f633.png!

Also, the fact that this lower bounds the evidence 2019-06-25-variational-inference-part-1-cavi_048f6c98e916ed8890f70a26012e2f994b6f56f5.png is how one can use VI to do model selection, but again, not getting into that here.

Since the expectation is linear and we're assuming independence, if we want to take into account all the samples in 2019-06-25-variational-inference-part-1-cavi_4896aad829445b20849ca132a1e290bbd1d47df3.png, we simply sum these contributions to get

2019-06-25-variational-inference-part-1-cavi_5ba27d6fb5868e5734d4dd116fc1ed20d7b58440.png

in the case where we use the same 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png for the different samples.2

Awesome. We now have an objective to maximize, whose maximizer then gives us the "optimal" approximate posterior 2019-06-25-variational-inference-part-1-cavi_fe00df71e7543e86c03dfb375840b14ae64dba82.png from which we can sample from.3

But computing 2019-06-25-variational-inference-part-1-cavi_3d20895be87bfffabe9d5dd2ecea45a7d87139b7.png involves computing an expectation, which of course, again, requires evaluation of a (usually) intractable integral, i.e. 2. in the problems for the KL-divergence noted in previous section. Luckily, in this case we can simply substitute the expectations which are intractable with the empirical expectation, i.e. if we want to estimate the expectation of some function 2019-06-25-variational-inference-part-1-cavi_f017a9b0f9e8a176a3db97e0116f98ff496ac318.png wrt. a density 2019-06-25-variational-inference-part-1-cavi_048f6c98e916ed8890f70a26012e2f994b6f56f5.png, then

2019-06-25-variational-inference-part-1-cavi_4ccddddcd0d905f7428e2d4bccf14735c8bfa732.png

which, by Strong Law of Large Numbers (SLLN), this converges to the true expectation in the infinite limit. We can apply this to the ELBO:

2019-06-25-variational-inference-part-1-cavi_5f59e1a798278ba065c2aaba68c82396f007ef62.png

which is much more of an approachable problem. Granted, this makes our objective non-determinstic / stochastic / random / wibbly-wobbly, which in turn means that our posterior estimate will have much higher variance. This usually means slower convergence, and possibly even complete failure to converge.

In fact, the approach to estimating the 2019-06-25-variational-inference-part-1-cavi_ae6e662081e287949ed603dc8f9faa9a9997cbd7.png is often what distinguishes VI approaches. We will discuss this further in the future; it's even likely that we'll dedicate one entire post only to the discussion / comparison of different approaches to this estimation.

In the particular example we will consider now, we actually have access to a closed form expression for 2019-06-25-variational-inference-part-1-cavi_0810dffa96937613ef0d06e4fd638c47e098e122.png, and therefore only need to make use of the empirical estimation to the first term of the ELBO: 2019-06-25-variational-inference-part-1-cavi_a4784f36cb2c74752531b8310ea19932e4772258.png. As mentioned above, this will lead to reduced variance in our updates and faster convergence.

How do I use this in Turing.jl?

Before diving into a particular example, let's just quickly go through the VI interface in Turing.jl.

Unfortunately VI is not officially supported yet, but you can get the WIP version if you checkout my VI branch tor/vi-v3.4 All code that follows is assuming you indeed have done that and that you've activated the local environment by doing ]activate(".") in the repository directory.

using Turing
using Turing: Variational

Just one more time: this is work in progress! What I present to you here is basically the first draft and thus unlikely to survive for long. But I put it here for you to see; maybe you even have ideas on how to improve it! If so, hit me up on the Julia slack channel @torfjelde :)

VI in Turing.jl is mainly supposed to be used as an approximate substitute for sample(model, sampler) interface:

?Variational.vi
vi(model::Model, alg::VariationalInference)
vi(model::Model, alg::VariationalInference, q::VariationalPosterior)

Constructs the variational posterior from the model and performs the
optimization following the configuration of the given VariationalInference
instance.

Therefore, in general we'll have the following psuedo-code:

# 1. Instantiate `Turing.Model` with observations
m = model(data)

# 2. Construct optimizer and VI algo.
opt = Variational.ADAGrad()                   # proxy for `Flux.Optimise`
alg = Alg(args...)                            # <= `Alg <: VariationalInference`

# 3. Perform variational inference
q = vi(m, alg; optimizer = opt)               # =>  <: `VariationalPosterior`
q = vi(m, alg, q_before_opt; optimizer = opt) # => optimized `q <: VariationalPosterior`
rand(q, 1000)

If q is not provided, Turing will extract the latent variables in your model and construct a VariationalPosterior (i.e. a Distribution{Multivariate, Continuous}) and optimize the corresponding objective, if the VariationalInference method indeed has a default family.

The most significant difference from the existing sample(...) interface is that vi(...) returns a VariationalPosterior which we can call rand on to obtain approximate posterior samples, while sample returns a Chains object from MCMCChains containing the posterior samples. Currently VariationalPosterior is simply an alias for Distribution{Multivariate, Continuous} from Distributions.jl. The idea is that the variational posteriors should act as standard distributions, implementing both rand and logpdf. This way, any multivariate distribution from Distributions.jl will work (if supported by the method).

An VariationalInference instance is supposed to hold the "configuration" for a VI method, e.g. number of samples used in the empirical expectation for the estimate of the ELBO. Additionally, even though two VariationalInference types have exactly the same struct fields, they might still have very different behaviour.

The following methods might have different behaviour based on the VariationalInference type:

  • optimize!:
    • If the VI method is not gradient-based (e.g. CAVI as we'll soon see), then this is where the main chunk of work occurs, and thus is the method to implement.
    • If the VI method is gradient-based, optimize! will usually contain the "training-loop", where the gradients are obtained by calling grad!. Then the parameters are updated by applying the suitable update rule decided by the optimizer used, e.g. ADAGrad.
      • There are also certain cases where we need to do more than just update the variational variables every iteration, e.g. min-max steps needed for certain objectives liu16_stein_variat_gradien_descen.
    • Default implementation is provided which simply takes a given number of optimization steps by calling grad! and apply!(optimizer, ...).
  • grad!: If the VI method is gradient-based, then this is usually where the heavy lifting is done. As mentioned earlier, the gradient estimator is often the distinguishing factor between different VI methods, e.g. reparameterization trick (next posts guys; calm down).
    • Default implementations are provided by using ForwardDiff and Flux.Tracker to differentiate the objective call, depending on which AD-type is specified for the VariationalInference{AD} instance.

And of course you we need to implement some objective function, e.g. ELBO. Internally in Turing.jl we define types and implement calling an instance, e.g. struct ELBO end and a (elbo::ELBO)(...). This allows for different VariationalInference methods to have different behaviour for different objectives. If you're interested in implementing your own VI objective maybe using a simple function, then you can just specialize the optimize! call with a obj::typeof(objective_function) as argument.

?Variational.optimize!
optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = ADAGrad())

Iteratively updates parameters by calling grad! and using the given
optimizer to compute the steps.

?Variational.grad!
grad!(vo, vi::VariationalInference, q::VariationalPosterior, model::Model, θ, out, args...)

Computes the gradients used in optimize!. Default implementation is provided
for VariationalInference{AD} where AD is either ForwardDiffAD or TrackerAD.
This implicitly also gives a default implementation of optimize!.

Variance reduction techniques, e.g. control variates, should be implemented
in this function.

Aight, aight; that's a lot of information. Let's do an example!

Example: 1D Gaussian Mixture

Notation

  • 1D Gaussian mixture:

    2019-06-25-variational-inference-part-1-cavi_4844c5a5da39deeb6c5435d46616b988f174f814.png

    where 2019-06-25-variational-inference-part-1-cavi_afff5e4a964db8f1a5d7b14e41e981ed24b7779e.png

  • Variational distribution

    2019-06-25-variational-inference-part-1-cavi_5d86866ea0451581d45283689271a473af4e4c6f.png

    where

  • 2019-06-25-variational-inference-part-1-cavi_ef8353830be8e017ae0b98c5c37ad3c2b43889c2.png is a Gaussian with mean 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png and variance 2019-06-25-variational-inference-part-1-cavi_ecbdc4ce3de5acbc3faf63535cf687fc8987c25a.png for 2019-06-25-variational-inference-part-1-cavi_4a9d41329be445dfda57b0275c43ecbe7f1442f0.png
  • 2019-06-25-variational-inference-part-1-cavi_e2b6cb82e2f2dbcf424839a7643e2ec3c238e5ac.png is a Categorical distribution where 2019-06-25-variational-inference-part-1-cavi_934c24dc230da8e2fbab1fffb773717da4ece5f4.png (a K-dim vector) s.t. 2019-06-25-variational-inference-part-1-cavi_6fa5792006d0d41bcc934ffcf10d9503ca80a5f3.png
  • 2019-06-25-variational-inference-part-1-cavi_ab7d542ae6d44aebfa144db3ef0e67eb83090aeb.png denote 2019-06-25-variational-inference-part-1-cavi_9974194a836f8c06efda1c78a55ab184c6c2d746.png is a parameterized density of 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png with param 2019-06-25-variational-inference-part-1-cavi_eb13eedc64828a894bed17ed8b010e5fea060eff.png.
  • 2019-06-25-variational-inference-part-1-cavi_b83ea156588fec75856a7276b25c5fa05da55c89.png means we assume 2019-06-25-variational-inference-part-1-cavi_eb13eedc64828a894bed17ed8b010e5fea060eff.png is also random variable itself, i.e. latent variable in Bayesian inference setting.
  • Expectation of some function 2019-06-25-variational-inference-part-1-cavi_142c5f350072a393e80de30083151077316c979f.png wrt. density 2019-06-25-variational-inference-part-1-cavi_165ca9304fe2167e21f8e1d42dec893850d712d3.png is denoted

    2019-06-25-variational-inference-part-1-cavi_0c6add97637a3b542b598aec7c87eaf88274234c.png

Definition

Let's apply VI to a simple Gaussian mixture problem, because why not? We will closely follow blei16_variat_infer with a slight variation (other than notational differences): we allow non-uniform mixture weights, i.e. being assinged to cluster 2019-06-25-variational-inference-part-1-cavi_bb118e347933fbc8dbe1012095a954c8fff97f94.png might have a different probability than being assinged to cluster 2019-06-25-variational-inference-part-1-cavi_e75af26294e87c419eaa035752a6c4d7539d4133.png.

This example lends itself nicely as a base-case for VI as we can derive closed form expressions for the updates. Though it takes a bit of work, it nicely demonstrates the underlying mechanisms necessary for VI and demonstrates why we really, really want more general VI approaches :) Aight, let's get started.

For a 1D Gaussian mixture, one particular generative model can be

2019-06-25-variational-inference-part-1-cavi_4844c5a5da39deeb6c5435d46616b988f174f814.png

where

2019-06-25-variational-inference-part-1-cavi_3ca830dc5387a82403ef4abaeaf55c1c962694cf.png

In this case we then have

2019-06-25-variational-inference-part-1-cavi_7e7d4e5b53184b5e52646e300351e20063b63e08.png

where 2019-06-25-variational-inference-part-1-cavi_993c838215c0a1f5213700aac34101133fd2315c.png denotes summing over all possible values 2019-06-25-variational-inference-part-1-cavi_64a587acefac95d72fb90156daa7779bb32ab6d1.png can take on. Following blei16_variat_infer we can actually rewrite 2019-06-25-variational-inference-part-1-cavi_048f6c98e916ed8890f70a26012e2f994b6f56f5.png to

2019-06-25-variational-inference-part-1-cavi_45b64fd9bade619be60ae340ec2c2e763f14a3b8.png

The conjugacy of the Gaussian prior 2019-06-25-variational-inference-part-1-cavi_98b88f9b79fe9fd933c6a80e593ec828cc00fb2c.png and Gaussian likelihood 2019-06-25-variational-inference-part-1-cavi_900ade0ca2561c190505ff2a9939a1d004353d65.png, i.e. that Gaussian prior and Gaussian likelihood gives a Gaussian posterior, we can indeed obtain a closed form expression for this. Even so, there are then 2019-06-25-variational-inference-part-1-cavi_f27f7267f731fc4d442d1f64209c81691dcb28d7.png number of these terms since we're summing over all the different cluster assignments: 2019-06-25-variational-inference-part-1-cavi_6f5e48c2022723307aafa4cc44933aa17d28e91d.png and 2019-06-25-variational-inference-part-1-cavi_2109a37107ee89e94c82381eaa6d34b8a2bdc051.png gives 2019-06-25-variational-inference-part-1-cavi_f27f7267f731fc4d442d1f64209c81691dcb28d7.png terms. Thus, for the sake of efficiency, we reach for VI!

Coordinate ascent (mean-field) variational inference (CAVI)

As noted earlier, to perform VI we need to specify a variational family 2019-06-25-variational-inference-part-1-cavi_27e16d785f7ae685d71efb97ad73aa13b3ea32ac.png. To make life easy for ourselves, we will make what's known as the mean-field assumption, i.e. the latent variables are independent so 2019-06-25-variational-inference-part-1-cavi_d4c44ca4684c8f12406efe5d89ac84a5843a3f85.png means that

2019-06-25-variational-inference-part-1-cavi_e5176051010bf998cbf9464311547dd0980dc579.png

for some independent densitites 2019-06-25-variational-inference-part-1-cavi_6b31c36a145dbffb1ccbd43d16e70eafe873a24c.png. To choose these 2019-06-25-variational-inference-part-1-cavi_603b3cfceaf8a537967899b2c0ca88a0c9652c7e.png, a natural thing to do is to take the densities used in the original model and apply the mean-field assumption to this model, i.e.

2019-06-25-variational-inference-part-1-cavi_5d86866ea0451581d45283689271a473af4e4c6f.png

where

  • 2019-06-25-variational-inference-part-1-cavi_ef8353830be8e017ae0b98c5c37ad3c2b43889c2.png is a Gaussian with mean 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png and variance 2019-06-25-variational-inference-part-1-cavi_ecbdc4ce3de5acbc3faf63535cf687fc8987c25a.png for 2019-06-25-variational-inference-part-1-cavi_4a9d41329be445dfda57b0275c43ecbe7f1442f0.png
  • 2019-06-25-variational-inference-part-1-cavi_e2b6cb82e2f2dbcf424839a7643e2ec3c238e5ac.png is a Categorical distribution where 2019-06-25-variational-inference-part-1-cavi_934c24dc230da8e2fbab1fffb773717da4ece5f4.png (a K-dim vector) s.t. 2019-06-25-variational-inference-part-1-cavi_6fa5792006d0d41bcc934ffcf10d9503ca80a5f3.png

Recall we want to find 2019-06-25-variational-inference-part-1-cavi_a6f97f9c6951a838233eda678a6b02c180a83075.png s.t. we maximize the 2019-06-25-variational-inference-part-1-cavi_8aa5cd31a08eec68d0718d66ee31fcb324a207a1.png. To do this, we're going to use one of the most basic VI methods called coordinate ascent variational inference (CAVI). This works by iteratively optimize each of the latent variables conditioned on the rest of the latent variables, similar to the iterative process of Gibbs sampling (but of course, doesn't involve any optimization).

Observe that we can write the ELBO under the mean-field assumption as

2019-06-25-variational-inference-part-1-cavi_bf7bc428ddfdcde6cb88ce3b7616a6236352efe9.png

where 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png denotes the parameters of 2019-06-25-variational-inference-part-1-cavi_8b34a5842f09b033248a46e3623d35bc77033a09.png. Then observe that

2019-06-25-variational-inference-part-1-cavi_81419af05e1a8d7bda6aac431856770e35d01232.png

Therefore we can also absorb the second term above into the constant wrt. 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png term:

2019-06-25-variational-inference-part-1-cavi_bd0aebe999f42e6df582524bcfe32145fc8c9cf4.png

We therefore make the justified guess

2019-06-25-variational-inference-part-1-cavi_24f6b32bcf273dea0684ee6bd8ba7d10ec82ce0f.png

(up to a constant normalization factor) which means that if we substitute into the above expression for ELBO, the two first terms cancel eachother out, and we're left with

2019-06-25-variational-inference-part-1-cavi_4f6160b73b5d908d8438a9bce3f47671544d0c24.png

If 2019-06-25-variational-inference-part-1-cavi_603b3cfceaf8a537967899b2c0ca88a0c9652c7e.png was parameterized by some parameter 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png (and differentiable wrt. 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png), we could differentiate wrt. 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png and see that the LHS vanishes completely when 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png is set to satisfy the 2019-06-25-variational-inference-part-1-cavi_8eed54548e9ec0473c2c02bdafd10ad0f13f9cbb.png above (assuming such a 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png exists of course). At the very least, this suggests an iterative process to maximize the ELBO:

2019-06-25-variational-inference-part-1-cavi_ebd64eeb5b509617a4fe0ae37d6987195a847180.png

And indeed this converges to a local optimum blei16_variat_infer.

Now, we still need to compute 2019-06-25-variational-inference-part-1-cavi_b53497307f3bb346cac8c4a606be3c1d3c63c895.png. But of course, the entire reason for choosing this model and variational family is because in this case these expressions are available in closed form:

2019-06-25-variational-inference-part-1-cavi_695c8b77eb123496a0e1afbced4b86673691cc9e.png

where we have used the fact that if 2019-06-25-variational-inference-part-1-cavi_6e1d3ff5d22a90c650dd418e4eb7bc0f612ce851.png then

2019-06-25-variational-inference-part-1-cavi_96ab2ee4ef31551b81c05bcb25d422817df9e708.png

The derivations of these can be found in the appendix (or without 2019-06-25-variational-inference-part-1-cavi_9f656ddb23a905a01ee310dcfd4c1da139a1024a.png in blei16_variat_infer).

This example is a special case of what's known as conditionally conjugate models with "local" and "global" variables.

  • A variable is local if only a single data point depends on this variable
  • A variable is global if multiple data points depend on this variable

We use this terminology of "global" and "local" for both the latent variables and the variational variables. E.g.

  • Latent:
    • Local: cluster assignments 2019-06-25-variational-inference-part-1-cavi_64a587acefac95d72fb90156daa7779bb32ab6d1.png
    • Global: cluster means 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png
  • Variational:
    • Local: parameters of cluster assignments 2019-06-25-variational-inference-part-1-cavi_ef555ff2069ebd01f48d11b2c36e6732ffde9266.png
    • Global: means 2019-06-25-variational-inference-part-1-cavi_1f16af3cf4c8f31242256354a1103d47a6c017f6.png and variances 2019-06-25-variational-inference-part-1-cavi_ecbdc4ce3de5acbc3faf63535cf687fc8987c25a.png for the different clusters

In a general conditionally conjugate model, say 2019-06-25-variational-inference-part-1-cavi_31aa0b82f4e4fe4be2c026d5eddd655a517a7f53.png is the global latent variables and 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png is local latent variables, then the joint density is given by

2019-06-25-variational-inference-part-1-cavi_e6402ae35f5cf81b992d585992458d7db046f737.png

In the case where these factors are also in the exponential family, closed form updates for CAVI can be derived in a similar way as we did for the Gaussian mixture model above.blei16_variat_infer

Implementation

Now, we before we start, remember that we're trying to get a more efficient way of sampling from our model, right? We're not trying to infer the parameters that generated the data, but instead working under the assumption the data came from the given model with known parameters (which in this is of course true) and we want to produce approximate samples from this model. With that in mind, let's begin!

First we import some necessary packages and set the random seed for reproducibility.

using Random
Random.seed!(1)

using Plots, StatsPlots, LaTeXStrings

The Turing.Model can then be defined

σ² = 1.0

K = 3               # number of clusters
n = 1000            # number of samples

@model mixture(x =  Vector{Float64}(undef, 1), μ = Vector{Float64}(undef, K), π = (1 / K) .* ones(K)) = begin
    for i = 1:K
        μ[i] ~ Normal(0.0, σ²)
    end

    c = tzeros(Int, length(x))
    for i ∈ eachindex(x)
        c[i] ~ Categorical(π)

        x[i] ~ Normal(μ[c[i]], 1)
    end
end
mixture (generic function with 4 methods)

This might look a bit complicated, but it really isn't.

  • K is the number of clusters
  • x = Vector{Float64}(undef, 1) means that if we don't specify any data x when instantiating the Model, x will be treated as a random variable to be sampled rather than as an observation. As we will see in a bit, this allows us to generate samples for x from the Model!
  • μ = Vector{Float64}(undef, K) will hold the means of the different clusters; again, the undef will simply allow us to choose between fixing or sampling μ
  • π will denote the weights of the clusters, i.e. π[k] is the probability of assignment to cluster k. Setting it equal to the number of clusters K assumes uniform cluster weights.
  • tzeros is simply a initialization procedure for latent random variables which is compatible with Turing's particle samplers (we're going to use Particle gibbs PG). See Turing's documentation for more info.

Let's produce some samples from the model with 2019-06-25-variational-inference-part-1-cavi_5851361b9629632a09b97aaaa7decb2847d7c755.png, 2019-06-25-variational-inference-part-1-cavi_f873b75d8f6bca5dc5ace58f2488c184faad6ca8.png and 2019-06-25-variational-inference-part-1-cavi_e3c947b61ef3f5a7cfe931a6d4470b59592b45c3.png together with the cluster weights 2019-06-25-variational-inference-part-1-cavi_eff4425bef531ee55a8539cf115bc40eeacbd90d.png, 2019-06-25-variational-inference-part-1-cavi_0211c12c3ae50453b85e8cf9341d461cdef27a2c.png and 2019-06-25-variational-inference-part-1-cavi_63f226b2259afa8f4637876a31565e54762ff838.png, e.g. there's a probability 2019-06-25-variational-inference-part-1-cavi_eff4425bef531ee55a8539cf115bc40eeacbd90d.png of being assigned cluster 2019-06-25-variational-inference-part-1-cavi_6b314f00387982277c754a67e9002cc0f3dd7144.png which is normally distributed with mean 2019-06-25-variational-inference-part-1-cavi_5851361b9629632a09b97aaaa7decb2847d7c755.png.

# sample from the model
π = [0.1, 0.2, 0.7] # cluster weights
mix = mixture(μ = [-5.0, 0.0, 5.0], π = π)
samples = sample(mix, PG(200, n));
[PG] Sampling...  0%  ETA: 0:44:51[PG] Sampling...  4%  ETA: 0:01:49[PG] Sampling...  8%  ETA: 0:01:03[PG] Sampling... 12%  ETA: 0:00:48[PG] Sampling... 16%  ETA: 0:00:39[PG] Sampling... 20%  ETA: 0:00:34[PG] Sampling... 24%  ETA: 0:00:30[PG] Sampling... 27%  ETA: 0:00:28[PG] Sampling... 31%  ETA: 0:00:25[PG] Sampling... 35%  ETA: 0:00:23[PG] Sampling... 39%  ETA: 0:00:21[PG] Sampling... 43%  ETA: 0:00:19[PG] Sampling... 47%  ETA: 0:00:17[PG] Sampling... 51%  ETA: 0:00:16[PG] Sampling... 55%  ETA: 0:00:14[PG] Sampling... 59%  ETA: 0:00:13[PG] Sampling... 63%  ETA: 0:00:12[PG] Sampling... 67%  ETA: 0:00:10[PG] Sampling... 71%  ETA: 0:00:09[PG] Sampling... 75%  ETA: 0:00:08[PG] Sampling... 79%  ETA: 0:00:06[PG] Sampling... 82%  ETA: 0:00:05[PG] Sampling... 86%  ETA: 0:00:04[PG] Sampling... 90%  ETA: 0:00:03[PG] Sampling... 94%  ETA: 0:00:02[PG] Sampling... 98%  ETA: 0:00:01[PG] Sampling...100% Time: 0:00:29
┌ Info: [PG] Finished with
└ @ Turing.Inference /home/tor/Projects/mine/Turing.jl/src/inference/AdvancedSMC.jl:195
┌ Info:   Running time    = 28.558703612000002;
└ @ Turing.Inference /home/tor/Projects/mine/Turing.jl/src/inference/AdvancedSMC.jl:196

plot(samples[:x], size = (1000, 400))

Sorry, your browser does not support SVG.

Looks nice! Let's extract the cluster assignments and the observations:

cluster_assignments = vec([v for v ∈ samples[:c].value])
x = vec([e for e ∈ samples[:x].value])
1000-element Array{Float64,1}:
  6.147820300612501  
  0.4663086943131637 
 -0.6258083879904177 
  5.724759360321863  
  4.852150430501788  
  8.069098838605914  
  6.02043226784142   
  5.25245793098362   
  5.236303942371799  
  4.918764327220161  
 -0.40921684006032205
  3.727671136403713  
  5.121310585284786  
  ⋮                  
  6.344736353801734  
  0.4270545469128376 
  4.834727745113844  
 -4.668084240525111  
  4.090751518471361  
  4.557843176639102  
  3.9554058940625487 
  0.4796996702606038 
 -4.078674006599593  
 -5.560181598563516  
  5.773778693804238  
 -4.9926860296827575 

Now we implement the VariationalInference interface for Turing. First we define the VariationalInference type CAVI:

struct CAVI
    max_iter # maximum number of optimization steps
end

And how the ELBO is evaluated for this particular VI algorithm and this MeanField model:

function (elbo::Variational.ELBO)(alg::CAVI, q::MeanField, model, num_samples)
    res = 0.0

    for i = 1:num_samples
        z = rand(q)
        res += logdensity(model, z) / num_samples
    end

    return res - entropy(q)
end

Observe that here we are using an empirical estimate of the expectation 2019-06-25-variational-inference-part-1-cavi_45292e533573f568f9b9f564017e867806779e4b.png while we can compute 2019-06-25-variational-inference-part-1-cavi_92a4298888bbbd3de80d00bdcea8c259911c9a4f.png in closed form. entropy is provided by StatsBase for both Categorical and Normal, so we don't even have to implement those. Under a mean-field assumption, the entropy of q is simply the sum of the entropies of the independent variables.

Both the logdensity and MeanField definitions you can find in the appendix.

As mentioned before, we now need to at the very least define optimize! and vi. This is kind of a contrived example since the update derived before is only applicable to a model of this particular form. Because of this, we're going to instantiate the mixture with the generated data and then dispatch on typeof(model).

model = mixture(x = x, π = π)
Turing.Model{Tuple{:c,:μ},Tuple{:x},getfield(Main, Symbol("###inner_function#375#11")){Array{Float64,1}},NamedTuple{(:x,),Tuple{Array{Float64,1}}},NamedTuple{(:μ, :x),Tuple{Array{Float64,1},Array{Float64,1}}}}(getfield(Main, Symbol("###inner_function#375#11")){Array{Float64,1}}(Core.Box(getfield(Main, Symbol("###inner_function#375#11")){Array{Float64,1}}(#= circular reference @-2 =#)), [0.1, 0.2, 0.7]), (x = [6.14782, 0.466309, -0.625808, 5.72476, 4.85215, 8.0691, 6.02043, 5.25246, 5.2363, 4.91876  …  4.83473, -4.66808, 4.09075, 4.55784, 3.95541, 0.4797, -4.07867, -5.56018, 5.77378, -4.99269],), (μ = [6.93058e-310, 6.93058e-310, 6.93058e-310], x = [6.93058e-310]))

function optimize!(alg::CAVI, elbo::Variational.ELBO, q, model::typeof(model), m, s², φ)
    # ideally we'd pack and unpack the params `m`, `s²` and `φ`
    # but for sake of exposition we skip that bit here
    K = length(m)
    n = size(φ, 1)

    for step_idx = 1:alg.max_iter
        # update cluster assignments
        for i = 1:n
            for k = 1:K
                # If the cluster weights were uniform, we could do without accesing π:
                # φ[i, k] = exp(m[k] * x[i] - 0.5 * (s²[k] + m[k]^2))

                # HACK: π is global variable :/ 
                φ[i, k] = π[k] * exp(m[k] * x[i] - 0.5 * (s²[k] + m[k]^2))
            end

            φ[i, :] ./= sum(φ[i, :])  # normalize
        end

        # HACK: σ² here is a global variable :/

        # update Gaussians
        for k = 1:K
            φ_k = φ[:, k]
            denominator = (1 / σ²) + sum(φ_k)
            m[k] = sum(φ_k .* x) / denominator
            s²[k] /= denominator
        end
    end

    m, s², φ
end
optimize! (generic function with 1 method)

function Variational.vi(model::typeof(model), alg::CAVI, q::MeanField)
    discrete_idx = findfirst(d -> d isa DiscreteNonParametric, q.dists)
    K = length(q.dists[discrete_idx].p)           # extract the number of clusters

    n = sum(isa.(q.dists, DiscreteNonParametric)) # each data-point has a cluster assignment

    # initialization of variables
    φ = rand(n, K)
    φ ./= sum(φ; dims = 2)
    m = randn(K)
    s² = ones(K)

    # objective
    elbo = Variational.ELBO()

    # optimize
    optimize!(alg, elbo, q, model, m, s², φ)

    # create new distribution
    q_new = MeanField(copy(q.dists), q.ranges)

    for k = 1:K
        q_new.dists[k] = Normal(m[k], sqrt(s²[k]))
    end

    for i ∈ eachindex(x)
        q_new.dists[i + K] = Categorical(φ[i, :])
    end

    q_new
end

Now we construct our mean-field approximation from model.

# extract priors into mean-field
var_info = Turing.VarInfo();
model(var_info, Turing.SampleFromUniform())

q = MeanField(deepcopy(var_info.dists), var_info.ranges)
MeanField{Array{Distribution,1}}(
dists: Distribution[Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7])  …  DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7])]
ranges: UnitRange{Int64}[1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10  …  994:994, 995:995, 996:996, 997:997, 998:998, 999:999, 1000:1000, 1001:1001, 1002:1002, 1003:1003]
)

Let's check that it indeed produces the expected samples:

rand(q)
1003-element Array{Float64,1}:
 -0.7810409728512453
 -0.7015975666891903
 -1.2133827111882445
  3.0               
  3.0               
  3.0               
  3.0               
  2.0               
  3.0               
  1.0               
  3.0               
  3.0               
  3.0               
  ⋮                 
  3.0               
  3.0               
  3.0               
  3.0               
  2.0               
  3.0               
  1.0               
  2.0               
  3.0               
  2.0               
  1.0               
  2.0               

The only issue with this MeanField implementation is that to satisfy the <: Distribution{Multivariate, Continuous} type constraint, we need to return AbstractArray{<: Real}. This means that the cluster assignments, which ideally would be integers, are now returned as floats.

This is the reason why in the definition of mixture, we added a Int(c[i]). It's kind of ugly, but, yah know, it works.

Finally we can perform CAVI on this mixture model:

# perform VI
cavi = CAVI(10)
q_new = vi(model, cavi, q)
MeanField{Array{Distribution,1}}(
dists: Distribution[Normal{Float64}(μ=-0.0720897, σ=1.42392e-11), Normal{Float64}(μ=-4.91145, σ=5.46488e-11), Normal{Float64}(μ=4.92814, σ=4.9844e-15), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[1.1981e-9, 1.67617e-27, 1.0]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.999615, 1.21824e-6, 0.000384242]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.999758, 0.000240409, 1.63404e-6]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[9.93394e-9, 1.07655e-25, 1.0]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[7.7963e-7, 5.76301e-22, 0.999999]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[8.06611e-14, 1.03465e-35, 1.0]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[2.26518e-9, 5.87004e-27, 1.0])  …  DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[8.50589e-7, 6.84062e-22, 0.999999]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[1.32992e-5, 0.999987, 3.62968e-20]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[3.50883e-5, 1.03289e-18, 0.999965]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[3.39579e-6, 1.04282e-20, 0.999997]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[6.90294e-5, 3.91168e-18, 0.999931]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.999588, 1.14177e-6, 0.000410837]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.000230365, 0.99977, 1.19755e-17]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[1.77431e-7, 1.0, 5.59742e-24]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[7.77467e-9, 6.64628e-26, 1.0]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[2.76477e-6, 0.999997, 1.4889e-21])]
ranges: UnitRange{Int64}[1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10  …  994:994, 995:995, 996:996, 997:997, 998:998, 999:999, 1000:1000, 1001:1001, 1002:1002, 1003:1003]
)

Let's compare the ELBO of q_new to q:

elbo = Variational.ELBO()
elbo(cavi, q, model, 100), elbo(cavi, q_new, model, 100)
(-13270.643508809633, -2240.377147461269)

Hey, it improved! Neat. Let's inspect the obtained parameters.

q_new.dists[1:K]  # <= distributions for the cluster means
3-element Array{Distribution,1}:
 Normal{Float64}(μ=-0.07208966158030183, σ=1.4239202985246593e-11)
 Normal{Float64}(μ=-4.911453128528145, σ=5.464878645528e-11)      
 Normal{Float64}(μ=4.928144006028532, σ=4.984398204586815e-15)    

As we can see, the means are fairly close and the variances are small. This is what we expected (or rather hoped for, but since it worked out, we just say it was expected, right guys?). In the data generation process we had the means fixed, i.e. "variance was zero", thus the variance of the estimates should also be small.

This further displays how this is a "contrived" example in Bayesian inference. Since we fixed the means 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png during the generation process, and the likelihood 2019-06-25-variational-inference-part-1-cavi_d323085fdf061c1f37aaa855af8a562d06f0bfa7.png had unit-variance, the posterior distributions we get should (and do) converge to, basically, point-estimates. As a result, sampling from the resulting variational posterior q_new will almost always give the same result.

This kind of "loses" the purpose of doing Bayesian inference in the first place since the "true" distribution of the latent variables is a discrete distribution with a single outcome, i.e. fixed distribution. At the same time, most interesting examples will probably not be in the conditionally conjugate family, thus CAVI won't be applicable. As mentioned before, there are more general approaches to come!

We can also inspect the learned cluster weights:

vec(mean(transpose(hcat([d.p for d ∈ q_new.dists[K + 1:end]]...)); dims = 1))
3-element Array{Float64,1}:
 0.19451210528824994
 0.08900147050764164
 0.7164864242041084 

Comparing to the cluster assignments in the dataset:

[sum(cluster_assignments .== v) / length(cluster_assignments) for v ∈ sort(unique(cluster_assignments))]
3-element Array{Float64,1}:
 0.089
 0.199
 0.712

Notice that indeed that the cluster weights correspond to the correct means, e.g. cluster with 2019-06-25-variational-inference-part-1-cavi_e3c947b61ef3f5a7cfe931a6d4470b59592b45c3.png indeed had weight 2019-06-25-variational-inference-part-1-cavi_63f226b2259afa8f4637876a31565e54762ff838.png. Neat stuff, ain't it?

Observe that the order of the cluster means is different from what we used in the data generation; we used the ordering [-5.0, 0.0, 5.0]. This is because the clusters are exchangable, i.e. any permutation of the cluster-indices will result in the same approximate posterior (modulo stochasticity). Which order we get these in simply comes down to the initialization values: if 2019-06-25-variational-inference-part-1-cavi_e7eade0cd650c5ec6ca3682d82834b8a9dccae26.png is closer to 2019-06-25-variational-inference-part-1-cavi_daf30f9f2c4ddf9bc5bec19f4acf22e064cc6284.png than 2019-06-25-variational-inference-part-1-cavi_69db63664777b86be30ce960ca277b9c54d37d3a.png after initialization, then it's very likely that the resulting variational posterior will have ordering with the index 2 corresponding to the "true" index 1, if yah feel me.

Let's finish off with a neat GIF of the cluster assignment. In creating this GIF, we'll also see how we the user can write VI for Turing as a "training" loop, more similar to other machine learning techniques. The general idea is to

  1. Initialize the parameters 2019-06-25-variational-inference-part-1-cavi_a14d8487fae6909c522cdd676b01387e7c5d213d.png, 2019-06-25-variational-inference-part-1-cavi_8f44de754519a6b4b737f2e24c2083a1a7cb4a03.png, and 2019-06-25-variational-inference-part-1-cavi_b614273a2c6406d98daf1108e22aa446959c69eb.png, together with the initial mean-field distribution q.
  2. Perform an optimization step using optimize!
  3. Plot the inferred cluster assignments and ELBO.
# Let's do one where we plot the progress
q = MeanField(deepcopy(var_info.dists), var_info.ranges)

cavi = CAVI(1)  # perform 1 step for each `optimize!` call

φ = rand(n, K)  # cluster assignments
φ ./= sum(φ; dims = 2)
m = randn(K)
s² = ones(K)

# update distributions
for k = 1:K
    q.dists[k] = Normal(m[k], sqrt(s²[k]))
end

for i ∈ eachindex(x)
    q.dists[i + K] = Categorical(φ[i, :])
end

objectives = []  # history of objective evaluations

anim = @animate for step_idx = 1:20
    # compute empirical estimate of ELBO
    step_elbo = elbo(cavi, q, model, 100)
    push!(objectives, step_elbo)

    @info "[$step_idx]" step_elbo

    # plotting of ELBO
    # sort so that we get the cluster in the middle is coloured red for visibility
    mean_idx_sort = sortperm(m)
    color_map = Dict(zip(sortperm(m), [colorant"#41b8f4", colorant"#f44155", colorant"#13c142"]))
    colors = map(c -> color_map[c], Int.(rand(q)[K + 1:end]))

    plt1 = scatter(x, color = colors, label = "")
    title!("Step $step_idx with ELBO $step_elbo")
    xlabel!("Index of data point")

    plt2 = plot(max(1, step_idx - 10):step_idx, objectives[max(1, step_idx - 10):step_idx], label = "ELBO")
    xlabel!("Step index")

    # run CAVI for 1 iteration to update parameters m, s², φ
    optimize!(cavi, elbo, q, model, m, s², φ)

    # update distributions so we can evaluate the ELBO with updated parameters
    for k = 1:K
        q.dists[k] = Normal(m[k], sqrt(s²[k]))
    end

    for i ∈ eachindex(x)
        q.dists[i + K] = Categorical(φ[i, :])
    end

    plot(plt1, plt2, layout = grid(2, 1, heights = [0.7, 0.3]), size = (1000, 500))
end
gif(anim, ".2019-06-25-variational-inference-part-1-cavi/figures/gaussian_mixture_cluster_assignments.gif", fps = 5);
┌ Info: [1]
│   step_elbo = -14505.608414741417
└ @ Main In[35]:28
┌ Info: [2]
│   step_elbo = -4295.3039673855865
└ @ Main In[35]:28
┌ Info: [3]
│   step_elbo = -3322.581925954455
└ @ Main In[35]:28
┌ Info: [4]
│   step_elbo = -3302.4541140390593
└ @ Main In[35]:28
┌ Info: [5]
│   step_elbo = -3300.152447964722
└ @ Main In[35]:28
┌ Info: [6]
│   step_elbo = -3295.6891878053775
└ @ Main In[35]:28
┌ Info: [7]
│   step_elbo = -3289.9682776026852
└ @ Main In[35]:28
┌ Info: [8]
│   step_elbo = -3284.6262893228004
└ @ Main In[35]:28
┌ Info: [9]
│   step_elbo = -3273.6817405178094
└ @ Main In[35]:28
┌ Info: [10]
│   step_elbo = -3266.9336190510426
└ @ Main In[35]:28
┌ Info: [11]
│   step_elbo = -3259.8310113380035
└ @ Main In[35]:28
┌ Info: [12]
│   step_elbo = -3248.8725276001446
└ @ Main In[35]:28
┌ Info: [13]
│   step_elbo = -3243.003893619992
└ @ Main In[35]:28
┌ Info: [14]
│   step_elbo = -3235.485410286028
└ @ Main In[35]:28
┌ Info: [15]
│   step_elbo = -3226.4825612330606
└ @ Main In[35]:28
┌ Info: [16]
│   step_elbo = -3218.4085612846575
└ @ Main In[35]:28
┌ Info: [17]
│   step_elbo = -3209.639842972914
└ @ Main In[35]:28
┌ Info: [18]
│   step_elbo = -3202.432054017625
└ @ Main In[35]:28
┌ Info: [19]
│   step_elbo = -3193.836909448166
└ @ Main In[35]:28
┌ Info: [20]
│   step_elbo = -3184.2329537025767
└ @ Main In[35]:28
┌ Info: Saved animation to 
│   fn = /home/tor/org-blog/posts/.2019-06-25-variational-inference-part-1-cavi/figures/gaussian_mixture_cluster_assignments.gif
└ @ Plots /home/tor/.julia/packages/Plots/oiirH/src/animation.jl:90

gaussian_mixture_cluster_assignments.gif

In the future

Though we have indeed derived the algorithm for a particular example, and the current implementation is specified to an instance of a particular mixture model, one could imagine generalizing this approach. We noted earlier that this is indeed an instance of more general family of distributions, which we can perform CAVI on with closed form updates blei16_variat_infer. One could imagine automatically detecting whether or not a Model consisted of exponential distributions and if the entire model was conditionally conjugate. Then we could simply construct a suitable mean-field approximation and then run CAVI with these more general updates!

That would be neat.

Final remarks

Aight, so this was interesting and all, but we had to do a lot of manual labour to obtain the variational updates. In an ideal world you have defined some Model and you just want to call vi(model, alg), maybe even make the choice of the variational family yourself by calling vi(model, alg, q), and then Turing does the rest. Not having to derive closed form updates, not having to worry about whether the densities in your model are in the exponential family, etc. In the next post we will have a look at a variational inference algorithm which gets us further in that direction: automatic differentiation variational inference (ADVI).

The animation below is obtained simply by calling vi(model, ADVI(10, 1000)) on a simple generative model:

advi_w_elbo_fps15_1_forward_diff.gif

It's going to be neat.

Bibliography

  • [blei16_variat_infer] Blei, Kucukelbir, McAuliffe & Jon, Variational Inference: a Review for Statisticians, CoRR, (2016). link.
  • [pmlr-v80-chen18k] Chen, Tao, Zhang, Henao & Duke, Variational Inference and Model Selection with Generalized Evidence Bounds, 893-902, in in: Proceedings of the 35th International Conference on Machine Learning, edited by Dy & Krause, PMLR (2018)
  • [liu16_stein_variat_gradien_descen] Liu & Wang, Stein Variational Gradient Descent: a General Purpose Bayesian Inference Algorithm, CoRR, (2016). link.

Appendix

Code

# Mean-field for arbitrary distributions
import Distributions: _rand!, _logpdf

struct MeanField{TDists <: AbstractVector{<: Distribution}} <: Distribution{Multivariate, Continuous}
    dists::TDists
    ranges::Vector{UnitRange{Int}}
end

MeanField(dists) = begin
    ranges::Vector{UnitRange{Int}} = []
    idx = 1
    for d ∈ dists
        push!(ranges, idx:idx + length(d) - 1)
        idx += length(d)
    end

    MeanField(dists, ranges)
end

# Base.length(mf::MeanField) = sum(length(d) for d ∈ mf.dists)
Base.length(mf::MeanField) = mf.ranges[end][end]
_rand!(rng::AbstractRNG, mf::MeanField, x::AbstractVector{T} where T <: Real) = begin
    for i ∈ eachindex(mf.ranges)
        d = mf.dists[i]
        r = mf.ranges[i]

        x[r] = rand(rng, d, 1)
    end

    x
end

_logpdf(mf::MeanField, x::AbstractVector{T} where T <: Real) = begin
    tot = 0.0
    for i ∈ eachindex(mf.ranges)
        r = mf.ranges[i]
        if length(r) == 1
            tot += logpdf(mf.dists[i], x[r[1]])
        else
            tot += logpdf(mf.dists[i], x[r])
        end
    end

    return tot
end

import StatsBase: params, entropy
entropy(mf::MeanField) = begin
    sum(entropy.(mf.dists))
end
logdensity(model::Turing.Model, z) = begin
    # kind of hacky; improved impl coming to Turing soon:
    # https://github.com/TuringLang/Turing.jl/issues/817

    # initialize VarInfo
    var_info = Turing.VarInfo()
    model(var_info)

    var_info.vals .= z
    model(var_info)
    var_info.logp
end

Derivation of variational updates

Recall the equation we found ealier for updates when the model is conditionally conjugate:

2019-06-25-variational-inference-part-1-cavi_24f6b32bcf273dea0684ee6bd8ba7d10ec82ce0f.png

Now we consider the different mean-field factors.

Mixture assignments

In this case we simply have

2019-06-25-variational-inference-part-1-cavi_aea433c58a949b9f7278af5c08105ebffc69b594.png

since the cluster assignment of 2019-06-25-variational-inference-part-1-cavi_64a587acefac95d72fb90156daa7779bb32ab6d1.png for 2019-06-25-variational-inference-part-1-cavi_c1a095bc430b73abaf63eae78d315295aaa78937.png is (conditionally) independent of the cluster assignments of the rest of the observations, denoted 2019-06-25-variational-inference-part-1-cavi_f49096d4f4e40dfc7417661094e875f0ea9957f3.png. The first term is simply

2019-06-25-variational-inference-part-1-cavi_f04cb0c776735c2a51ad81067a4ae2070de985a8.png

where 2019-06-25-variational-inference-part-1-cavi_9f656ddb23a905a01ee310dcfd4c1da139a1024a.png is the cluster-weight of the i-th cluster, or equivalently, the probability of assignment to the i-th cluster. Thus the interesting term is the expectation of the log-likelihood. Though in our implementation of this Gaussian mixture, 2019-06-25-variational-inference-part-1-cavi_6f5e48c2022723307aafa4cc44933aa17d28e91d.png, we can identitically consider 2019-06-25-variational-inference-part-1-cavi_64a587acefac95d72fb90156daa7779bb32ab6d1.png as a indicator vector, with the k-th component given by

2019-06-25-variational-inference-part-1-cavi_4910e129d0f8550f7d7ffefccdc389f91cfe8504.png

where

2019-06-25-variational-inference-part-1-cavi_967070508c46abb16a8b314d467bed35d0a64d56.png

Then we can write the log-likelihood

2019-06-25-variational-inference-part-1-cavi_1e9cbb88d58ccdd478d3ba895fae371cda9e9bb1.png

Then

2019-06-25-variational-inference-part-1-cavi_2a1cfb313e75ffd43eef967d9cec916f8d743b23.png

where we have used the fact that the

2019-06-25-variational-inference-part-1-cavi_c6acb741039b9a5f97e1e62408ff2ce3585d0297.png

to "move" the constant expression out of the sum, and, recalling that 2019-06-25-variational-inference-part-1-cavi_a500ea3ff82c9cea65d5902ccb313b429cc78963.png is simply a unit variance Gaussian with mean 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png,

2019-06-25-variational-inference-part-1-cavi_89f4dbb6350a0389b4a505dd2ff5f62c4ca13b6d.png

Moreover, observe that

2019-06-25-variational-inference-part-1-cavi_7c3e8fbe8ac5b703b84e0272d633ba3020290b5a.png

and that

2019-06-25-variational-inference-part-1-cavi_a297d70894d2450748d446885a4fe66bfaf56d74.png

Substituting back into the above expression

2019-06-25-variational-inference-part-1-cavi_efc9ab5eae1333a1d46763f252a2563975821a3b.png

where we have observed the 2019-06-25-variational-inference-part-1-cavi_9ec8e7c68bf7b7b4b982acf6542c353e63f55062.png into the 2019-06-25-variational-inference-part-1-cavi_130bc2b12c9dd1fe8791a2571f820aaf6b06845c.png term. This is indeed is equivalent to the expression derived in blei16_variat_infer by simply observing that

2019-06-25-variational-inference-part-1-cavi_68a6cb3b69a55c80749fe3f889ed44c66cf65e2d.png

This finally gives us the full update for 2019-06-25-variational-inference-part-1-cavi_e2b6cb82e2f2dbcf424839a7643e2ec3c238e5ac.png:

2019-06-25-variational-inference-part-1-cavi_a835bc764dde3e33ec6534ca368c0531b12841ed.png

Means of mixture components

For the means 2019-06-25-variational-inference-part-1-cavi_daf4c98ec256e58b010370c12204e88945e82b90.png we have

2019-06-25-variational-inference-part-1-cavi_5811eb78fc952a0ef474c2c487c9204537ca3b88.png

Observe then that we can complete the square

2019-06-25-variational-inference-part-1-cavi_57cd48040bf6a3beb86c0a8187ddb1a0eecd5a24.png

Exponentiating this and bringing the term in red into a multiplicative constant (wrt. 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png) 2019-06-25-variational-inference-part-1-cavi_a4680369a3334fa4e273bbd161e5ca60a5d7c8a5.png, leaving us with

2019-06-25-variational-inference-part-1-cavi_8191c94d1e7af02a46dac3343e435c7270757f36.png

where

2019-06-25-variational-inference-part-1-cavi_9b5be977aa0cb25d3f72f975c76c470576fd28ce.png

Congratulations! It's a bo…Gaussian!

Actually, in blei16_variat_infer they insist on using 2019-06-25-variational-inference-part-1-cavi_983afa4cb71952f7c1e1111b04fbf529d4677059.png instead of just writing 2019-06-25-variational-inference-part-1-cavi_1f16af3cf4c8f31242256354a1103d47a6c017f6.png, etc. At first I didn't understand why, but now I realize: the above is basically deriving that the conditionally conjugate density is a Gaussian, rather than assuming so! Neato.

Footnotes:

1

It's at least famous in certain circles. Maybe few circles, but still, they exist… There are dozens of us! Dozens!

2

As we will see in later posts in this series, this is not always the case. E.g. the class of variational inference methods called Ammortized VI uses a different latent variable 2019-06-25-variational-inference-part-1-cavi_f4d02d8b757b78d2991836b88709a29d2d6c38f9.png for each observation 2019-06-25-variational-inference-part-1-cavi_c1a095bc430b73abaf63eae78d315295aaa78937.png.

3

Usually the variational family 2019-06-25-variational-inference-part-1-cavi_dc8fda671dad6051d988d7dd31211745f3f81823.png is chosen s.t. 2019-06-25-variational-inference-part-1-cavi_34724c66b3e97a635a63f6916f5be50dd45cf282.png is cheap to sample from.

4

You noticed the v3? What happened to v1 and v2? Don't worry about it, they're fine.