These notes contain some derivations for variational inference in Poisson and probit Generalized Linear Models (GLMs) with a Gaussian prior and approximated Gaussian posterior. ( see also here. )
Problem statement
Consider a population of neurons with firing-intensities $\boldsymbol\lambda =\rho(\boldsymbol\theta)$, where $\rho(\cdot)$ is a firing-rate nonlinearity and $\boldsymbol\theta$ is a vector of synaptic activations (amount of input drive to each neuron). For stochastic models of spiking $\Pr(y|\theta)$ in the canonical exponential family, the probability of observing spikes $\mathbf y$ given $\boldsymbol\theta$ can be written as \begin{equation}\begin{aligned}\ln\Pr(\mathbf y|\mathbf z) &=\mathbf y^\top\boldsymbol\theta-\mathbf 1^\top A(\boldsymbol\theta)+\text{constant},\end{aligned}\end{equation}
where $A(x)$ is a known function whose derivative equals the firing-rate nonlinarity, i.e. $A'(\cdot) =\rho(\cdot)$.
Assume that the synaptic activations $\boldsymbol\theta$ are driven by shared latent variables $\mathbf z$ with a Gaussian prior $\mathbf z\sim\mathcal N(\boldsymbol\mu_z,\boldsymbol\Sigma_z)$. Let $\boldsymbol\theta=\mathbf B\mathbf z$, where "$\mathbf B$" is a matrix of coupling coefficients which determine how the latent factors $\mathbf z$ drive each neuron.
We want to infer the distribution of $\mathbf z$ from observed spikes $\mathbf y$. The posterior is given by Bayes rule, $\Pr(\mathbf z |\mathbf y) =\Pr(\mathbf y |\mathbf z)\Pr(\mathbf z)/\Pr(\mathbf y)$. However, this posterior does not admit a closed form if $A(\cdot)$ is nonlinear. Instead, one can use a variational Bayesian approach to obtain an approximate posterior.
Variational Bayes
In variational Bayes, the posterior on $\boldsymbol z$ is approximated as Gaussian, i.e. $\Pr(\mathbf z |\mathbf y)\approx Q(\mathbf z)$, where $Q(\mathbf z) =\mathcal N(\boldsymbol\mu_q,\boldsymbol\Sigma_q)$. We then optimize $\boldsymbol\mu_q$ and $\boldsymbol\Sigma_q$ to minimize the Kullback-Leibler (KL) divergence from the true posterior $\Pr(\mathbf z |\mathbf y)$ to $Q(\mathbf z)$. This is equivalent to minimizing the KL divergenece $D_{\text{KL}}\left[ Q(\mathbf z)\|\Pr(\mathbf z)\right]$ from the prior to the posterior, while maximizing the expected log-likelihood $\left<\Pr(\mathbf y |\mathbf z)\right>$: \begin{equation}\begin{aligned}D_{\text{KL}}\left[ Q(\mathbf z)\|\Pr(\mathbf z |\mathbf y)\right]&=D_{\text{KL}}\left[ Q(\mathbf z)\|\Pr(\mathbf z)\right]-\left<\ln\Pr(\mathbf y |\mathbf z)\right>+\text{constant}.\end{aligned}\end{equation}
(In these notes, all expectations $\langle\cdot\rangle$ are taken with respect to the approximating posterior distribution.)
Since both $Q(\mathbf z)$ and $\Pr(\mathbf z)$ are multivariate Gaussian, the KL divergence $D_{\text{KL}}\left[ Q(\mathbf z)\|\Pr(\mathbf z)\right]$ has the closed form : \begin{equation}\begin{aligned} D_{\text{KL}}\left[ Q(\mathbf z)\|\Pr(\mathbf z)\right] &= \tfrac 1 2\left\{ (\boldsymbol\mu_z-\boldsymbol\mu_q)^\top \boldsymbol\Sigma_z^{-1} (\boldsymbol\mu_z-\boldsymbol\mu_q) + \operatorname{tr}\left( \boldsymbol\Sigma_z^{-1} \boldsymbol\Sigma_q \right) + \ln| \boldsymbol\Sigma_z^{-1} \boldsymbol\Sigma_q | \right\} +\text{constant}. \end{aligned}\end{equation}
For our choice of the canonically-parameterized natural exponential family, the expected negative log-likelihood can be written as: \begin{equation}\begin{aligned}-\langle\ln\Pr(\mathbf y|\mathbf z)\rangle&=\mathbf 1^\top\langle A(\boldsymbol\theta)\rangle-\mathbf y^\top\mathbf B\boldsymbol\mu_q+\text{constant}.\end{aligned}\end{equation}
Neglecting constants and terms that do not depend on $(\boldsymbol\mu_q,\boldsymbol\Sigma_q)$, the overall loss function to be minimized is: \begin{equation}\begin{aligned} \mathcal L(\boldsymbol\mu_q,\boldsymbol\Sigma_q) &= \tfrac 1 2\left\{ (\boldsymbol\mu_z-\boldsymbol\mu_q)^\top \boldsymbol\Sigma_z^{-1} (\boldsymbol\mu_z-\boldsymbol\mu_q) + \operatorname{tr}\left( \boldsymbol\Sigma_z^{-1} \boldsymbol\Sigma_q \right) + \ln|\boldsymbol\Sigma_z^{-1}\boldsymbol\Sigma_q|\right\} +\mathbf 1^\top\langle A(\boldsymbol\theta)\rangle -\mathbf y^\top\mathbf B\boldsymbol\mu_q \end{aligned} \label{loss}. \end{equation}
Closed-form expectations
To optimize $\eqref{loss}$, we need to differentiate it in $\boldsymbol\mu_q$ and $\boldsymbol\Sigma_q$. These derivatives are mostly straightforward, but the expectation $\langle A(\boldsymbol\theta)\rangle$ poses difficulties when $A(\cdot)$ is nonlinear. We'll consider some choices of firing-rate nonlinearity for which the derivatives of $\langle A(\boldsymbol\theta)\rangle$ have closed-form expressions when $\boldsymbol\theta$ is Gaussian.
Because we've assumed a Gaussian posterior on our latent state $\mathbf z$, and since $\boldsymbol\theta =\mathbf B\mathbf z$, the synaptic activations $\boldsymbol\theta$ are also Gaussian. The vectors $\boldsymbol\mu_\theta$ and $\boldsymbol\sigma^2_\theta$ for the mean and variance of $\boldsymbol\theta$, respectively, are: \begin{equation}\begin{aligned}\boldsymbol\mu_\theta &=\mathbf B\boldsymbol\mu_q\\\boldsymbol\sigma^2_\theta &=\operatorname{diag}\left[\mathbf B\boldsymbol\Sigma_q\mathbf B^\top\right]\end{aligned}\end{equation}
Consider a single, scalar $\theta\sim\mathcal N(\mu,\sigma^2)$. Using the chain rule and linearity of expectation, one can show that the partial derivatives $\partial_{\mu}\langle A(\theta)\rangle$ and $\partial_{\sigma^2}\langle A(\theta)\rangle$, with respect to $\mu$ and $\sigma^2$ respectively, are: \begin{equation}\begin{aligned}\partial_{\mu}\langle A(\theta)\rangle&=\langle A'(\theta) \rangle= \langle \rho(\theta) \rangle \\\partial_{\sigma^2}\langle A(\theta)\rangle&= \tfrac1{2\sigma^2}\left<(\theta-\mu_\theta) A'(\theta)\right>=\tfrac1{2\sigma^2} \left<(\theta-\mu)\rho(\theta)\right>.\end{aligned}\label{dexpect}\end{equation}
For more compact notation, denote the expected firing rate as $\bar\lambda=\langle\rho(\theta)\rangle$, and denote the expected derivative of the firing-rate in $\theta$ as $\bar\lambda'=\langle\rho'(\theta)\rangle$. Note that $\bar\lambda=\partial_{\mu}\langle A(\theta)\rangle$ and $\tfrac 1 2\bar\lambda'=\partial_{\sigma^2}\langle A(\theta)\rangle$.
Closed-form expressions for $\bar\lambda$ and $\bar\lambda'$ exist only in some special cases, for example if the firing-rate function $\rho(\cdot)$ is a (rectified) polynomial. We consider two choices of firing-rate nonlinearity which admit closed-form expressions, "exponential" and "probit".
-Choosing $\rho =\exp$ corresponds to a Poisson GLM. In this case, $\bar\lambda =\bar\lambda' =\exp(\mu +\sigma^2/2)$. -Let $\phi(\cdot)$ and $\Phi(\cdot)$ denote the probability density and cumulative distribution function, respectively, for a standard normal distribution. Choosing $\rho =\Phi$ corresponds to a probit GLM. In this case, $\bar\lambda =\Phi(\gamma\mu)$ and $\bar\lambda' =\gamma\phi(\gamma\mu)$, where $\gamma = (1+\sigma^2)^{-1}$.
For the probit firing-rate nonlinearity, we will also need to know $\partial_{\sigma^2}\langle\rho'(\boldsymbol\theta)\rangle$ to calculate the Hessian-vector product. In this case, $\rho'=\phi$. We have from $\eqref{dexpect}$ that $\partial_{\sigma^2}\langle\phi(x)\rangle=\tfrac1{2\sigma^2}\left<\theta (\mu-\theta)\phi(\theta)\right>$. This can be solved by writing the expectation as an integral and completing the square in the resulting Gaussian integral, yielding: \begin{equation}\begin{aligned} \partial_{\sigma^2}\langle\phi(x)\rangle &= \frac {u-1} {\sqrt{8\pi e^u(1+\sigma^2)^3}} ,\text{ where }u=\frac{\mu^2}{\sigma^2+1}. \end{aligned} \label{probitrhoprimeexpectgradient} \end{equation}
Derivatives of the loss function
With these prelimenaries out of the way, we can now consider the derivatives of $\eqref{loss}$ in terms of $\boldsymbol\mu_q$ and $\boldsymbol\Sigma_q$.
Derivatives in $\boldsymbol\mu_q$
The gradient and Hessian of $\mathcal L$ with respect to $\boldsymbol\mu_q$ are: \begin{equation}\begin{aligned} \partial_{\boldsymbol\mu_q} \mathcal L &= \boldsymbol\Sigma_z^{-1} (\boldsymbol\mu_q-\boldsymbol\mu_z) + \mathbf B^\top\left( \bar{\boldsymbol\lambda}-\mathbf y\right) \\ \operatorname H_{\boldsymbol\mu_q} \mathcal L &= \boldsymbol\Sigma_z^{-1} + \mathbf B^\top \operatorname{diag}[\bar{\boldsymbol\lambda}'] \mathbf B \end{aligned}\end{equation}
Gradient in $\boldsymbol\Sigma_q$
The gradient of $\eqref{loss}$ in $\boldsymbol\Sigma_q$ is more involved. The derivative of the term $\tfrac 1 2\{ \operatorname{tr}(\boldsymbol\Sigma_z^{-1}\boldsymbol\Sigma_q)+\ln|\boldsymbol\Sigma_z^{-1}\boldsymbol\Sigma_q|\}$ can be obtained using identities provided in The Matrix Cookbook . The derivative of $\mathbf 1^\top\langle A (\boldsymbol\theta)\rangle$ can be obtained by considering derivatives with respect to individual elements of $\boldsymbol\Sigma_q$, and is $\tfrac12\mathbf B^\top\operatorname{diag}[\bar\lambda']\mathbf B$. Overall, we find that: \begin{equation}\begin{aligned} \partial_{\boldsymbol\Sigma_q}{\mathcal L}= \tfrac 1 2 \left\{ \boldsymbol\Sigma_z^{-1}+ \boldsymbol\Sigma_q^{-\top}+ \mathbf B^\top\operatorname{diag}[\bar\lambda']\mathbf B \right\}. \end{aligned} \label{sigmagrad} \end{equation}
Hessian-vector product in $\boldsymbol\Sigma_q$
Since $\boldsymbol\Sigma_q$ is a matrix, the Hessian of $\eqref{loss}$ in $\boldsymbol\Sigma_q$ is a fourth-order tensor. It is simpler to work with the Hessian-vector product. Here, the "vector" is a covariance matrix $\mathbf M$ to be optimized. The Hessian-vector product is given by the following identity: \begin{equation}\begin{aligned} \langle\mathbf H_{\boldsymbol\Sigma_q},\mathbf M\rangle&= \partial_{\boldsymbol\Sigma_q}\langle \mathbf J_{\boldsymbol\Sigma_q},\mathbf M\rangle= \partial_{\boldsymbol\Sigma_q} \operatorname{tr}\left[ \mathbf J_{\boldsymbol\Sigma_q}^\top\mathbf M \right] \end{aligned}\end{equation}
where $\langle\cdot,\cdot\rangle$ denotes the scalar (Frobenius) product. The Hessian-vector product for the terms $\boldsymbol\Sigma_z^{-1}+\boldsymbol\Sigma_q^{-\top}$ in $\eqref{sigmagrad}$ can be obtained using identities provided in The Matrix Cookbook : \begin{equation}\begin{aligned} \partial_{\boldsymbol\Sigma_q} \operatorname{tr}\left[\left\{\boldsymbol\Sigma_z^{-1}+\boldsymbol\Sigma_q^{-\top}\right\}^\top\mathbf M\right]=-\boldsymbol\Sigma_q^{-1} \mathbf M^\top \boldsymbol\Sigma_q^{-1}. \end{aligned}\end{equation}
The Hessian-vector product for the term $\mathbf B^\top\operatorname{diag}[\bar\lambda']\mathbf B$ in $\eqref{sigmagrad}$ is more complicated. We can write \begin{equation}\begin{aligned} \partial_{\boldsymbol\Sigma_q} \operatorname{tr}\left[ \left\{ \mathbf B^\top\operatorname{diag}[\bar\lambda']\mathbf B \right\}^\top\mathbf M \right]&= \partial_{\boldsymbol\Sigma_q} \operatorname{tr}\left[ \mathbf B\mathbf M\mathbf B^\top \operatorname{diag}[\bar\lambda'] \right]\\&= \mathbf B^\top \operatorname{diag}[\mathbf B\mathbf M\mathbf B^\top] \operatorname{diag} \left[ \partial_{\sigma^2_{\boldsymbol\theta}}\langle\rho'(\boldsymbol\theta)\rangle \right] \mathbf B. \end{aligned} \label{hvp2} \end{equation}
The first step in $\eqref{hvp2}$ uses the fact that the trace is invariant under cyclic permutations. The second step follows from Lemma 1 (Appendix, below), with $\mathbf C=\mathbf B\mathbf M\mathbf B^\top$ and using the fact that $\bar\lambda'=\langle\rho'(\theta)\rangle$. In general, the Hessian-vector product in $\boldsymbol\Sigma_q$ is \begin{equation}\begin{aligned}\langle\mathbf H_{\boldsymbol\Sigma_q},\mathbf M\rangle&= \tfrac12\left\{-\boldsymbol\Sigma_q^{-1}\mathbf M^\top \boldsymbol\Sigma_q^{-1}+\mathbf B^\top\operatorname{diag}[\mathbf B\mathbf M\mathbf B^\top] \operatorname{diag}\left[\partial_{\sigma^2_{\boldsymbol\theta}}\langle\rho'(\boldsymbol\theta)\rangle \right]\mathbf B\right\} \end{aligned}\end{equation}
For the exponential firing-rate nonlinearity, $\partial_{\sigma^2_{\boldsymbol\theta}}\langle\rho'(\boldsymbol\theta)\rangle=\tfrac 1 2\bar\lambda$. The solution for the probit firing-rate nonlinearity is given in $\eqref{probitrhoprimeexpectgradient}$.
Conclude
That's all for now! I'll need to integrate these with the various other derivations ( e.g. see also here. ).
Appendix
Lemma 1
(We use Einstein summation to simplify the notation) \begin{equation}\begin{aligned} \partial_{\boldsymbol\Sigma_{q,ij}} \operatorname{tr} \left[ \mathbf C \operatorname{diag}\left[ \langle f(\boldsymbol\theta)\rangle \right] \right]&= \partial_{\boldsymbol\Sigma_{q,ij}} \left[ \mathbf C \operatorname{diag}\left[ \langle f(\boldsymbol\theta)\rangle \right] \right]_{kk}\\&= \partial_{\boldsymbol\Sigma_{q,ij}} \left[ \mathbf C_{lm} \operatorname{diag}\left[ \langle f(\boldsymbol\theta)\rangle \right]_{mn} \right]_{kk}\\&= \partial_{\boldsymbol\Sigma_{q,ij}} \left[ {\mathbf C}_{kk} \operatorname{diag}\left[ \langle f(\boldsymbol\theta)\rangle \right]_{k} \right]\\&= {\mathbf C}_{kk} \langle \partial_{\boldsymbol\Sigma_{q,ij}} f(\theta_k)\rangle\\&= {\mathbf C}_{kk} \mathbf B_{ik}^\top \,\partial_{\sigma^2_{\theta}}\langle f(\theta_k)\rangle \mathbf B_{kj}\\&= \mathbf B_{ik}^\top {\mathbf C}_{kk} \,\partial_{\sigma^2_{\theta}}\langle f(\theta_k)\rangle \mathbf B_{kj}\\&= \left\{ \mathbf B^\top \operatorname{diag}[\mathbf C] \operatorname{diag} \left[ \partial_{\sigma^2_{\boldsymbol\theta}}\langle f(\boldsymbol\theta)\rangle \right] \mathbf B \right\}_{ij} \\ \\ \partial_{\boldsymbol\Sigma_q} \operatorname{tr} \left[ \mathbf C \operatorname{diag}\left[ \langle f(\boldsymbol\theta)\rangle \right] \right]&= \mathbf B^\top \operatorname{diag}[\mathbf C] \operatorname{diag} \left[ \partial_{\sigma^2_{\boldsymbol\theta}}\langle f(\boldsymbol\theta)\rangle \right] \mathbf B \end{aligned}\end{equation}
No comments:
Post a Comment