Tuesday, July 10, 2018

Approximations of the measurement update and model likelihood for nonlinear spatiotemporal Cox processes

Epilogue: These notes were part of a project to infer unobserved ("latent") states in a neural field model, as well as the parameters of that model, from spiking observations. It has since been published. Ultimately, for speed we ended up selecting the Laplace approximation for the measurement update, solved via Newton-Raphson.

[get PDF]

We are interested in approximations to the measurement update for a spatially-extended latent-variable point-process, where the latent variables are also spatial fields that undergo dynamics similar to chemical reaction-diffusion systems. 

The latent variables or fields are concentrations, activations, or some similar physical quantity. They are therefore constrained to be non-negative, and also typically must obey conservation laws. Additionally, the observed point-process intensity field must also be constrained to be non-negative. 

Such systems arise in chemical reaction diffusion systems, epidemiological models, and neural field models, where the measurement is a point-process that is coupled indirectly to the latent spatiotemporal system.

Problem statement

We will use a multivariate Gaussian approximation to model the joint distribution of latent variables. In the continuous filed case, this is a Gaussian process. In the numerical implementation, we project this process onto a finite basis to give a finite-dimensional multivariate Gaussian distribution. The filtering update therefore requires finding a multivariate Gaussian approximation to the non-conjugate update of a Poisson observation and a multivariate Gaussian prior. 

Estimating this posterior involves an integral that becomes intractable in high-dimensions. There are numerous approximation methods to handle this, including moment-matching, variational Bayes, expected log-likelihoods, and the Laplace approximation. Often, the choice of link function furthers constrain which methods are applicable, as efficient algorithms may only be available in some classes. The update must be also constrained to prevent negative activity in the estimated latent states.

Multivariate Gaussian model

Let A(x) denote a latent vector of activity, defined over a spatial domain with coordinates xΩ on some (likely bounded) domain Ω.  We approximate the prior distribution over the latent activity vector A(x) in terms of its first two moments. We interpret these moments to reflect the mean and variance of a multivariate Gaussian distribution, for the purposes of both moment closure and the measurement update. This assumption is an approximation, as in general the tails of the Gaussian that extent to negative activation values are unphysical. 

(1)Pr(A(x))Gaussian(μA(x),ΣA(x))

A(x) is a field defined over a continuous spatial region, and so Eq. 1 denotes a Gaussian process. In practice, we project this continuous process onto a finite basis and work with N discrete activation variables corresponding to spatial regions, i.e. A={A1,..,AN}. In this numerical implementation, Eq. 1 represents a finite-dimensional multivariate Gaussian distribution.

Latent field definition and discretization

This activation field is mapped to point-process intensities via a link function λ0=f(A). Furthermore, our observation model may include heterogeneity in terms of the density of agents or background level of activity, and we therefore incorporate a spatially inhomogeneous gain γ(x) and bias β(x) that adjust the observe point-process intensity. The intensity as a function of spatial coordinates x is then:

(1)λ(x)=f[A(x)]γ(x)+β(x).

In practice, we project this continuous, infinite-dimensional process onto a finite set of discrete basis elements B={b1,..,bN}, where the expected firing rate for the nth basis element is:

(2)λn=xΩbn(x)λ(x),

where xΩ denotes integration over the spatial domain Ω parameterized by x. If the variations in the activity, gain, and bias, are small relative to the scale of the basis elements, we may approximate this integral as:

(3)λnλ(xbn)vnvn=xΩbn(x),

where xbn is the centre of mass of the basis element bn, and vn is the volume of said basis element. We consider an especially simple case where basis functions all have identical volume v, so we may write the regional intensity as

(4)λnv[γnf(An)+βn],

where v is the (uniform) volume of the basis elements, e.g. v=Δx2Δt for a process with two spatial dimensions and one time dimensions, with a fixed region size. The spatially-varying intensity is then represented as a vector of per-region intensities, λ={λ1,..,λN}. Since the volume parameter v is redundant to the gain parameter γ, in the derivations that follow we assume that the bias and gain parameters have been premultiplied by the volume. 

We introduce the gain and bias parameters to decouple inhomogeneity in the measurements from the underlying, latent spatiotemporal process. For example, in the case of retinal waves, different regions have differing densities of retinal ganglion cells, which amounts to a spatially inhomogeneous gain. Additionally, the amount of spontaneously-active background activity varies. In order to build up mathematical solutions that are immediately useful for numerical implementation, we carry-through these bias and gain parameters in the derivations that follow. 

Count observations and the measurement posterior

Given the Gaussian prior, the Posterior estimate of the latent activations A is given by Bayes' rule:

(5)Pr(AY)=Pr(A)Pr(Y)Pr(YA)

We observe event counts over a finite number of spatial regions, and these observed counts are independent conditioned on the per-region intensity, so we can write:

(6)Pr(AY)=Pr(A)Pr(Y)n1..NPr(ynA)

The dependence of the counts on the latent activation can be expanded to include the regional intensity λn as:

(7)Pr(ynAn)=Pr(ynλn)Pr(λnA)

Since the dependence of λn on A is deterministic (and vice-versa), we can therefore write

(8)Pr(ynAn)=Pr(A)Pr(Y)n1..NPr(ynλn)

We observe regional counts Y={y1,..,yN}, which are Poisson distributed, with the observation likelihood

(9)Pr(ynλn)=λnynyn!eλn.

The log-posterior

Consider the logarithmic form of the measurement update:

(10)logPr(AY)=logPr(A)logPr(Y)+n1..NlogPr(ynλn)

The prior A is approximated by a Gaussian distribution N(μA,ΣA), and so the log-prior on A is:

(11)logPr(A)=12[log|2πΣA|+(AμA)ΣA1(AμA)].

The conditional log-likelihood is given by the Poisson observation model with regional intensity λn=γnf(An)+βn:

(2)logPr(YA)=n1..NlogPr(ynλn)=n1..N[ynlog(λn)λn].=n1..N[ynlog(γnf(An)+βn)(γnf(An)+βn)].

The marginal log-likelihood of the count observations logPr(Y) cannot be computed, except via an intractable integral. Approximating this integral will be a major challenge for computing the model likelihood, which we will address later. However, for fixed count observations Y, the PrY term is constant, and so:

(12)logPr(AY)=12[log|2πΣA|+(Aμa)ΣA1(Aμa)]+n1..N[ynlog(λn)λn]+constant 

Approximation methods

In this section, we explore various approaches to obtaining a Gaussian approximation to the posterior Pr(A|Y)Q(A)N(μ^A,Σ^A). We examine three approaches: the Laplace approximation, the variational Bayes approach, and moment-matching. 

Laplace approximation

For the Laplace approximation, we find the mode of the posterior and interpret the curvature at this mode as the inverse of the covariance matrix. Since we are dealing with spatiotemporal processes driven by physical agents (e.g. molecules, neurons, humans), we constrain the posterior mode to be non-negative. This departs slightly from the traditional Laplace approximation, in which the posterior mode is a non-extremal local maximum with zero slope. For this reason, the interpretation of the curvature at the mode as the inverse of the covariance must be treated with caution. 

(13)μA^=argmaxA[logPr(AY)]

This can be solved with gradient descent or the Newton-Raphson method, which requires the gradient and Hessian of the log-posterior with respect to A

We introduce some abbreviations to simplify the notation in the derivations that follow. Denote log-probabilities "logPr()'' as L(), and denote the first and second derivatives of the log-measurement likelihood with respect to individual activation variables A{A1,..,AN} as LyA and LyA, respectively. 

Note that  λn is synonymous with the locally adjusted rate γnf(An)+βn, such that λn=γnf(An). We also omit indexing by the basis function number n when unambiguous. With these abbreviations, the gradient and Hessian of the log-posterior in A are:

(14)ALAY=(μaA)ΣA1+ALYAA2LAY=ΣA1+A2LYA,

where

(15)LyA=γλf(A)(yλ)LyA=γλ[(yλ)f(A)yγλf(A)2]\subsubsection{The identity link function}

In the case that f(An)=An, these gradients simplify to:

(16)LyA=(yλ)γλLyA=y(γλ)2

The exponential link function

In the case that f(A)=exp(A):

(17)LyA=(yλ)(γeAλ)LyA=(γeAλ)[y(1γeAλ)λ]

For more flexibility, one might add another gain parameter inside the exponentiation, i.e. f(An)=exp(δAn), which gives:

(18)LyA=(yλ)(γδλeδA)LyA=γδ2[yγλ2e2δA+(yλ1)eδA]=δ(yλ)(γδλeδA)y(γδλeδA)2

The quadratic link function

Let λ=A2γ+β. That is, f(A)=A2, f(A)=2A, and f(A)=2.

(19)LyA=γ2A(yλ)LyA=γλ[(yλ)2yγλ4A2]

A more flexible parameterization is λ=γ(A+b)2+β gives:

(20)LyA=γλ(2A+b)(yλ)LyA=γλ[(yλ)2yγλ(2A+b)2]

Logistic link 1

We might want to consider the logistic link function, which maps the range (,) in activation A to (0,1), which then may be further adjusted to span a given range using the gain/bias parameters:

(21)f=11+eδAf=δeδA(1+eδA)2=δ[1f(A)]f(A)f=δ[12f]f.

Logistic link 2

If the activation is bounded on [0,), it might make more sense to apply the logistic function to the log-activation, yielding the following link function:

(22)f=Aϵ+Af=ϵ(ϵ+A)2f=2ϵ(ϵ+A)3.

where ϵn is an additional free parameter that acts a like an inverse gain.

Variational approximation

In the variational approximation, we find a Gaussian distribution Q(A)N(μ^A,Σ^A) that approximates the true posterior by minimizing the KL divergence of the true posterior from the approximating distribution Q. This is conceptually equivalent to jointly maximizing the entropy of Q while also maximizing the expected log-probability of the true posterior under Q

(3)argminμQ,ΣQDKL(QP)=argminμQ,ΣQAQ(A)logQ(A)Pr(AY)=argmaxμQ,ΣQ[H(Q)+logPr(AY)Q]

To obtain a tractable form of the above, first expand logPr(AY) using the logarithmic form of Bayes' rule:

(23)logPr(AY)=logPr(YA)+logPr(A)logPr(Y)

This gives convenient simplifications, as the prior logPr(A) is often Gaussian and has closed-form solutions, and the marginal data likelihood logPr(Y) is constant and can be dropped from the optimization. Expanding logPr(AY)Q in Eq. 3 gives:

(24)argmaxμQ,ΣQ[H(Q)+logPr(YA)Q+logPr(A)QlogPr(Y)Q]

Dropping the constant logPr(Y)Q term and recognizing that the remaining terms reflect the KL divergence of the approximating posterior from the prior, i.e. DKL(QPr(A)), we get the following optimization problem:

(25)argmaxμQ,ΣQ[logPr(YA)QDKL(QPr(A))].

The objective function for variational Bayes amounts to maximizing the data likelihood logPr(YA)Q under the approximation Q, while also minimizing the KL divergence of the prior from the approximating posterior. It can therefore be interpreted as a regularized maximum-likelihood approach. This form also connects to the objective functions often seen in variational autoencoders and in variational free energy. 

In this application both the prior and approximating posterior are Gaussian, and the KL divergence term DKL(QPr(A)) has a closed-form solution reflecting the KL divergence between two multivariate Gaussian distributions:

(26)DKL(QPr(A))=12[log|ΣA||Σ^A|D+tr[ΣA1Σ^A]+(μAμ^A)TΣA1(μAμ^A)],

where D is the dimensionality of the multivariate Gaussian. It remains then to calculate the expected log-likelihood, logPr(YA)Q. As discussed in the next section, this integral is not always tractable.

Challenges for the variational approximation in this application

The variational approximation integrates over the domain for A, which is truncated to [0,) since negative values for A are unphysical. Typically, this means that efficient algorithms are challenging to derive, as closed-form solutions for the relevant integrals do not exist, or at best involve the multivariate Gaussian cumulative distribution function, its inverses, and derivatives, which are numerically expensive to compute. 

One may relax the constraint that A be non-negative, extending the domain of integration to (,), but then one must constrain optimization to return only positive means for the variational posterior. However, unless a rectifying (e.g. exponential, quadratic) link function is used, the inclusion of negative rates in the domain will make the Poisson observation likelihood undefined. For this reason, the variational update has been explored only for the exponential link function [[PARK]]. Because small changes in activation can lead to large fluctuations in rate owing to the amplification of the exponential link function, we have found that the exponential link is numerically unstable. 

(An implementation of variational optimization using the rectifying quadratic link function may be more numerically stable, and remains to be explored.)

Moment-matching

Moment matching calculates or approximates the mean and covariance of the true posterior, and uses these moments to form a multivariate Gaussian approximation. When applied as a message-passing algorithm in a graphical model, moment matching is an important step of the expectation-propagation algorithm. Moment-matching can be performed explicitly by integrating the posterior moments, but in high dimensions there is no computationally tractable way to evaluate such integrals. Since spatial correlations are essential in spatiotemporal phenomena, we cannot discard this higher dimensional structure. 

Another approach to moment-matching to note is that the the Gaussian distribution Q that minimizes KL divergence from Q to the true posterior will also match the moments of the posterior. We can therefore perform moment matching by minimizing this KL divergence:

(27)argminμQ,ΣQAPr(AY)logPr(AY)Q(A)=argmaxμQ,ΣQH[Pr(AY)]+logQPr(AY)

Note that the first term is the entropy of the (true) posterior distribution. It is constant for a given update, and therefore does not affect our optimization. We can focus on the second term, and optimize:

(28)argmaxμQ,ΣQlogQPr(AY)

The log-probability of a Gaussian approximation Q is with mean μ^A and covariance matrix Σ^A is:

(29)logQ(A)=12[log|2πΣA|+(Aμ^A)Σ^A1(Aμ^A)]

We cannot calculate the true posterior Pr(AY), and so the integral logQPr(AY) cannot be computed directly. However, the normalization constant, although unknown, is constant with respect to this optimization, and is suffices to take the weighted expectation with respect to an un-normalized form of Pr(AY)

(30)Pr(AY)=|2πΣA|12e12(Aμa)ΣA1(Aμa)n1..N[(γnf(An)+βn)ynexp(γnf(An)+βn)]

This integral, however, remains essentially intractable, due to the product of Gaussian and Poisson-like terms. In high dimension there is no (to our knowledge) computationally efficient way to estimate this integral or its derivatives. 

Expected log-likelihoods and variational Bayes

So far, we have explored three approaches to finding a Gaussian approximation to the measurement posterior: the Laplace approximation, variational Bayes, and moment matching. Moment matching is unsuitable because, to the best of our knowledge, there is no computationally tractable way to estimate the relevant moments in high-dimensions. The Laplace approximation and variational Bayes remain computationally tractable, with limitations.

The Laplace approximation suffers from errors arising from the non-negativity constraint on activity levels, and the high skewness of the distributions causes the mode to be far from the mean. Errors in estimating the covariance are especially severe, as the covariance controls the trade-off between propagating past information, and incorporating new measurements, during the filtering procedure.

Variational Bayes also has a number of challenges. First, an efficient way of evaluating the relevant integrals and their derivatives is needed. In practice, this is simplest when we can tolerate the approximation of integrating over the full domain of the prior, including the unphysical negative activations. We must also use a rectifying link function, because if the predicted point-process intensity is negative, then the Poisson likelihood for our observations is not defined. 

To our knowledge, the only rectifying link function that has been explored to-date is the exponential link function, which suffers from unacceptable numerical instability in our application. There are a few other link functions that we might explore, but we will address another approximate solution in this section based on Laplace-approximated expected log-likelihoods.

In order to minimize DKL in the variational Bayes approach, we must maximize the data log-likelihood under the approximating distribution Q, while simultaneously minimizing the KL divergence of the prior from this posterior approximation. Provided we interpret the multivariate Gaussian prior for A as having support over (,)D, the KL divergence term has a closed-form solution and well-defined derivatives. The challenge, then, is to calculate expected log-likelihood term:

(31)logPr(YA)Q

We now derive a general second-order approximation for the expected log-likelihood for a Gaussian latent-variable process with Poisson observations, where the latent variable may be linked to the Poisson intensity via an arbitrary link function. (See Zhou and Park, plus the expected log-likehood papers, for more detail). 

In the case of intractable logPr(YA)Q, we approximate this integral via Laplace approximation. This yields and approximate variational inference method that is similar, but not identical, to the Laplace approximation.

Second-order approximations to the expected log-likelihood

To briefly review the notation, let A={A1,..,AN} be a multivariate Gaussian latent variable reflecting our prior estimate for the distribution of latent activation, with mean μA and covariance ΣA. Let λn0=f(An) be an link function mapping the latent activity to a baseline intensity λn0, which then might be further scaled and shifted due to spatially inhomogeneous gain γn or background activity βn

We need to compute expected log-likelihoods under the approximating posterior distribution Q, that is:

(32)L(y)n|An)=ynlog(λn)λn

Where denotes averaging over the posterior distribution Q(A) with mean μ^A and covariance Σ^A

In certain cases, the expectations log(λn) and λn may have closed-form solutions, for example in the log-Gaussian instance (cite Rule, Zhou). Here, however, we explore a general approach based on second-order Taylor expansions, which is accurate for small variances. If An is normally distributed with mean μAn and variance σAn2, then out to second order:

(33)L(ynAn)L(ynμ^An)+σ^An22L(ynμ^An)

Gradient-based methods for optimizing the expected log-likelihood require derivatives of these approximated expectations. In general, derivatives with respect to the mean are:

(34)dndμAnmL(yAn|An)L(m)(ynμ^An)+σ^An22L(m+2)(ynμ^An)

A Newton-Raphson solver for optimizing the mean μ^A requires the Hessian of the objective function. Since the approximated expectations include second derivatives, the Hessian involves derivatives out to fourth order. The chain rule for higher-order derivatives of the logarithm is too cumbersome to state for the general case. Instead, we derive the equations for three versions of f(A): A, A2, and eA.

Optimizing the likelihood may also involve optimizing the variance σ2 or in general, the covariance. We will address this in later sections.

For the case that f(A)=A

Interpreting the distribution of latent activations A as a multivariate Gaussian over (,)D allows closed-form estimation of the DKL contribution to the variational Bayes objective function. However, unless point-process intensities are artificially constrained to the domain [0,), the expected log-likelihood is undefined for the identity link function. This is because the Poisson measurement likelihood is not defined for negative intensities. 

In this second-order approximation, we circumvent this issue by considering a locally-quadratic approximation of the likelihood function that continues the Poisson likelihood to negative intensities. Provided variance is small, and the posterior mean is constrained to be positive, this approximation may provide an accurate estimate of the expected log-likelihood.

If f(A)=A and so λ=vγ(x+β/γ). Computing out to the 4th derivative.

(35)L(yA)=ylog[λ]λ=y[log(vγ)+log(λ/γ)]vγ(x+β/γ)=ylog(λ/γ)vγA+constantL(yA)=y(λ/γ)1vγL(yA)=y(λ/γ)2L(3)(yA)=2y(λ/γ)3L(4)(yA)=6y(λ/γ)4

As λ0, the fourth derivative of the likelihood tends rapidly to infinity, which may create issues for numerical stability and accuracy. This behavior near λ0 is similar to the issues that plague the Laplace approximation. In particular, the distribution may become highly skewed, which means that third or higher moments may be needed, and the second-order approximation may not be accurate. However, I have reason to suspect that the issues might be less severe for the expected log-likelihood compared to the Laplace approximation. In this case, we are using a quadratic expansion about an estimate of the posterior mean, whereas the Laplace approximation seeks the posterior mode. I expect that this will have a stabilizing effect, encouraging λ toward more positive values.

For the case that f(A)=eA

A closed-form solution for the expected log-likelihood exists under this link function, and closed-form expressions for the moments of log-Gaussian random variables are known (see Zhou and Park for application to log-Gaussian point processes). However, in this application the amplification of positive tails of the distribution by the exponential link function is numerically unstable and unphysical, indicating that the log-Gaussian model is inappropriate. In Rule et al. 2018, we noted that a second-order approximation to the expected likelihood was more stable and more accurate. 

(36)L(yA)=ylog[λ]λ=ylog[vγ(eA+β/γ)]vγ(eA+β/γ)=ylog[eA+β/γ]vγeA+constantL(n)(yA)=yC(n1)vγeA,C=eAeA+β/γ=11+β/γeAC=C(1C)C=C(12C)C(3)=C6C2

For the case that f(A)=A2

A quadratic link function is rectifying, so the Poisson likelihood remains well-defined even if the domain of A is extended to (,)D. However, to my knowledge there is no tractable closed-form for the expectation of logarithm of a generalized noncentral χ2 distributed variable, and so the second-order approximation remains useful:

(37)L(yA)=ylog[λ]λ=ylog[vγ(A2+β/γ)]vγ(A2+β/γ)=ylog[A2+β/γ]vγA2+constantL(yA)=yC2vγA,C=2AA2+β/γL(yA)=yC2vγ,C=C(C1C)L(3)(yA)=yC,C=C(A12C)CA2L(4)(yA)=yC(3),C(3)=3(C/A)26C2

Incorporating the prior

So far we have focused on the expected log-likelihood contribution to the variational posterior objective function. We also need to derive the gradients of the KL-divergence term. For the most part, this is identical to the derivation in Zhou and Park, so I present only some quick notes here. We are interested in the gradients (and hessians) for the KL divergence between two k dimensional multivariate Gaussians, which is:

(38)DKL(N0N1)=12[tr(Π1Σ0)+(μ1μ0)Π1(μ1μ0)ln|Π1|ln|Σ0|k]

Note that I have chosen to denote this in terms of the precision matrix Π1=Σ11, as it makes some of the derivations below more straightforward. 

The derivative with respect to μ0 is the same as the derivative with respect to μ1 and is:

(39)μ1DKL(N0N1)=μ0DKL(N0N1)=12[(μ1μ0)Π1]

The derivative with respect to Σ0 is:

(40)Σ0DKL(N0N1)=12[Π1Π0]

The derivative with respect to Π1 is:

(41)Π1DKL(N0N1)=12[Σ0Σ1+(μ1μ0)(μ1μ0)]

For the variational interpretation we approximate the posterior P with approximation Q and minimize DKL(QP). This involves taking the derivative with respect to Σ0 above. 

Optimizing the (approximate) variational approximation

We have derived approximations for the expected log-likelihood contribution to the variational Bayesian objective function, which must be optimized jointly over the posterior mean μ^A and the posterior covariance Σ^A. The above derivations provide gradients and Hessians for optimizing μ^A for a fixed Σ^A. In Zhou and Park, they explore the joint optimization for the (exact) objective function for a log-Gaussian variational approximation. They prove that Σ^A can be optimized using a fixed-point iteration. 

(In numerical experiments, I extended this approach by interleaving one-step of the Newton-Raphson optimization for μ^A with one step of the fixed-point update for Σ^A. In my experience this accelerated convergence. Does the fixed-point iteration for Σ^A convergs for the second-oder approximated expected log-likelihood? Could a similar approach be found for the other (non-exponential) link functions explored here.)

The covariance update

To complete the variational approximation, we also need to optimize the posterior covariance. This involves the derivative of the expected log-likelihood with respect to Σ0. In the second-order (Laplace-approximation-like) expected log-likelihood, the dependence on the covariance enters through the second-order terms, which are:

(42)12diag(Σ0)L(yμ0)

The derivative of the above with respect to Σ0 is:

(43)12L(yμ0)

along the diagonal, and 0 elsewhere. The total gradient for Σ0, incorporating both the DKL and expected log-likelihood contributions, is: 

(44)Σ0L(AY)=12[Π1Π0+L(yμ0)]

Setting this gradient to zero and solving for Σ0 gives:

(45)Π0=Π1+L(yμ)

Which amounts to adding the curvature of the log likelihood L(yμ) (approximated as the curvature at the current posterior mean), to the prior precision matrix Π1. This is similar to the Hessian observed for the Laplace update, with the exception that we use the curvature at the estimated posterior mean, rather than posterior mode.

Computing the model likelihood

Once the Gaussian approximation is computed, how should we estimate the likelihood of the data given the model parameters θ, Pr(Y|θ) (or just Pr(Y) for short)? There are numerous approximation methods available, and it remains unclear to me which is best.

Integration via Laplace approximation

This likelihood is the integral of the prior Pr(A) times the measurement likelihood Pr(YA), and is also the normalization constant for the posterior distribution: 

(46)Pr(Y)=APr(YA)Pr(A)=Pr(YA)Pr(A)

Given a Gaussian approximation Q for the posterior Pr(AY), we can approximate this integral using the posterior mean μ^A and covariance Σ^A. This Gaussian approximation can be obtained from any of the previously mentioned approximation methods, Laplace approximation, or exact or approximated variational inference. 

Under this approximation, we evaluate Pr(YA)Pr(A) at μ^A, and also compute the curvature at this point (which should match Σ^A if everything has gone as planned!).

(47)logPr(AY)12[log|2πΣ^A|+(Aμ^a)Σ^A1(Aμ^a)]

We use the following logarithmic relationship derived from Bayes' rule

(48)logPr(Y)=log[Pr(A)Pr(YA)]logPr(AY)

Substituting the forms for the above

(49)logPr(Y)12[log|2πΣA|+(AμA)ΣA1(AμA)]+n1..N[ynlog(γnf(An)+βn)(γnf(An)+βn)]+12[log|2πΣ^A|+(Aμ^A)Σ^A1(Aμ^A)] 

Evaluating the above at the posterior mean μ^A amounts to a Laplace approximation of the integral for the likelihood.

(50)logPr(Y)log[Pr(μ^A)Pr(Yμ^A)]logPr(μ^AY)=12[log|2πΣA|+(μ^AμA)ΣA1(μ^AμA)]+n1..N[ynlog(γnf(μ^An)+βn)(γnf(μ^An)+βn)]+12[log|2πΣ^A|+(μ^Aμ^A)Σ^A1(μ^Aμ^A)] 

Cleaning things up, and denoting λ^n=γnf(μ^An)+βn, we get:

(51)logPr(Y)12[log|ΣA||Σ^A|+(μ^AμA)ΣA1(μ^AμA)]+n1..N[ynlog(λ^n)λ^n] 

On the similarity between the Laplace-approximated likelihood and DKL

Note the similarity to KL divergence

(52)DKL(QPr(A))=12[log|ΣA||Σ^A|D+tr[ΣA1Σ^A]+(μAμ^A)TΣA1(μAμ^A)]

Which gives the relation

(53)12[log|ΣA||Σ^A|+(μAμ^A)TΣA1(μAμ^A)]=12[tr[ΣA1Σ^A]D]DKL(QPr(A))

Which allows the likelihood to be written as

(54)logPr(Y)12[tr[ΣA1Σ^A]D]DKL(QPr(A))+n1..N[ynlog(λ^n)λ^n] 

We can play with this further, bringing in the second-order expected log-likelihood:

(55)L(ynμ^An)=n1..N[ynlog(λ^n)λ^n]

Which, from the second-order approximation, is:

(56)L(ynμ^An)L(ynμ^An)σ^An22L(ynμ^An)

So:

(57)logPr(Y)12[tr[ΣA1Σ^A]D]DKL(QPr(A))+L(yμ^A)n1..Nσ^An22L(ynμ^An)=DKL(QPr(AY))+12tr(ΣA1Σ^A)n1..Nσ^An22L(ynμ^An)+constant 

Which is to say that a point estimate of the log-likelihood is very similar to the (negative) KL-divergence penalty, differing only by the trace term and the curvature correction (and the dimensionality constant D). 

Empirically, these terms are small and do not dominate the likelihood. As we shall see in the following sections, this similarity is not a coincidence: the Laplace approximation is connected to the Evidence Lower Bound (ELBO) in the variational Bayesian approach, especially when a second-order approximation is used to evaluate the expected log-likelihood.

The expected log-likelihood

One can consider the expected value of the log-likelihood relative to the prior distribution Pr(A). Starting from the logarithmic form of Bayes' rule, we have:

(4)logPr(Y)=logPr(A)+logPr(YA)logPr(AY)

For the true posterior Pr(AY), this equality holds for all A, as we could recover the log-likelihood by evaluating this expression at any point. In the Laplace approximation, we evaluated this quantity at the posterior mean. 

%For the variational approach (outlined below), we consider a bound based on taking the expectation of Eq. 4 with respect to the approximating posterior distribution Q.

For the expected log-likelihood approach here, we take the expectation with respect to the prior distribution Pr(A):

(5)logPYPA=logPAPA+logPY|APAlogPA|YPA=logPY|APA+DKL(PAPA|Y)

We cannot compute the above exactly, because we do not have access to the true posterior PA|Y, but we do have access to an approximating posterior QPA|Y, which we can use to approximate the expected log-likelihood. Note that the additional DLK term increases the expected log likelihood, opposite of its role in the variational approach. 

(6)logPYPAlogPY|APA+DKL(PAQ)

To estimate the expected log-likelihood term logPY|APA, we may either use the second-order approach that we derived for variational Bayes', or simply evaluate logPY|A at the prior mean for a faster approximation.

ELBO and variational Bayes

When deriving the variational update, we omitted the (constant) data likelihood term. The full form for DKL(QP) is:

(58)DKL(QP)=AQAlogQAPA|Y=logQQlogPA|YQ=HQlogPY,AQ+logPY=logPY|AQ+DKL(QPA)+logPY, 

which implies the following identity for the log-likelihood:

(59)logPY=DKL(QP)+logPY|AQDKL(QPA)=DKL(QP)+HQ+logPY,AQ. 

That is: the data likelihood is the expected log-likelihood under the variational posterior, minus the KL divergence of the posterior form the prior, plus the KL divergence of the variational posterior from the true posterior. Since this last term is always positive, the following bound holds:

(7)logPYlogPY|AQDKL(QPA)=HQ+logPY,AQ 

This inequality is sometimes called the Evidence Lower Bound (ELBO). The more accurately we can approximate the posterior as a Gaussian, the smaller DKL(QP) becomes and the tight this bound becomes. 

Going forward

In practice, we use the Laplace approximation for obtaining the approximate posterior Q. This avoids needing to jointly optimize the posterior covariance, trading accuracy for speed. We have explore the Laplace, ELBO, and expected log-likelihood approaches to estimating the likelihood, and found them to be broadly similar. Note, however, that the accuracy of expressions involving second-order expansions is expected to break down if the variance of the latent state becomes large.

No comments:

Post a Comment