Variational Inference (Part 1): Variational Inference & CAVI

Table of Contents


In this post I will present to you what's known as variational inference (VI), a family of approximate Bayesian inference methods. In particular, we will focus on one of the more basic VI methods called coordinate ascent variational inference (CAVI). We will go through the theory for VI, and apply CAVI to the case of a Gaussian mixture, following and adding to what's presented in blei16_variat_infer. With the theory clear we will go through an implemention in Turing.jl as a simple (and slightly contrived) example of how to implement "custom" variational inference methods. Hopefully you'll learn something. The goal is that you walk way with (in descending order of importance):

  1. A basic understanding of VI and why we bother
  2. A general idea of how to use the (work in progress) VI interface in Turing.jl
  3. A vague familiarity with CAVI

Just a heads up: in contrast to the previous post this is one is less focused on memes, involving a fair bit of mathematics and code. At ripe age of 24, there's certainly one thing I've learned: memes can only get you so far in life.


  • VI abbreviates variational inference
  • CAVI abbvr. coordinate ascent VI
  • 2019-06-25-variational-inference-part-1-cavi_5403f6bbface4889f05450c96efe5dfafd041d71.png denotes the number of data
  • 2019-06-25-variational-inference-part-1-cavi_2379137a70f55d3e67efab50fe29f05dbee4ce3a.png denotes a set of data
  • Evidence lower bound (ELBO):





In Bayesian inference one usually specifies a model as follows: given data 2019-06-25-variational-inference-part-1-cavi_03d068b4949ffc917c6ef07dd7455c00089aa282.png,


where 2019-06-25-variational-inference-part-1-cavi_5179e66dd4274af2f7d0b87eb3ace1f19d37414d.png denotes that the samples are identically independently distributed. Our goal in Bayesian inference is then to find the posterior


In general one cannot obtain a closed form expression for 2019-06-25-variational-inference-part-1-cavi_897c7fada95f8af980ffc865424400e015b13d3a.png, but one might still be able to sample from 2019-06-25-variational-inference-part-1-cavi_897c7fada95f8af980ffc865424400e015b13d3a.png with guarantees of converging to the target posterior 2019-06-25-variational-inference-part-1-cavi_897c7fada95f8af980ffc865424400e015b13d3a.png as the number of samples go to 2019-06-25-variational-inference-part-1-cavi_8135ffb7ca4c5b92ced3fcd4204b8705472e3c5c.png, e.g. MCMC.

"So what's the problem? We have these nice methods which we know converge to the true posterior in the infinite limit. What more do we really need?!"

Weeell, things aren't always that easy.

  1. As the model 2019-06-25-variational-inference-part-1-cavi_b346b8eaf5913ed81b033dacc1d6568c142bbb4a.png becomes increasingly complex, e.g. more variables or multi-modality (i.e. separate "peaks"), convergence of these unbiased samplers can slow down dramatically. Still, in the infinite limit, these methods should converge to the true posterior, but infinity is fairly large, like, at least more than 12!
  2. Large amounts of data means more evaluations of the log-likelihood which, if it's sufficiently expensive to compute 2019-06-25-variational-inference-part-1-cavi_ab20ba1962a9eb230e624066148820ba492659f0.png, can make direct sampling infeasible.

In fact, there is also an application of VI in the context of Bayesian model selection, but we will not dig into that here.pmlr-v80-chen18k

Therefore it might at times be a good idea to use an approximate posterior, which we'll denote 2019-06-25-variational-inference-part-1-cavi_fe00df71e7543e86c03dfb375840b14ae64dba82.png. Clearly we'd like it to be close to target 2019-06-25-variational-inference-part-1-cavi_b346b8eaf5913ed81b033dacc1d6568c142bbb4a.png, but what does even "close" mean in the context of probability densities?

This question really deserves much more than a paragraph and is indeed quite an interesting topic. There are plenty of approaches to take in this case, some come down to putting actual metrics (defining a "distance") on spaces of probability densities, e.g. Wasserstein / Earth-mover distance, while other approaches are less strict, e.g. Kullback-Leibler (KL) divergence which, compared to a metric, is not symmetric in its arguments.


For now we will consider the KL-divergence, but as we will probably see in later posts, there are several viable alternatives which can also be employed in VI. KL-divergence has the nice property that it's equivalent to maximizing the likelihood.

With KL-divergence as this measure of how "good" of an approximation our approximate posterior 2019-06-25-variational-inference-part-1-cavi_fe00df71e7543e86c03dfb375840b14ae64dba82.png is to 2019-06-25-variational-inference-part-1-cavi_6e68cb0ca7095a69db7adc49eeec5ab146d7134c.png, the objective is to find the optimal 2019-06-25-variational-inference-part-1-cavi_2b6f58b235cda96698852897ddd4abf7d1473a77.png where 2019-06-25-variational-inference-part-1-cavi_27e16d785f7ae685d71efb97ad73aa13b3ea32ac.png is space of densities we have chosen to consider:


where the KL-divergence is defined


So the idea is to construct some approximate posterior 2019-06-25-variational-inference-part-1-cavi_9974194a836f8c06efda1c78a55ab184c6c2d746.png which it's feasible to sample from, allowing us to do approximate Bayesian inference. But there's a number of practical issues with the above objective:

  1. We can't evaluate 2019-06-25-variational-inference-part-1-cavi_37bf5958eebb9907bab22437e5bfe74e4452164f.png! Evaluating 2019-06-25-variational-inference-part-1-cavi_37bf5958eebb9907bab22437e5bfe74e4452164f.png is the entire goal of Bayesian inference; if we could simply evaluate this, we would be done.
  2. Integrating over all 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png is of course in general intractable.

In the case of KL-divergence, the standard approach to address (1) is to instead optimize what's know as the evidence lower bound (ELBO).

The famous ELBO1

Using the fact that 2019-06-25-variational-inference-part-1-cavi_c147244f76e7b608dd6810bf6da1af0c645b6df7.png, the RHS of the KL-divergence becomes


where we've simply brought the 2019-06-25-variational-inference-part-1-cavi_2d71aa14217932e7c1dd02bbf12c38f25b05654c.png term outside of the expectation since 2019-06-25-variational-inference-part-1-cavi_edba165435c97711c6c316b9b57b45649434e403.png is independent of 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png. We can therefore re-arrange the original expression to


This then implies that if we maximize the second term on the RHS we're implicitly minimizing the KL-divergence between 2019-06-25-variational-inference-part-1-cavi_37bf5958eebb9907bab22437e5bfe74e4452164f.png and 2019-06-25-variational-inference-part-1-cavi_fe00df71e7543e86c03dfb375840b14ae64dba82.png! This term is called the evidence lower bound (ELBO), stemming from the fact that 2019-06-25-variational-inference-part-1-cavi_048f6c98e916ed8890f70a26012e2f994b6f56f5.png is often referred to as the "evidence" in the Bayesian literature and it turns out that the KL-divergence is indeed non-negative, i.e. 2019-06-25-variational-inference-part-1-cavi_ca91cfb8277cbeb1d4cf27badb533dad1b9dbcb1.png. We therefore define


This still requires us to feasibly compute 2019-06-25-variational-inference-part-1-cavi_f07b4155a393eae07c6e121526e0674b5310f633.png!

Also, the fact that this lower bounds the evidence 2019-06-25-variational-inference-part-1-cavi_048f6c98e916ed8890f70a26012e2f994b6f56f5.png is how one can use VI to do model selection, but again, not getting into that here.

Since the expectation is linear and we're assuming independence, if we want to take into account all the samples in 2019-06-25-variational-inference-part-1-cavi_4896aad829445b20849ca132a1e290bbd1d47df3.png, we simply sum these contributions to get


in the case where we use the same 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png for the different samples.2

Awesome. We now have an objective to maximize, whose maximizer then gives us the "optimal" approximate posterior 2019-06-25-variational-inference-part-1-cavi_fe00df71e7543e86c03dfb375840b14ae64dba82.png from which we can sample from.3

But computing 2019-06-25-variational-inference-part-1-cavi_3d20895be87bfffabe9d5dd2ecea45a7d87139b7.png involves computing an expectation, which of course, again, requires evaluation of a (usually) intractable integral, i.e. 2. in the problems for the KL-divergence noted in previous section. Luckily, in this case we can simply substitute the expectations which are intractable with the empirical expectation, i.e. if we want to estimate the expectation of some function 2019-06-25-variational-inference-part-1-cavi_f017a9b0f9e8a176a3db97e0116f98ff496ac318.png wrt. a density 2019-06-25-variational-inference-part-1-cavi_048f6c98e916ed8890f70a26012e2f994b6f56f5.png, then


which, by Strong Law of Large Numbers (SLLN), this converges to the true expectation in the infinite limit. We can apply this to the ELBO:


which is much more of an approachable problem. Granted, this makes our objective non-determinstic / stochastic / random / wibbly-wobbly, which in turn means that our posterior estimate will have much higher variance. This usually means slower convergence, and possibly even complete failure to converge.

In fact, the approach to estimating the 2019-06-25-variational-inference-part-1-cavi_ae6e662081e287949ed603dc8f9faa9a9997cbd7.png is often what distinguishes VI approaches. We will discuss this further in the future; it's even likely that we'll dedicate one entire post only to the discussion / comparison of different approaches to this estimation.

In the particular example we will consider now, we actually have access to a closed form expression for 2019-06-25-variational-inference-part-1-cavi_0810dffa96937613ef0d06e4fd638c47e098e122.png, and therefore only need to make use of the empirical estimation to the first term of the ELBO: 2019-06-25-variational-inference-part-1-cavi_a4784f36cb2c74752531b8310ea19932e4772258.png. As mentioned above, this will lead to reduced variance in our updates and faster convergence.

How do I use this in Turing.jl?

Before diving into a particular example, let's just quickly go through the VI interface in Turing.jl.

Unfortunately VI is not officially supported yet, but you can get the WIP version if you checkout my VI branch tor/vi-v3.4 All code that follows is assuming you indeed have done that and that you've activated the local environment by doing ]activate(".") in the repository directory.

using Turing
using Turing: Variational

Just one more time: this is work in progress! What I present to you here is basically the first draft and thus unlikely to survive for long. But I put it here for you to see; maybe you even have ideas on how to improve it! If so, hit me up on the Julia slack channel @torfjelde :)

VI in Turing.jl is mainly supposed to be used as an approximate substitute for sample(model, sampler) interface:

vi(model::Model, alg::VariationalInference)
vi(model::Model, alg::VariationalInference, q::VariationalPosterior)

Constructs the variational posterior from the model and performs the
optimization following the configuration of the given VariationalInference

Therefore, in general we'll have the following psuedo-code:

# 1. Instantiate `Turing.Model` with observations
m = model(data)

# 2. Construct optimizer and VI algo.
opt = Variational.ADAGrad()                   # proxy for `Flux.Optimise`
alg = Alg(args...)                            # <= `Alg <: VariationalInference`

# 3. Perform variational inference
q = vi(m, alg; optimizer = opt)               # =>  <: `VariationalPosterior`
q = vi(m, alg, q_before_opt; optimizer = opt) # => optimized `q <: VariationalPosterior`
rand(q, 1000)

If q is not provided, Turing will extract the latent variables in your model and construct a VariationalPosterior (i.e. a Distribution{Multivariate, Continuous}) and optimize the corresponding objective, if the VariationalInference method indeed has a default family.

The most significant difference from the existing sample(...) interface is that vi(...) returns a VariationalPosterior which we can call rand on to obtain approximate posterior samples, while sample returns a Chains object from MCMCChains containing the posterior samples. Currently VariationalPosterior is simply an alias for Distribution{Multivariate, Continuous} from Distributions.jl. The idea is that the variational posteriors should act as standard distributions, implementing both rand and logpdf. This way, any multivariate distribution from Distributions.jl will work (if supported by the method).

An VariationalInference instance is supposed to hold the "configuration" for a VI method, e.g. number of samples used in the empirical expectation for the estimate of the ELBO. Additionally, even though two VariationalInference types have exactly the same struct fields, they might still have very different behaviour.

The following methods might have different behaviour based on the VariationalInference type:

  • optimize!:
    • If the VI method is not gradient-based (e.g. CAVI as we'll soon see), then this is where the main chunk of work occurs, and thus is the method to implement.
    • If the VI method is gradient-based, optimize! will usually contain the "training-loop", where the gradients are obtained by calling grad!. Then the parameters are updated by applying the suitable update rule decided by the optimizer used, e.g. ADAGrad.
      • There are also certain cases where we need to do more than just update the variational variables every iteration, e.g. min-max steps needed for certain objectives liu16_stein_variat_gradien_descen.
    • Default implementation is provided which simply takes a given number of optimization steps by calling grad! and apply!(optimizer, ...).
  • grad!: If the VI method is gradient-based, then this is usually where the heavy lifting is done. As mentioned earlier, the gradient estimator is often the distinguishing factor between different VI methods, e.g. reparameterization trick (next posts guys; calm down).
    • Default implementations are provided by using ForwardDiff and Flux.Tracker to differentiate the objective call, depending on which AD-type is specified for the VariationalInference{AD} instance.

And of course you we need to implement some objective function, e.g. ELBO. Internally in Turing.jl we define types and implement calling an instance, e.g. struct ELBO end and a (elbo::ELBO)(...). This allows for different VariationalInference methods to have different behaviour for different objectives. If you're interested in implementing your own VI objective maybe using a simple function, then you can just specialize the optimize! call with a obj::typeof(objective_function) as argument.

optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = ADAGrad())

Iteratively updates parameters by calling grad! and using the given
optimizer to compute the steps.

grad!(vo, vi::VariationalInference, q::VariationalPosterior, model::Model, θ, out, args...)

Computes the gradients used in optimize!. Default implementation is provided
for VariationalInference{AD} where AD is either ForwardDiffAD or TrackerAD.
This implicitly also gives a default implementation of optimize!.

Variance reduction techniques, e.g. control variates, should be implemented
in this function.

Aight, aight; that's a lot of information. Let's do an example!

Example: 1D Gaussian Mixture


  • 1D Gaussian mixture:


    where 2019-06-25-variational-inference-part-1-cavi_afff5e4a964db8f1a5d7b14e41e981ed24b7779e.png

  • Variational distribution



  • 2019-06-25-variational-inference-part-1-cavi_ef8353830be8e017ae0b98c5c37ad3c2b43889c2.png is a Gaussian with mean 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png and variance 2019-06-25-variational-inference-part-1-cavi_ecbdc4ce3de5acbc3faf63535cf687fc8987c25a.png for 2019-06-25-variational-inference-part-1-cavi_4a9d41329be445dfda57b0275c43ecbe7f1442f0.png
  • 2019-06-25-variational-inference-part-1-cavi_e2b6cb82e2f2dbcf424839a7643e2ec3c238e5ac.png is a Categorical distribution where 2019-06-25-variational-inference-part-1-cavi_934c24dc230da8e2fbab1fffb773717da4ece5f4.png (a K-dim vector) s.t. 2019-06-25-variational-inference-part-1-cavi_6fa5792006d0d41bcc934ffcf10d9503ca80a5f3.png
  • 2019-06-25-variational-inference-part-1-cavi_ab7d542ae6d44aebfa144db3ef0e67eb83090aeb.png denote 2019-06-25-variational-inference-part-1-cavi_9974194a836f8c06efda1c78a55ab184c6c2d746.png is a parameterized density of 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png with param 2019-06-25-variational-inference-part-1-cavi_eb13eedc64828a894bed17ed8b010e5fea060eff.png.
  • 2019-06-25-variational-inference-part-1-cavi_b83ea156588fec75856a7276b25c5fa05da55c89.png means we assume 2019-06-25-variational-inference-part-1-cavi_eb13eedc64828a894bed17ed8b010e5fea060eff.png is also random variable itself, i.e. latent variable in Bayesian inference setting.
  • Expectation of some function 2019-06-25-variational-inference-part-1-cavi_142c5f350072a393e80de30083151077316c979f.png wrt. density 2019-06-25-variational-inference-part-1-cavi_165ca9304fe2167e21f8e1d42dec893850d712d3.png is denoted



Let's apply VI to a simple Gaussian mixture problem, because why not? We will closely follow blei16_variat_infer with a slight variation (other than notational differences): we allow non-uniform mixture weights, i.e. being assinged to cluster 2019-06-25-variational-inference-part-1-cavi_bb118e347933fbc8dbe1012095a954c8fff97f94.png might have a different probability than being assinged to cluster 2019-06-25-variational-inference-part-1-cavi_e75af26294e87c419eaa035752a6c4d7539d4133.png.

This example lends itself nicely as a base-case for VI as we can derive closed form expressions for the updates. Though it takes a bit of work, it nicely demonstrates the underlying mechanisms necessary for VI and demonstrates why we really, really want more general VI approaches :) Aight, let's get started.

For a 1D Gaussian mixture, one particular generative model can be




In this case we then have


where 2019-06-25-variational-inference-part-1-cavi_993c838215c0a1f5213700aac34101133fd2315c.png denotes summing over all possible values 2019-06-25-variational-inference-part-1-cavi_64a587acefac95d72fb90156daa7779bb32ab6d1.png can take on. Following blei16_variat_infer we can actually rewrite 2019-06-25-variational-inference-part-1-cavi_048f6c98e916ed8890f70a26012e2f994b6f56f5.png to


The conjugacy of the Gaussian prior 2019-06-25-variational-inference-part-1-cavi_98b88f9b79fe9fd933c6a80e593ec828cc00fb2c.png and Gaussian likelihood 2019-06-25-variational-inference-part-1-cavi_900ade0ca2561c190505ff2a9939a1d004353d65.png, i.e. that Gaussian prior and Gaussian likelihood gives a Gaussian posterior, we can indeed obtain a closed form expression for this. Even so, there are then 2019-06-25-variational-inference-part-1-cavi_f27f7267f731fc4d442d1f64209c81691dcb28d7.png number of these terms since we're summing over all the different cluster assignments: 2019-06-25-variational-inference-part-1-cavi_6f5e48c2022723307aafa4cc44933aa17d28e91d.png and 2019-06-25-variational-inference-part-1-cavi_2109a37107ee89e94c82381eaa6d34b8a2bdc051.png gives 2019-06-25-variational-inference-part-1-cavi_f27f7267f731fc4d442d1f64209c81691dcb28d7.png terms. Thus, for the sake of efficiency, we reach for VI!

Coordinate ascent (mean-field) variational inference (CAVI)

As noted earlier, to perform VI we need to specify a variational family 2019-06-25-variational-inference-part-1-cavi_27e16d785f7ae685d71efb97ad73aa13b3ea32ac.png. To make life easy for ourselves, we will make what's known as the mean-field assumption, i.e. the latent variables are independent so 2019-06-25-variational-inference-part-1-cavi_d4c44ca4684c8f12406efe5d89ac84a5843a3f85.png means that


for some independent densitites 2019-06-25-variational-inference-part-1-cavi_6b31c36a145dbffb1ccbd43d16e70eafe873a24c.png. To choose these 2019-06-25-variational-inference-part-1-cavi_603b3cfceaf8a537967899b2c0ca88a0c9652c7e.png, a natural thing to do is to take the densities used in the original model and apply the mean-field assumption to this model, i.e.



  • 2019-06-25-variational-inference-part-1-cavi_ef8353830be8e017ae0b98c5c37ad3c2b43889c2.png is a Gaussian with mean 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png and variance 2019-06-25-variational-inference-part-1-cavi_ecbdc4ce3de5acbc3faf63535cf687fc8987c25a.png for 2019-06-25-variational-inference-part-1-cavi_4a9d41329be445dfda57b0275c43ecbe7f1442f0.png
  • 2019-06-25-variational-inference-part-1-cavi_e2b6cb82e2f2dbcf424839a7643e2ec3c238e5ac.png is a Categorical distribution where 2019-06-25-variational-inference-part-1-cavi_934c24dc230da8e2fbab1fffb773717da4ece5f4.png (a K-dim vector) s.t. 2019-06-25-variational-inference-part-1-cavi_6fa5792006d0d41bcc934ffcf10d9503ca80a5f3.png

Recall we want to find 2019-06-25-variational-inference-part-1-cavi_a6f97f9c6951a838233eda678a6b02c180a83075.png s.t. we maximize the 2019-06-25-variational-inference-part-1-cavi_8aa5cd31a08eec68d0718d66ee31fcb324a207a1.png. To do this, we're going to use one of the most basic VI methods called coordinate ascent variational inference (CAVI). This works by iteratively optimize each of the latent variables conditioned on the rest of the latent variables, similar to the iterative process of Gibbs sampling (but of course, doesn't involve any optimization).

Observe that we can write the ELBO under the mean-field assumption as


where 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png denotes the parameters of 2019-06-25-variational-inference-part-1-cavi_8b34a5842f09b033248a46e3623d35bc77033a09.png. Then observe that


Therefore we can also absorb the second term above into the constant wrt. 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png term:


We therefore make the justified guess


(up to a constant normalization factor) which means that if we substitute into the above expression for ELBO, the two first terms cancel eachother out, and we're left with


If 2019-06-25-variational-inference-part-1-cavi_603b3cfceaf8a537967899b2c0ca88a0c9652c7e.png was parameterized by some parameter 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png (and differentiable wrt. 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png), we could differentiate wrt. 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png and see that the LHS vanishes completely when 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png is set to satisfy the 2019-06-25-variational-inference-part-1-cavi_8eed54548e9ec0473c2c02bdafd10ad0f13f9cbb.png above (assuming such a 2019-06-25-variational-inference-part-1-cavi_824c17af3c3c13146e942952cef999fda6fa8621.png exists of course). At the very least, this suggests an iterative process to maximize the ELBO:


And indeed this converges to a local optimum blei16_variat_infer.

Now, we still need to compute 2019-06-25-variational-inference-part-1-cavi_b53497307f3bb346cac8c4a606be3c1d3c63c895.png. But of course, the entire reason for choosing this model and variational family is because in this case these expressions are available in closed form:


where we have used the fact that if 2019-06-25-variational-inference-part-1-cavi_6e1d3ff5d22a90c650dd418e4eb7bc0f612ce851.png then


The derivations of these can be found in the appendix (or without 2019-06-25-variational-inference-part-1-cavi_9f656ddb23a905a01ee310dcfd4c1da139a1024a.png in blei16_variat_infer).

This example is a special case of what's known as conditionally conjugate models with "local" and "global" variables.

  • A variable is local if only a single data point depends on this variable
  • A variable is global if multiple data points depend on this variable

We use this terminology of "global" and "local" for both the latent variables and the variational variables. E.g.

  • Latent:
    • Local: cluster assignments 2019-06-25-variational-inference-part-1-cavi_64a587acefac95d72fb90156daa7779bb32ab6d1.png
    • Global: cluster means 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png
  • Variational:
    • Local: parameters of cluster assignments 2019-06-25-variational-inference-part-1-cavi_ef555ff2069ebd01f48d11b2c36e6732ffde9266.png
    • Global: means 2019-06-25-variational-inference-part-1-cavi_1f16af3cf4c8f31242256354a1103d47a6c017f6.png and variances 2019-06-25-variational-inference-part-1-cavi_ecbdc4ce3de5acbc3faf63535cf687fc8987c25a.png for the different clusters

In a general conditionally conjugate model, say 2019-06-25-variational-inference-part-1-cavi_31aa0b82f4e4fe4be2c026d5eddd655a517a7f53.png is the global latent variables and 2019-06-25-variational-inference-part-1-cavi_01f7b7127bdc73de1b82efda8d5a045166ffb9fa.png is local latent variables, then the joint density is given by


In the case where these factors are also in the exponential family, closed form updates for CAVI can be derived in a similar way as we did for the Gaussian mixture model above.blei16_variat_infer


Now, we before we start, remember that we're trying to get a more efficient way of sampling from our model, right? We're not trying to infer the parameters that generated the data, but instead working under the assumption the data came from the given model with known parameters (which in this is of course true) and we want to produce approximate samples from this model. With that in mind, let's begin!

First we import some necessary packages and set the random seed for reproducibility.

using Random

using Plots, StatsPlots, LaTeXStrings

The Turing.Model can then be defined

σ² = 1.0

K = 3               # number of clusters
n = 1000            # number of samples

@model mixture(x =  Vector{Float64}(undef, 1), μ = Vector{Float64}(undef, K), π = (1 / K) .* ones(K)) = begin
    for i = 1:K
        μ[i] ~ Normal(0.0, σ²)

    c = tzeros(Int, length(x))
    for i ∈ eachindex(x)
        c[i] ~ Categorical(π)

        x[i] ~ Normal(μ[c[i]], 1)
mixture (generic function with 4 methods)

This might look a bit complicated, but it really isn't.

  • K is the number of clusters
  • x = Vector{Float64}(undef, 1) means that if we don't specify any data x when instantiating the Model, x will be treated as a random variable to be sampled rather than as an observation. As we will see in a bit, this allows us to generate samples for x from the Model!
  • μ = Vector{Float64}(undef, K) will hold the means of the different clusters; again, the undef will simply allow us to choose between fixing or sampling μ
  • π will denote the weights of the clusters, i.e. π[k] is the probability of assignment to cluster k. Setting it equal to the number of clusters K assumes uniform cluster weights.
  • tzeros is simply a initialization procedure for latent random variables which is compatible with Turing's particle samplers (we're going to use Particle gibbs PG). See Turing's documentation for more info.

Let's produce some samples from the model with 2019-06-25-variational-inference-part-1-cavi_5851361b9629632a09b97aaaa7decb2847d7c755.png, 2019-06-25-variational-inference-part-1-cavi_f873b75d8f6bca5dc5ace58f2488c184faad6ca8.png and 2019-06-25-variational-inference-part-1-cavi_e3c947b61ef3f5a7cfe931a6d4470b59592b45c3.png together with the cluster weights 2019-06-25-variational-inference-part-1-cavi_eff4425bef531ee55a8539cf115bc40eeacbd90d.png, 2019-06-25-variational-inference-part-1-cavi_0211c12c3ae50453b85e8cf9341d461cdef27a2c.png and 2019-06-25-variational-inference-part-1-cavi_63f226b2259afa8f4637876a31565e54762ff838.png, e.g. there's a probability 2019-06-25-variational-inference-part-1-cavi_eff4425bef531ee55a8539cf115bc40eeacbd90d.png of being assigned cluster 2019-06-25-variational-inference-part-1-cavi_6b314f00387982277c754a67e9002cc0f3dd7144.png which is normally distributed with mean 2019-06-25-variational-inference-part-1-cavi_5851361b9629632a09b97aaaa7decb2847d7c755.png.

# sample from the model
π = [0.1, 0.2, 0.7] # cluster weights
mix = mixture(μ = [-5.0, 0.0, 5.0], π = π)
samples = sample(mix, PG(200, n));
[PG] Sampling...  0%  ETA: 0:44:51[PG] Sampling...  4%  ETA: 0:01:49[PG] Sampling...  8%  ETA: 0:01:03[PG] Sampling... 12%  ETA: 0:00:48[PG] Sampling... 16%  ETA: 0:00:39[PG] Sampling... 20%  ETA: 0:00:34[PG] Sampling... 24%  ETA: 0:00:30[PG] Sampling... 27%  ETA: 0:00:28[PG] Sampling... 31%  ETA: 0:00:25[PG] Sampling... 35%  ETA: 0:00:23[PG] Sampling... 39%  ETA: 0:00:21[PG] Sampling... 43%  ETA: 0:00:19[PG] Sampling... 47%  ETA: 0:00:17[PG] Sampling... 51%  ETA: 0:00:16[PG] Sampling... 55%  ETA: 0:00:14[PG] Sampling... 59%  ETA: 0:00:13[PG] Sampling... 63%  ETA: 0:00:12[PG] Sampling... 67%  ETA: 0:00:10[PG] Sampling... 71%  ETA: 0:00:09[PG] Sampling... 75%  ETA: 0:00:08[PG] Sampling... 79%  ETA: 0:00:06[PG] Sampling... 82%  ETA: 0:00:05[PG] Sampling... 86%  ETA: 0:00:04[PG] Sampling... 90%  ETA: 0:00:03[PG] Sampling... 94%  ETA: 0:00:02[PG] Sampling... 98%  ETA: 0:00:01[PG] Sampling...100% Time: 0:00:29
┌ Info: [PG] Finished with
└ @ Turing.Inference /home/tor/Projects/mine/Turing.jl/src/inference/AdvancedSMC.jl:195
┌ Info:   Running time    = 28.558703612000002;
└ @ Turing.Inference /home/tor/Projects/mine/Turing.jl/src/inference/AdvancedSMC.jl:196

plot(samples[:x], size = (1000, 400))

Sorry, your browser does not support SVG.

Looks nice! Let's extract the cluster assignments and the observations:

cluster_assignments = vec([v for v ∈ samples[:c].value])
x = vec([e for e ∈ samples[:x].value])
1000-element Array{Float64,1}:

Now we implement the VariationalInference interface for Turing. First we define the VariationalInference type CAVI:

struct CAVI
    max_iter # maximum number of optimization steps

And how the ELBO is evaluated for this particular VI algorithm and this MeanField model:

function (elbo::Variational.ELBO)(alg::CAVI, q::MeanField, model, num_samples)
    res = 0.0

    for i = 1:num_samples
        z = rand(q)
        res += logdensity(model, z) / num_samples

    return res - entropy(q)

Observe that here we are using an empirical estimate of the expectation 2019-06-25-variational-inference-part-1-cavi_45292e533573f568f9b9f564017e867806779e4b.png while we can compute 2019-06-25-variational-inference-part-1-cavi_92a4298888bbbd3de80d00bdcea8c259911c9a4f.png in closed form. entropy is provided by StatsBase for both Categorical and Normal, so we don't even have to implement those. Under a mean-field assumption, the entropy of q is simply the sum of the entropies of the independent variables.

Both the logdensity and MeanField definitions you can find in the appendix.

As mentioned before, we now need to at the very least define optimize! and vi. This is kind of a contrived example since the update derived before is only applicable to a model of this particular form. Because of this, we're going to instantiate the mixture with the generated data and then dispatch on typeof(model).

model = mixture(x = x, π = π)
Turing.Model{Tuple{:c,:μ},Tuple{:x},getfield(Main, Symbol("###inner_function#375#11")){Array{Float64,1}},NamedTuple{(:x,),Tuple{Array{Float64,1}}},NamedTuple{(:μ, :x),Tuple{Array{Float64,1},Array{Float64,1}}}}(getfield(Main, Symbol("###inner_function#375#11")){Array{Float64,1}}(Core.Box(getfield(Main, Symbol("###inner_function#375#11")){Array{Float64,1}}(#= circular reference @-2 =#)), [0.1, 0.2, 0.7]), (x = [6.14782, 0.466309, -0.625808, 5.72476, 4.85215, 8.0691, 6.02043, 5.25246, 5.2363, 4.91876  …  4.83473, -4.66808, 4.09075, 4.55784, 3.95541, 0.4797, -4.07867, -5.56018, 5.77378, -4.99269],), (μ = [6.93058e-310, 6.93058e-310, 6.93058e-310], x = [6.93058e-310]))

function optimize!(alg::CAVI, elbo::Variational.ELBO, q, model::typeof(model), m, s², φ)
    # ideally we'd pack and unpack the params `m`, `s²` and `φ`
    # but for sake of exposition we skip that bit here
    K = length(m)
    n = size(φ, 1)

    for step_idx = 1:alg.max_iter
        # update cluster assignments
        for i = 1:n
            for k = 1:K
                # If the cluster weights were uniform, we could do without accesing π:
                # φ[i, k] = exp(m[k] * x[i] - 0.5 * (s²[k] + m[k]^2))

                # HACK: π is global variable :/ 
                φ[i, k] = π[k] * exp(m[k] * x[i] - 0.5 * (s²[k] + m[k]^2))

            φ[i, :] ./= sum(φ[i, :])  # normalize

        # HACK: σ² here is a global variable :/

        # update Gaussians
        for k = 1:K
            φ_k = φ[:, k]
            denominator = (1 / σ²) + sum(φ_k)
            m[k] = sum(φ_k .* x) / denominator
            s²[k] /= denominator

    m, s², φ
optimize! (generic function with 1 method)

function, alg::CAVI, q::MeanField)
    discrete_idx = findfirst(d -> d isa DiscreteNonParametric, q.dists)
    K = length(q.dists[discrete_idx].p)           # extract the number of clusters

    n = sum(isa.(q.dists, DiscreteNonParametric)) # each data-point has a cluster assignment

    # initialization of variables
    φ = rand(n, K)
    φ ./= sum(φ; dims = 2)
    m = randn(K)
    s² = ones(K)

    # objective
    elbo = Variational.ELBO()

    # optimize
    optimize!(alg, elbo, q, model, m, s², φ)

    # create new distribution
    q_new = MeanField(copy(q.dists), q.ranges)

    for k = 1:K
        q_new.dists[k] = Normal(m[k], sqrt(s²[k]))

    for i ∈ eachindex(x)
        q_new.dists[i + K] = Categorical(φ[i, :])


Now we construct our mean-field approximation from model.

# extract priors into mean-field
var_info = Turing.VarInfo();
model(var_info, Turing.SampleFromUniform())

q = MeanField(deepcopy(var_info.dists), var_info.ranges)
dists: Distribution[Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=0.0, σ=1.0), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7])  …  DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.1, 0.2, 0.7])]
ranges: UnitRange{Int64}[1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10  …  994:994, 995:995, 996:996, 997:997, 998:998, 999:999, 1000:1000, 1001:1001, 1002:1002, 1003:1003]

Let's check that it indeed produces the expected samples:

1003-element Array{Float64,1}:

The only issue with this MeanField implementation is that to satisfy the <: Distribution{Multivariate, Continuous} type constraint, we need to return AbstractArray{<: Real}. This means that the cluster assignments, which ideally would be integers, are now returned as floats.

This is the reason why in the definition of mixture, we added a Int(c[i]). It's kind of ugly, but, yah know, it works.

Finally we can perform CAVI on this mixture model:

# perform VI
cavi = CAVI(10)
q_new = vi(model, cavi, q)
dists: Distribution[Normal{Float64}(μ=-0.0720897, σ=1.42392e-11), Normal{Float64}(μ=-4.91145, σ=5.46488e-11), Normal{Float64}(μ=4.92814, σ=4.9844e-15), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[1.1981e-9, 1.67617e-27, 1.0]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.999615, 1.21824e-6, 0.000384242]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.999758, 0.000240409, 1.63404e-6]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[9.93394e-9, 1.07655e-25, 1.0]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[7.7963e-7, 5.76301e-22, 0.999999]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[8.06611e-14, 1.03465e-35, 1.0]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[2.26518e-9, 5.87004e-27, 1.0])  …  DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[8.50589e-7, 6.84062e-22, 0.999999]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[1.32992e-5, 0.999987, 3.62968e-20]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[3.50883e-5, 1.03289e-18, 0.999965]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[3.39579e-6, 1.04282e-20, 0.999997]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[6.90294e-5, 3.91168e-18, 0.999931]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.999588, 1.14177e-6, 0.000410837]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[0.000230365, 0.99977, 1.19755e-17]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[1.77431e-7, 1.0, 5.59742e-24]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[7.77467e-9, 6.64628e-26, 1.0]), DiscreteNonParametric{Int64,Float64,Base.OneTo{Int64},Array{Float64,1}}(support=Base.OneTo(3), p=[2.76477e-6, 0.999997, 1.4889e-21])]
ranges: UnitRange{Int64}[1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10  …  994:994, 995:995, 996:996, 997:997, 998:998, 999:999, 1000:1000, 1001:1001, 1002:1002, 1003:1003]

Let's compare the ELBO of q_new to q:

elbo = Variational.ELBO()
elbo(cavi, q, model, 100), elbo(cavi, q_new, model, 100)
(-13270.643508809633, -2240.377147461269)

Hey, it improved! Neat. Let's inspect the obtained parameters.

q_new.dists[1:K]  # <= distributions for the cluster means
3-element Array{Distribution,1}:
 Normal{Float64}(μ=-0.07208966158030183, σ=1.4239202985246593e-11)
 Normal{Float64}(μ=-4.911453128528145, σ=5.464878645528e-11)      
 Normal{Float64}(μ=4.928144006028532, σ=4.984398204586815e-15)    

As we can see, the means are fairly close and the variances are small. This is what we expected (or rather hoped for, but since it worked out, we just say it was expected, right guys?). In the data generation process we had the means fixed, i.e. "variance was zero", thus the variance of the estimates should also be small.

This further displays how this is a "contrived" example in Bayesian inference. Since we fixed the means 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png during the generation process, and the likelihood 2019-06-25-variational-inference-part-1-cavi_d323085fdf061c1f37aaa855af8a562d06f0bfa7.png had unit-variance, the posterior distributions we get should (and do) converge to, basically, point-estimates. As a result, sampling from the resulting variational posterior q_new will almost always give the same result.

This kind of "loses" the purpose of doing Bayesian inference in the first place since the "true" distribution of the latent variables is a discrete distribution with a single outcome, i.e. fixed distribution. At the same time, most interesting examples will probably not be in the conditionally conjugate family, thus CAVI won't be applicable. As mentioned before, there are more general approaches to come!

We can also inspect the learned cluster weights:

vec(mean(transpose(hcat([d.p for d ∈ q_new.dists[K + 1:end]]...)); dims = 1))
3-element Array{Float64,1}:

Comparing to the cluster assignments in the dataset:

[sum(cluster_assignments .== v) / length(cluster_assignments) for v ∈ sort(unique(cluster_assignments))]
3-element Array{Float64,1}:

Notice that indeed that the cluster weights correspond to the correct means, e.g. cluster with 2019-06-25-variational-inference-part-1-cavi_e3c947b61ef3f5a7cfe931a6d4470b59592b45c3.png indeed had weight 2019-06-25-variational-inference-part-1-cavi_63f226b2259afa8f4637876a31565e54762ff838.png. Neat stuff, ain't it?

Observe that the order of the cluster means is different from what we used in the data generation; we used the ordering [-5.0, 0.0, 5.0]. This is because the clusters are exchangable, i.e. any permutation of the cluster-indices will result in the same approximate posterior (modulo stochasticity). Which order we get these in simply comes down to the initialization values: if 2019-06-25-variational-inference-part-1-cavi_e7eade0cd650c5ec6ca3682d82834b8a9dccae26.png is closer to 2019-06-25-variational-inference-part-1-cavi_daf30f9f2c4ddf9bc5bec19f4acf22e064cc6284.png than 2019-06-25-variational-inference-part-1-cavi_69db63664777b86be30ce960ca277b9c54d37d3a.png after initialization, then it's very likely that the resulting variational posterior will have ordering with the index 2 corresponding to the "true" index 1, if yah feel me.

Let's finish off with a neat GIF of the cluster assignment. In creating this GIF, we'll also see how we the user can write VI for Turing as a "training" loop, more similar to other machine learning techniques. The general idea is to

  1. Initialize the parameters 2019-06-25-variational-inference-part-1-cavi_a14d8487fae6909c522cdd676b01387e7c5d213d.png, 2019-06-25-variational-inference-part-1-cavi_8f44de754519a6b4b737f2e24c2083a1a7cb4a03.png, and 2019-06-25-variational-inference-part-1-cavi_b614273a2c6406d98daf1108e22aa446959c69eb.png, together with the initial mean-field distribution q.
  2. Perform an optimization step using optimize!
  3. Plot the inferred cluster assignments and ELBO.
# Let's do one where we plot the progress
q = MeanField(deepcopy(var_info.dists), var_info.ranges)

cavi = CAVI(1)  # perform 1 step for each `optimize!` call

φ = rand(n, K)  # cluster assignments
φ ./= sum(φ; dims = 2)
m = randn(K)
s² = ones(K)

# update distributions
for k = 1:K
    q.dists[k] = Normal(m[k], sqrt(s²[k]))

for i ∈ eachindex(x)
    q.dists[i + K] = Categorical(φ[i, :])

objectives = []  # history of objective evaluations

anim = @animate for step_idx = 1:20
    # compute empirical estimate of ELBO
    step_elbo = elbo(cavi, q, model, 100)
    push!(objectives, step_elbo)

    @info "[$step_idx]" step_elbo

    # plotting of ELBO
    # sort so that we get the cluster in the middle is coloured red for visibility
    mean_idx_sort = sortperm(m)
    color_map = Dict(zip(sortperm(m), [colorant"#41b8f4", colorant"#f44155", colorant"#13c142"]))
    colors = map(c -> color_map[c], Int.(rand(q)[K + 1:end]))

    plt1 = scatter(x, color = colors, label = "")
    title!("Step $step_idx with ELBO $step_elbo")
    xlabel!("Index of data point")

    plt2 = plot(max(1, step_idx - 10):step_idx, objectives[max(1, step_idx - 10):step_idx], label = "ELBO")
    xlabel!("Step index")

    # run CAVI for 1 iteration to update parameters m, s², φ
    optimize!(cavi, elbo, q, model, m, s², φ)

    # update distributions so we can evaluate the ELBO with updated parameters
    for k = 1:K
        q.dists[k] = Normal(m[k], sqrt(s²[k]))

    for i ∈ eachindex(x)
        q.dists[i + K] = Categorical(φ[i, :])

    plot(plt1, plt2, layout = grid(2, 1, heights = [0.7, 0.3]), size = (1000, 500))
gif(anim, ".2019-06-25-variational-inference-part-1-cavi/figures/gaussian_mixture_cluster_assignments.gif", fps = 5);
┌ Info: [1]
│   step_elbo = -14505.608414741417
└ @ Main In[35]:28
┌ Info: [2]
│   step_elbo = -4295.3039673855865
└ @ Main In[35]:28
┌ Info: [3]
│   step_elbo = -3322.581925954455
└ @ Main In[35]:28
┌ Info: [4]
│   step_elbo = -3302.4541140390593
└ @ Main In[35]:28
┌ Info: [5]
│   step_elbo = -3300.152447964722
└ @ Main In[35]:28
┌ Info: [6]
│   step_elbo = -3295.6891878053775
└ @ Main In[35]:28
┌ Info: [7]
│   step_elbo = -3289.9682776026852
└ @ Main In[35]:28
┌ Info: [8]
│   step_elbo = -3284.6262893228004
└ @ Main In[35]:28
┌ Info: [9]
│   step_elbo = -3273.6817405178094
└ @ Main In[35]:28
┌ Info: [10]
│   step_elbo = -3266.9336190510426
└ @ Main In[35]:28
┌ Info: [11]
│   step_elbo = -3259.8310113380035
└ @ Main In[35]:28
┌ Info: [12]
│   step_elbo = -3248.8725276001446
└ @ Main In[35]:28
┌ Info: [13]
│   step_elbo = -3243.003893619992
└ @ Main In[35]:28
┌ Info: [14]
│   step_elbo = -3235.485410286028
└ @ Main In[35]:28
┌ Info: [15]
│   step_elbo = -3226.4825612330606
└ @ Main In[35]:28
┌ Info: [16]
│   step_elbo = -3218.4085612846575
└ @ Main In[35]:28
┌ Info: [17]
│   step_elbo = -3209.639842972914
└ @ Main In[35]:28
┌ Info: [18]
│   step_elbo = -3202.432054017625
└ @ Main In[35]:28
┌ Info: [19]
│   step_elbo = -3193.836909448166
└ @ Main In[35]:28
┌ Info: [20]
│   step_elbo = -3184.2329537025767
└ @ Main In[35]:28
┌ Info: Saved animation to 
│   fn = /home/tor/org-blog/posts/.2019-06-25-variational-inference-part-1-cavi/figures/gaussian_mixture_cluster_assignments.gif
└ @ Plots /home/tor/.julia/packages/Plots/oiirH/src/animation.jl:90


In the future

Though we have indeed derived the algorithm for a particular example, and the current implementation is specified to an instance of a particular mixture model, one could imagine generalizing this approach. We noted earlier that this is indeed an instance of more general family of distributions, which we can perform CAVI on with closed form updates blei16_variat_infer. One could imagine automatically detecting whether or not a Model consisted of exponential distributions and if the entire model was conditionally conjugate. Then we could simply construct a suitable mean-field approximation and then run CAVI with these more general updates!

That would be neat.

Final remarks

Aight, so this was interesting and all, but we had to do a lot of manual labour to obtain the variational updates. In an ideal world you have defined some Model and you just want to call vi(model, alg), maybe even make the choice of the variational family yourself by calling vi(model, alg, q), and then Turing does the rest. Not having to derive closed form updates, not having to worry about whether the densities in your model are in the exponential family, etc. In the next post we will have a look at a variational inference algorithm which gets us further in that direction: automatic differentiation variational inference (ADVI).

The animation below is obtained simply by calling vi(model, ADVI(10, 1000)) on a simple generative model:


It's going to be neat.


  • [blei16_variat_infer] Blei, Kucukelbir, McAuliffe & Jon, Variational Inference: a Review for Statisticians, CoRR, (2016). link.
  • [pmlr-v80-chen18k] Chen, Tao, Zhang, Henao & Duke, Variational Inference and Model Selection with Generalized Evidence Bounds, 893-902, in in: Proceedings of the 35th International Conference on Machine Learning, edited by Dy & Krause, PMLR (2018)
  • [liu16_stein_variat_gradien_descen] Liu & Wang, Stein Variational Gradient Descent: a General Purpose Bayesian Inference Algorithm, CoRR, (2016). link.



# Mean-field for arbitrary distributions
import Distributions: _rand!, _logpdf

struct MeanField{TDists <: AbstractVector{<: Distribution}} <: Distribution{Multivariate, Continuous}

MeanField(dists) = begin
    ranges::Vector{UnitRange{Int}} = []
    idx = 1
    for d ∈ dists
        push!(ranges, idx:idx + length(d) - 1)
        idx += length(d)

    MeanField(dists, ranges)

# Base.length(mf::MeanField) = sum(length(d) for d ∈ mf.dists)
Base.length(mf::MeanField) = mf.ranges[end][end]
_rand!(rng::AbstractRNG, mf::MeanField, x::AbstractVector{T} where T <: Real) = begin
    for i ∈ eachindex(mf.ranges)
        d = mf.dists[i]
        r = mf.ranges[i]

        x[r] = rand(rng, d, 1)


_logpdf(mf::MeanField, x::AbstractVector{T} where T <: Real) = begin
    tot = 0.0
    for i ∈ eachindex(mf.ranges)
        r = mf.ranges[i]
        if length(r) == 1
            tot += logpdf(mf.dists[i], x[r[1]])
            tot += logpdf(mf.dists[i], x[r])

    return tot

import StatsBase: params, entropy
entropy(mf::MeanField) = begin
logdensity(model::Turing.Model, z) = begin
    # kind of hacky; improved impl coming to Turing soon:

    # initialize VarInfo
    var_info = Turing.VarInfo()

    var_info.vals .= z

Derivation of variational updates

Recall the equation we found ealier for updates when the model is conditionally conjugate:


Now we consider the different mean-field factors.

Mixture assignments

In this case we simply have


since the cluster assignment of 2019-06-25-variational-inference-part-1-cavi_64a587acefac95d72fb90156daa7779bb32ab6d1.png for 2019-06-25-variational-inference-part-1-cavi_c1a095bc430b73abaf63eae78d315295aaa78937.png is (conditionally) independent of the cluster assignments of the rest of the observations, denoted 2019-06-25-variational-inference-part-1-cavi_f49096d4f4e40dfc7417661094e875f0ea9957f3.png. The first term is simply


where 2019-06-25-variational-inference-part-1-cavi_9f656ddb23a905a01ee310dcfd4c1da139a1024a.png is the cluster-weight of the i-th cluster, or equivalently, the probability of assignment to the i-th cluster. Thus the interesting term is the expectation of the log-likelihood. Though in our implementation of this Gaussian mixture, 2019-06-25-variational-inference-part-1-cavi_6f5e48c2022723307aafa4cc44933aa17d28e91d.png, we can identitically consider 2019-06-25-variational-inference-part-1-cavi_64a587acefac95d72fb90156daa7779bb32ab6d1.png as a indicator vector, with the k-th component given by




Then we can write the log-likelihood




where we have used the fact that the


to "move" the constant expression out of the sum, and, recalling that 2019-06-25-variational-inference-part-1-cavi_a500ea3ff82c9cea65d5902ccb313b429cc78963.png is simply a unit variance Gaussian with mean 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png,


Moreover, observe that


and that


Substituting back into the above expression


where we have observed the 2019-06-25-variational-inference-part-1-cavi_9ec8e7c68bf7b7b4b982acf6542c353e63f55062.png into the 2019-06-25-variational-inference-part-1-cavi_130bc2b12c9dd1fe8791a2571f820aaf6b06845c.png term. This is indeed is equivalent to the expression derived in blei16_variat_infer by simply observing that


This finally gives us the full update for 2019-06-25-variational-inference-part-1-cavi_e2b6cb82e2f2dbcf424839a7643e2ec3c238e5ac.png:


Means of mixture components

For the means 2019-06-25-variational-inference-part-1-cavi_daf4c98ec256e58b010370c12204e88945e82b90.png we have


Observe then that we can complete the square


Exponentiating this and bringing the term in red into a multiplicative constant (wrt. 2019-06-25-variational-inference-part-1-cavi_0e3a1d7a935bba2108a947eb6269ea0e4d824db7.png) 2019-06-25-variational-inference-part-1-cavi_a4680369a3334fa4e273bbd161e5ca60a5d7c8a5.png, leaving us with




Congratulations! It's a bo…Gaussian!

Actually, in blei16_variat_infer they insist on using 2019-06-25-variational-inference-part-1-cavi_983afa4cb71952f7c1e1111b04fbf529d4677059.png instead of just writing 2019-06-25-variational-inference-part-1-cavi_1f16af3cf4c8f31242256354a1103d47a6c017f6.png, etc. At first I didn't understand why, but now I realize: the above is basically deriving that the conditionally conjugate density is a Gaussian, rather than assuming so! Neato.



It's at least famous in certain circles. Maybe few circles, but still, they exist… There are dozens of us! Dozens!


As we will see in later posts in this series, this is not always the case. E.g. the class of variational inference methods called Ammortized VI uses a different latent variable 2019-06-25-variational-inference-part-1-cavi_f4d02d8b757b78d2991836b88709a29d2d6c38f9.png for each observation 2019-06-25-variational-inference-part-1-cavi_c1a095bc430b73abaf63eae78d315295aaa78937.png.


Usually the variational family 2019-06-25-variational-inference-part-1-cavi_dc8fda671dad6051d988d7dd31211745f3f81823.png is chosen s.t. 2019-06-25-variational-inference-part-1-cavi_34724c66b3e97a635a63f6916f5be50dd45cf282.png is cheap to sample from.


You noticed the v3? What happened to v1 and v2? Don't worry about it, they're fine.