Expectation propagation
Table of Contents
Notation
- denotes the Kullback-Leibler divergence, to be minimized wrt.
is the approximate distribution, which is a member of the exponential family, i.e. can be written:
- is a fixed distribution
- denotes the data
- denotes the parameters / hidden variables
Exponential family
The exponential family of distributions over , given parameters , is defined to be the set of distribtions of the form
where:
- may be discrete or continuous
- is just some function of
MLE
If we wanted to estimate in a density of the exponential family, we would simply take the derivative wrt of the likelihood of the distribution and set to zero:
Rearranging, we get
which is just
The covariance of can be expressed in terms of the second derivatives of , and similarily for higher order moments.
Thus, if we had set of i.i.d. data denoted by , for which the likelihood function is given by
which is minimized wrt. when
giving us the MLE estimator
i.e. the MLE estimator depends on the data only through , hence is a sufficient statistic.
Minimizing KL-divergence between two exponential distributions
The KL divergence then becomes
where , and define the approximation distribution .
We're not making any assumptions regarding the underlying distribution of .
Which is minimized wrt. by
From MLE estimator for exponential we then have
hence, the optimum solution corresponds to matching the expected sufficient statistics!
Expectation propagation
Joint probability of data and parameters given by
Want to evaluate
and for model comparison:
Expectation propagation is based on the approximation to the posterior distribution by
where is an approximation to the factor in the true posterior
- To be tractable, need to constrain the approximators => assume to be exponential
Want to determine by minimizing KL divergence between true and approx. posterior:
One approach of doing this would be to minimize KL divergence between the pairs and of factors.
It turns out that this no good; even though each of the factors are approximated individually, the product could still give a poor approximation.
(Also, if our assumptions are not true, i.e. true posterior is actually not of the exponential family, then clearly this approach would not be a good one.)
Expectation propagation optimizes each factor in turn, given the "context" of all remaining factors.
Suppose we wish to update our estimate for , then we would like the following (assuming data was drawn from some exponential distribution):
i.e. want to optimize for the j-th factor conditioned on our estimate for all the other factors. (Expectation maximization, anyone?)
This problem we can pose as minimizing the following KL divergence
where:
new approximate distribution:
the unnormalized distribution:
normalization with replaced by :
Which we already know comes down to having:
and we wish to approximate the posterior distribution by a distribution of the form
and the model evidence .
- Initialize all of the approximating factors
Initialize the posterior approximation by setting
- Until convergence:
- Choose a factor to improve
Remove from posterior by division:
Let be such that
i.e. matching the sufficient statistics (moments) of with those of , including evaluating the normalization constant
Evaluate and store new factor
Evaluate the approximation to the model evidence:
Notes
- No guarantee that iterations will converge
- Fixed-points guaranteed to exist when the approximations are in the exponential family
- So at least then it converges, but to what? Dunno
- Can also have multiple fixed-points
- However, if iteratiors do converge; resulting solution will be a stationary point of a particular energy function (although each iteration is not guaranteed to actually decrease this energy function)
- Fixed-points guaranteed to exist when the approximations are in the exponential family
- Does not make sense to apply EP to mixtures as the approximation tries to capture all of the modes of the posterior distribution
Moment matching / Assumed density filtering (ADF): a special case of EP
- Initializes all to unity and performs a single pass through the approx. factors, updating each of them once
- Does not make use of batching
- Highly dependent on order of data points, as are only updated once