Variational Inference (Part 0): What?

What is this and who are you?

I'm Tor. I'm 24 years old, Norwegian, and just finished my BSc in Mathematics. More importantly, I'm lucky enough to get to work on variational inference in Turing.jl!

And before you ask, no I don't own a hammer (not even a small one), nor am I to blame for the awful rain and thunder you experienced yesterday. On the other hand I do have luscious hair and bulging arms.1

"Part 0"? You planning on doing more of these?

Yes! I intend to publish a series of posts on variational inference (VI) and the different VI methods we implement in Turing.jl. In this series I will attempt to explain both the theory behind the methods, how we went about the implementations, and finally, (hopefully) working examples using the implemented method in Turing.jl.

Reader requirements

Hard requirements:

  • [ ] Familiarity with Bayesian inference. If "Bayesian inference" is completely unknown to you I suggest having a look at that first. There are tons of good resources on the topic, suitable for different levels of mathematical familiarity.

Soft requirements:

  • [ ] Some familiarity with Julia. Though we will make use of a lot of nice features Julia offers, I will not be going into detail about these, e.g. macros, multiple-dispatch. Therefore it's useful if you're already a bit familiar, or at the very least can read Julia code.
  • [ ] Some familiarity with "exact" methods for Bayesian inference, e.g. MCMC. Otherwise it might be difficult to fully see why one would even bother with variational inference in the first place, though I will do my best to motivate in the next post.

It's worth noting that the focus in this series will be on the variational inference aspect of Turing.jl, and so in this series I won't necessarily introduce you to all the nice unbiased samplers Turing.jl already provides. So if you're not already familiar with this part of Turing.jl, you should definitively check it out at https://turing.ml!

If all those boxes are ticked + familiar with Turing.jl and you just want to get straight to the variational inference bit, just head to the next post in the series: Part 1: Variational Inference (VI).

Turing.jl

First I'll just quickly run you through what Turing.jl is.

"Turing.jl is a probabilistic programming library in Julia which automates probabilistic learning by expressing probabilistic models as Julia programs and then Bayesian inference is performed using generic inference engines."

In short: Turing.jl is a library for doing Bayesian modelling.

"Sure Tor, but there are tons of libraries out there which can do that in other languages! E.g. PyMC3, Tensorflow Probability and Pyro."

Sure, sure, but can they define models like THIS?!

using Turing

# define Turing model
@model demo(x) = begin
    s ~ InverseGamma(2,3)
    m ~ Normal(0.0, sqrt(s))  # `Normal(μ, σ)` has mean μ and variance σ², i.e. parametrize with std. not variance

    for i = 1:length(x)
        x[i] ~ Normal(m, sqrt(s))
    end
end
demo (generic function with 2 methods)

Let's compare this to the mathematical notation used to describe such a model

2019-06-24-variational-inference-part-0_5e6250e4f5f5ddb60b8c291e16ac2c315065a179.png

I know what you're thinking: "But..but they look almost exactly the same?!", and if you weren't thinking that, you are now! It's awesome. Moreover, Turing.jl of course also provides several state-of-the-art samplers out of the box, e.g. NUTS:

x = randn(1, 1000);

# instantiates the model with the data
m = demo(x)
samples = sample(m, NUTS(2000, 0.65));
┌ Info: [Turing] looking for good initial eps...
└ @ Turing.Inference /home/tor/.julia/packages/Turing/RZOZ8/src/inference/support/hmc_core.jl:247
┌ Info: [Turing] found initial ϵ: 0.12589111328125002
└ @ Turing.Inference /home/tor/.julia/packages/Turing/RZOZ8/src/inference/support/hmc_core.jl:239
[NUTS] Sampling...  0%  ETA: 1:31:26┌ Warning: Numerical error in gradients. Rejecting current proposal...
└ @ Turing.Core /home/tor/.julia/packages/Turing/RZOZ8/src/core/ad.jl:169
┌ Warning: grad = Real[NaN, NaN]
└ @ Turing.Core /home/tor/.julia/packages/Turing/RZOZ8/src/core/ad.jl:170
[NUTS] Sampling...  0%  ETA: 1:10:57
  ϵ:         2.3788348367108965
  α:         0.0



[NUTS] Sampling...  2%  ETA: 0:04:05
  ϵ:         0.053653608797217596
  α:         0.10495062744920733



[NUTS] Sampling...  4%  ETA: 0:02:30
  ϵ:         0.1223303417500384
  α:         9.767513352337008e-24



[NUTS] Sampling...  6%  ETA: 0:01:57
  ϵ:         0.023116817857212917
  α:         0.9861896838444705



[NUTS] Sampling...  7%  ETA: 0:01:51
  ϵ:         0.016369392393220186
  α:         0.9978163697491695



[NUTS] Sampling...  9%  ETA: 0:01:38
  ϵ:         0.01629159052563758
  α:         0.9846227352981666



[NUTS] Sampling... 11%  ETA: 0:01:28
  ϵ:         0.04275777107361567
  α:         0.8744296096607371



[NUTS] Sampling... 13%  ETA: 0:01:19
  ϵ:         0.013126261192410438
  α:         0.9896791604154688



[NUTS] Sampling... 14%  ETA: 0:01:16
  ϵ:         0.026810670111053408
  α:         0.9301512903325442



[NUTS] Sampling... 15%  ETA: 0:01:15
  ϵ:         0.024417151098626622
  α:         0.9094814370153583



[NUTS] Sampling... 17%  ETA: 0:01:10
  ϵ:         0.029826594791330485
  α:         0.9907998672792551



[NUTS] Sampling... 19%  ETA: 0:01:05
  ϵ:         0.031927886575882775
  α:         0.9945230605842369



[NUTS] Sampling... 22%  ETA: 0:01:00
  ϵ:         0.05638564452407551
  α:         0.44339498259004895



[NUTS] Sampling... 24%  ETA: 0:00:57
  ϵ:         0.017658834965831066
  α:         1.0



[NUTS] Sampling... 26%  ETA: 0:00:54
  ϵ:         0.07816408923070961
  α:         0.003507335229872055



[NUTS] Sampling... 28%  ETA: 0:00:52
  ϵ:         0.030858667796674483
  α:         0.8238358385455912



[NUTS] Sampling... 30%  ETA: 0:00:49
  ϵ:         0.0502577425854659
  α:         0.35938402043263784



[NUTS] Sampling... 32%  ETA: 0:00:46
  ϵ:         0.03367517396171061
  α:         0.9310299579592963



[NUTS] Sampling... 34%  ETA: 0:00:44
  ϵ:         0.05195048105430525
  α:         0.4962303319615342



[NUTS] Sampling... 37%  ETA: 0:00:41
  ϵ:         0.03284161326355648
  α:         0.9841434537832825



[NUTS] Sampling... 39%  ETA: 0:00:38
  ϵ:         0.05377445985356001
  α:         0.5504951115378797



[NUTS] Sampling... 42%  ETA: 0:00:36
  ϵ:         0.06323084214780471
  α:         0.05573974717210862



[NUTS] Sampling... 44%  ETA: 0:00:34
  ϵ:         0.03522497678779347
  α:         1.0



[NUTS] Sampling... 47%  ETA: 0:00:32
  ϵ:         0.0478552229297725
  α:         0.02023054561170522



[NUTS] Sampling... 49%  ETA: 0:00:30
  ϵ:         0.027718592379259624
  α:         1.0
┌ Info:  Adapted ϵ = 0.014890205235942436, std = [1.0, 1.0]; 1000 iterations is used for adaption.
└ @ Turing.Inference /home/tor/.julia/packages/Turing/RZOZ8/src/inference/adapt/adapt.jl:90
[NUTS] Sampling... 50%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9952924957532813



[NUTS] Sampling... 51%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9975770985181144



[NUTS] Sampling... 52%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.996393336678336



[NUTS] Sampling... 52%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9749962197714194



[NUTS] Sampling... 53%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9921941505042944



[NUTS] Sampling... 54%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9874893487731604



[NUTS] Sampling... 55%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9877071887629078



[NUTS] Sampling... 56%  ETA: 0:00:29
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 57%  ETA: 0:00:29
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 57%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9984403339496662



[NUTS] Sampling... 58%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9999528100916278



[NUTS] Sampling... 58%  ETA: 0:00:30
  ϵ:         0.014890205235942436
  α:         0.9736666679867931



[NUTS] Sampling... 59%  ETA: 0:00:29
  ϵ:         0.014890205235942436
  α:         0.9940544930127959



[NUTS] Sampling... 60%  ETA: 0:00:29
  ϵ:         0.014890205235942436
  α:         0.9513031791376916



[NUTS] Sampling... 61%  ETA: 0:00:29
  ϵ:         0.014890205235942436
  α:         0.9635238926298488



[NUTS] Sampling... 62%  ETA: 0:00:28
  ϵ:         0.014890205235942436
  α:         0.9979739095089367



[NUTS] Sampling... 62%  ETA: 0:00:28
  ϵ:         0.014890205235942436
  α:         0.9776650470249726



[NUTS] Sampling... 63%  ETA: 0:00:28
  ϵ:         0.014890205235942436
  α:         0.9921609035347982



[NUTS] Sampling... 64%  ETA: 0:00:27
  ϵ:         0.014890205235942436
  α:         0.9920824067383612



[NUTS] Sampling... 65%  ETA: 0:00:27
  ϵ:         0.014890205235942436
  α:         0.9987424972819761



[NUTS] Sampling... 66%  ETA: 0:00:26
  ϵ:         0.014890205235942436
  α:         0.9503756399501381



[NUTS] Sampling... 67%  ETA: 0:00:26
  ϵ:         0.014890205235942436
  α:         0.9862107576959083



[NUTS] Sampling... 67%  ETA: 0:00:25
  ϵ:         0.014890205235942436
  α:         0.9807856765792222



[NUTS] Sampling... 68%  ETA: 0:00:25
  ϵ:         0.014890205235942436
  α:         0.9965992474831399



[NUTS] Sampling... 69%  ETA: 0:00:24
  ϵ:         0.014890205235942436
  α:         0.9545908468227742



[NUTS] Sampling... 70%  ETA: 0:00:24
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 71%  ETA: 0:00:23
  ϵ:         0.014890205235942436
  α:         0.9403064545526797



[NUTS] Sampling... 72%  ETA: 0:00:23
  ϵ:         0.014890205235942436
  α:         0.9992670985981574



[NUTS] Sampling... 72%  ETA: 0:00:22
  ϵ:         0.014890205235942436
  α:         0.9944298713255253



[NUTS] Sampling... 73%  ETA: 0:00:22
  ϵ:         0.014890205235942436
  α:         0.8631185319063751



[NUTS] Sampling... 74%  ETA: 0:00:21
  ϵ:         0.014890205235942436
  α:         0.9629231469340549



[NUTS] Sampling... 75%  ETA: 0:00:21
  ϵ:         0.014890205235942436
  α:         0.9931142880492831



[NUTS] Sampling... 76%  ETA: 0:00:20
  ϵ:         0.014890205235942436
  α:         0.9946160640872017



[NUTS] Sampling... 77%  ETA: 0:00:20
  ϵ:         0.014890205235942436
  α:         0.9045378458715148



[NUTS] Sampling... 77%  ETA: 0:00:19
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 78%  ETA: 0:00:18
  ϵ:         0.014890205235942436
  α:         0.9653460548260573



[NUTS] Sampling... 79%  ETA: 0:00:18
  ϵ:         0.014890205235942436
  α:         0.9883248155398858



[NUTS] Sampling... 80%  ETA: 0:00:17
  ϵ:         0.014890205235942436
  α:         0.9860482833058883



[NUTS] Sampling... 81%  ETA: 0:00:16
  ϵ:         0.014890205235942436
  α:         0.9777488263120653



[NUTS] Sampling... 82%  ETA: 0:00:16
  ϵ:         0.014890205235942436
  α:         0.997378791751194



[NUTS] Sampling... 83%  ETA: 0:00:15
  ϵ:         0.014890205235942436
  α:         0.9700690411054681



[NUTS] Sampling... 84%  ETA: 0:00:14
  ϵ:         0.014890205235942436
  α:         0.9980598814658146



[NUTS] Sampling... 85%  ETA: 0:00:13
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 86%  ETA: 0:00:13
  ϵ:         0.014890205235942436
  α:         0.9995271327093942



[NUTS] Sampling... 86%  ETA: 0:00:12
  ϵ:         0.014890205235942436
  α:         0.8845035787921371



[NUTS] Sampling... 87%  ETA: 0:00:11
  ϵ:         0.014890205235942436
  α:         0.9959285072895525



[NUTS] Sampling... 88%  ETA: 0:00:11
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 89%  ETA: 0:00:10
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 90%  ETA: 0:00:09
  ϵ:         0.014890205235942436
  α:         0.9460101349811275



[NUTS] Sampling... 90%  ETA: 0:00:08
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 91%  ETA: 0:00:08
  ϵ:         0.014890205235942436
  α:         0.9993303980272074



[NUTS] Sampling... 92%  ETA: 0:00:07
  ϵ:         0.014890205235942436
  α:         0.9991885721258078



[NUTS] Sampling... 93%  ETA: 0:00:07
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 94%  ETA: 0:00:06
  ϵ:         0.014890205235942436
  α:         0.9869219952804218



[NUTS] Sampling... 94%  ETA: 0:00:05
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 95%  ETA: 0:00:04
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 96%  ETA: 0:00:04
  ϵ:         0.014890205235942436
  α:         0.9748849114191735



[NUTS] Sampling... 97%  ETA: 0:00:03
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 98%  ETA: 0:00:02
  ϵ:         0.014890205235942436
  α:         1.0



[NUTS] Sampling... 99%  ETA: 0:00:01
  ϵ:         0.014890205235942436
  α:         0.9943394130006415



[NUTS] Sampling... 99%  ETA: 0:00:01
  ϵ:         0.014890205235942436
  α:         0.974297052943446



[NUTS] Sampling...100% Time: 0:01:33
[NUTS] Finished with
  Running time        = 92.82022274199993;
  #lf / sample        = 0.0;
  #evals / sample     = 0.0005;
  pre-cond. metric    = [1.0, 1.0].
samples
Object of type Chains, with data of type 2000×8×1 Array{Union{Missing, Float64},3}

Log evidence      = 0.0
Iterations        = 1:2000
Thinning interval = 1
Chains            = 1
Samples per chain = 2000
internals         = elapsed, epsilon, eval_num, lf_eps, lf_num, lp
parameters        = m, s

2-element Array{ChainDataFrame,1}

Summary Statistics
. Omitted printing of 2 columns
│ Row │ parameters │ mean        │ std       │ naive_se    │ mcse        │
│     │ Symbol     │ Float64     │ Float64   │ Float64     │ Float64     │
├─────┼────────────┼─────────────┼───────────┼─────────────┼─────────────┤
│ 1   │ m          │ -0.00381341 │ 0.0407008 │ 0.000910098 │ 0.000835512 │
│ 2   │ s          │ 0.950122    │ 0.143686  │ 0.00321292  │ 0.00276883  │

Quantiles
. Omitted printing of 1 columns
│ Row │ parameters │ 2.5%       │ 25.0%      │ 50.0%       │ 75.0%     │
│     │ Symbol     │ Float64    │ Float64    │ Float64     │ Float64   │
├─────┼────────────┼────────────┼────────────┼─────────────┼───────────┤
│ 1   │ m          │ -0.0620671 │ -0.0248262 │ -0.00371837 │ 0.0161158 │
│ 2   │ s          │ 0.872305   │ 0.919291   │ 0.947512    │ 0.97431   │
using Plots, StatsPlots

plot(samples)

Sorry, your browser does not support SVG.

The above is for sure very neat. But there is more!

turing_neatness_intesifies.gif

Figure 2: Turing.jl's logo with intesifying neatness.

Another feature of Turing.jl, which a lot of these other libraries lack, is the ability to execute (basically) arbitrary code in the model. This gives the user a lot of flexibility, allowing neatly customized models. But we'll see more of that later, so for now you'll just have to take my word for it.

Hopefully I've gotten my point across; Turing.jl is neat.

But there is one neat feature which is missing in Turing.jl (as of <2019-06-24 Mon>): variational inference.2 What is it? You'll find out in the next post!

I've read this stupid post; now what?

Woah! No need to be like that! Just go read Part 1: Variational Inference (VI) you jerk. I'll go through the theory behind VI, how we're doing it in Turing.jl and a simple example. To wet your appetite, here's a plot of applying coordinate ascent variational inference (CAVI) to a Gaussian mixture:

gaussian_mixture_cluster_assignments.gif

Footnotes:

1

This may or may not be true.

2

Well, maybe there are some other neat ones too. But VI is a super-neat one!