Gaussian Processes (GPs) generalize the idea of multivariate Gaussian distributions to distributions over functions. In neuroscience, they can be used to estimate how the firing rate of a neuron varies as a function of other variables (e.g. to track retinal waves ). Lately, we've been using Gaussian processes to describe the firing rate map of hippocampal grid cells .
We review Bayesian inference and Gaussian processes, explore applications of Gaussian Processes to analyzing grid cell data, and finally construct a GP model of the log-rate that accounts for the Poisson noise in spike count data. Along the way, we discuss fast approximations for these methods, like kernel density estimation , or approximating GP inference using convolutions.
Edit: There is a bug in the "covariance_crosshairs" function, there should be a square-root around "chi2.isf(1-p,df=2)".
Introduction
First, we briefly review Bayesian inference for multivariate Gaussian variables and Gaussian processes. Then, we construct some synthetic spike-count observations, similar to what one might see in hippocampal grid cells. We then review how to estimate the underlying firing rate map using kernel density estimation, and discuss some regularization choices when data are limited.
Loosely, Gaussian processes can be viewed as "really big" multivariate Gaussian distributions, with infinitely many variables. It's helpful to review Bayesian inference for multivariate Gaussian variables before continuing.
Consider estimating some jointly Gaussian variables
Consider a case where both
We can estimate
This product of two multivariate Gaussian distributions is also a multivariate Gaussian distribution,
In other textbooks or tutorials, you might also see this written as
Both forms are equivalent, and are related to each other by applying the Sherman–Morrison–Woodbury matrix inversion lemma .
Gaussian processes are commonly used to estimate a smooth underlying trend from noisy observations. Peter Roelants' notes on Gaussian processes is a clear and detailed introduction.
Consider a GP regression problem for learning
For the regression problem, we'd like learn a model of
GP regression builds a posterior distribution over possible functions
For any finite collection of points
where the means and covariances are computed according to the prior mean and kernel,
To connect GP regression to the Bayesian update for multivariate normal variables, consider sampling both the data and the posterior over the same set of points
In this case,
This is identical to the posterior distribution for a multivariate Gaussian model we discussed earlier. Indeed, if the data consist of Gaussian observations over a set of points, and you evaluate the posterior at these same locations, there is no difference between Gaussian Process regression and Bayesian inference using multivariate Gaussian variables.
First, let's set up our Python environment in the notebook.
%matplotlib inline
# First, set up envirinment
from pylab import *
mpl.rcParams['figure.figsize'] = (8,2.5)
mpl.rcParams['figure.dpi'] = 200
mpl.rcParams['image.origin'] = 'lower'
mpl.rcParams['image.cmap']='magma'
np.seterr(divide='ignore', invalid='ignore');
Simulating some data
Let's generate some fake grid cell data. We'll simulate a
L = 128 # Grid size
P = L/10 # Grid spacing
α = 0.5 # Grid "sharpness"
μ = 1500/L**2 # Mean firing rate (spikes per sample)
# 2D grid coordinates as complex numbers
c = arange(L)-L//2
coords = 1j*c[:,None]+c[None,:]
def ideal_hex_grid(L,P):
# Build a hexagonal grid by summing three cosine waves
θs = exp(1j*array([0,pi/3,2*pi/3]))
return sum([cos((θ*coords).real*2*pi/P) for θ in θs],0)
# Generate intensity map: Exponentiate and scale mean rate
λ0 = exp(ideal_hex_grid(L,P)*α)
λ0 = λ0*μ/mean(λ0)
We also add a bit of zero padding around the data. This allows us to apply convolution kernels using circular convolution, without mixing up data from opposite ides. This will be useful later, since circular convolution can be computed very quickly using the Fast Fourier Transform (FFT) . More generally, we might also want to mask out parts of the space if e.g. the rat was exploring an arena with something other than a square shape.
# Zero pad edges
pad = L*1//10
mask = zeros((L,L),dtype='bool')
mask[pad:-pad,pad:-pad]=1
# Simulate oddly shaped arena
mask[:-L*4//10,L*3//10:L*4//10] = False
λ0 = λ0*mask
# For realism, add some background rate changes
λ0 = λ0*(1-abs(coords/(L-2*pad)+0.1))
We summarize the data in terms of two
# Simulated a random number of visits to each location
# as well as Poisson spike counts at each location
N = poisson(2*(1-abs(coords/L-0.2j)),size=(L,L))*mask
K = poisson(λ0*N)
Let's plot things.
def pscale(x,q1=0.5,q2=99.5,domask=True):
# Plot helper: Scale data by percentiles
u = x[mask] if domask else x
p1 = percentile(u,q1)
p2 = percentile(u,q2)
x = clip((x-p1)/(p2-p1),0,1)
return x*mask if domask else x
def showim(x,t='',**kwargs):
# Plot helper: Show image with title, no axes
if len(x.shape)==1: x=x.reshape(L,L)
imshow(pscale(x,**kwargs));
axis('off');
title(t);
subplot(131); showim(mask,'Environmet')
subplot(132); showim(λ0,'True rate')
subplot(133); showim(K,'Binned Spikes');
Estimating rate in each bin
The simplest way to estimate the rate at each location is to simply divide the number of observed spikes
It's tempting to add a little ad-hoc regularization to handle the
We assume that each time the rat visits a location with intensity
This gives us a likelihood for estimating
We can write the likelihood of observing a count observation
To combine
To assign a rate estimate to locations with missing data, we can define a Bayesian prior for
This gives a gamma-distributed posterior with
This is biased toward higher rates due the +1 in the numerator. Using the mode
One can interpolate between these mean-based and mode-based regularizers with another parameter
We use
def regλ(N,K,ρ=1.3,γ=0.5):
# Regularized rate estimate
return (K+ρ*(sum(K)/sum(N)-γ)+γ)/(N+ρ)
Even with regularization, estimating the rate directly in each bin is far too noisy to be useful. Why go through all this trouble to define a principled way to regularize counts for single bins then? These regularized rate estimators provide a principled way to define how a rate estimator should behave when data are limited, and can be incorporated into better estimators that pool data from adjacent bins. Next, we explore a simple way to pool data from adjacent bins using kernel density smoothing.
from scipy.stats import pearsonr
def printstats(a,b,message=''):
# Print RMSE and correlation between two rate maps
a,b = a.reshape(L,L)[mask],b.reshape(L,L)[mask]
NMSE = mean((a-b)**2)/sqrt(mean(a**2)*mean(b**2))
print(message+':')
print('∙ Normalized MSE: %0.1f%%'%(100*NMSE))
print('∙ Pearson correlation: %0.2f'%pearsonr(a,b)[0])
# Rate per bin using naive and regularized estimators
λhat1 = nan_to_num(K/N)
λhat2 = regλ(N,K)
printstats(λ0,λhat1,'K/N Estimator')
printstats(λ0,λhat2,'Regularized Estimator')
# Effect of regularization on error
ρs,γs = linspace(1e-2,2,51),linspace(0,1,51)
MAE = array([[mean(abs(λ0-regλ(N,K,ρ,γ))**2) for ρ in ρs] for γ in γs])
subplot(121); showim(λhat2,'$\hat\lambda$, γ=0.5, ρ=1.3')
subplot(122); imshow(-log(MAE),extent=(0,2,0,1),aspect=2)
xticks([0,1,2]); yticks([0,.5,1]); xlabel('ρ'); ylabel('γ');
title('Regularized $\hat\lambda$ Error')
colorbar(label='$-\log(\operatorname{MSE})$');
K/N Estimator: ∙ Normalized MSE: 236.6% ∙ Pearson correlation: 0.17 Regularized Estimator: ∙ Normalized MSE: 206.4% ∙ Pearson correlation: 0.19
Estimating rate via Kernel Density Estimation (KDE)
The simplest way to estimate rate is to average the spike counts over nearby regions. We'll use a Gaussian blur here. The 2D Gaussian blur is a separable filter , so we can compute it using two 1D Gaussian blurs in each direction. This can also be done quickly using the Fast Fourier Transform (FFT) . This amounts to Kernel Density Estimation (KDE) .
def blurkernel(L,σ,normalize=False):
# Gaussian kernel
k = exp(-(arange(-L//2,L//2)/σ)**2)
if normalize:
k /= sum(k)
return fftshift(k)
def conv(x,K):
# Compute circular 2D convolution using FFT
# Kernel K should already be fourier-transformed
return real(ifft2(fft2(x.reshape(K.shape))*K))
def blur(x,σ,**kwargs):
# 2D Gaussian blur via fft
kern = fft(blurkernel(x.shape[0],σ,**kwargs))
return conv(x,outer(kern,kern))
In our case, we must also account for the nonuniform sampling of space. The rat visits some locations more than others. The solution is to smooth the spike counts
def kdeλ(N,K,σ,**kwargs):
# Estimate rate using Gaussian KDE
return regλ(blur(N,σ),blur(K,σ),**kwargs)
If we want to use the regularized rate estimator defined earlier, we should normalize our smoothing kernel
For analyzing the underlying grid, we might also want to remove large-scale variations in rate across the arena. We can estimate a background rate also via Gaussian smoothing, and divide out this rate to get a normalized estimate of how rate changes with location.
fgσ = 4 # Kernel smoothing radius
bgσ = L/15 # Background kernel radius
λhat = kdeλ(N,K,fgσ)
λbg = kdeλ(N,K,bgσ)
λbar = λhat/λbg
printstats(λ0,λhat,'KDE Error')
subplot(131); showim(λhat,'Rate, KDE, σ=%d'%fgσ);
subplot(132); showim(λbg ,'Background Rate');
subplot(133); showim(λbar,'Normalized Rate');
KDE Error: ∙ Normalized MSE: 30.5% ∙ Pearson correlation: 0.59
Inspecting the data
Kernel density smoothing yields a good estimate of the rate map, but we need to know how much to blur the spike count data. We can also pick
We can calculate the 2D autocorrelation efficiently using the FFT . To focus on fluctuations around the mean rate, we should first subtract any constant component.
def zeromean(x):
# Mean-center data, accounting for masked-out regions
x = x.reshape(mask.shape)
return (x-mean(x[mask]))*mask
def fft_acorr(x):
# Zero-lag normalized to match signal variance
x = zeromean(x)
# Window attenuates boundary artefacts
win = hanning(L)
win = outer(win,win)
# Calculate autocorrelation using FFT
psd = (abs(fft2(x*win))/L)**2
acr = fftshift(real(ifft2(psd)))
# Adjust peak for effects of mask, window
return acr*var(x[mask])/acr[L//2,L//2]
We can collapse this 2D autocorrelation down to 1D by averaging this 2D autocorrelation as a function of radial distance. This radial autocorrelation has a large peak at zero lag, but also several smaller peaks due to the periodic tuning curve.
def radial_average(y):
# Get radial autocorrelation by averaging 2D autocorrelogram
i = int32(abs(coords)) # Radial distance
a = array([mean(y[i==j]) for j in range(L//2+1)])
return concatenate([a[::-1],a[1:-1]])
def radial_acorr(y):
# Autocorrelation as a function of distance
return radial_average(fft_acorr(y))
We can estimate the grid spacing based on the location of the first non-zero-lag peak. Here, we use sinc interpolation computed via FFT to find the location of the peak that corresponds to the grid spacing.
def fft_upsample_1D(x,factor=4):
'''
Upsample 1D array using the FFT
'''
n = len(x)
n2 = n*factor
f = fftshift(fft(x))*hanning(n)
f2 = np.complex128(np.zeros(n2))
r0 = (n2+1)//2-(n+0)//2
f2[r0:r0+n] = f
return np.real(ifft(fftshift(f2)))*factor
from scipy.signal import find_peaks
def acorr_peak(r,F=6):
# sinc upsample at ×F resolution to get distance to first peak
r2 = fft_upsample_1D(r,F)
return min(find_peaks(r2[len(r2)//2:])[0])/F-1,r2
For grid cells, the 2D autocorrelation should show a hexagon, which reflects the three sinusoidal components that make up the periodic grid tiling (below, left).
λhat = kdeλ(N,K,L/75) # Small blur for initial esitmate
acorr2 = fft_acorr(λhat) # Get 2D autocorrelation
acorrR = radial_average(acorr2) # Get radial autocorrelation
res = 5 # Subsampling resolution
P,acup = acorr_peak(acorrR,res) # Distance to first peak in bins
figure(figsize=(8,2))
subplot(121); showim(acorr2,'Autocorrelation',domask=False)
subplot(122); plot(linspace(-L/2,L/2,L*res)-.5/res,acup)
[gca().spines[s].set_visible(0) for s in ['top','right','bottom','left']]
axhline(0,color='k',lw=.8); xticks([0]); xlabel('Distance');
axvline(0,color='k',lw=.8); yticks([0]); ylabel(' '*9+'Correlation',labelpad=-9)
axvline((0+1+P),color='y',lw=.8);
title('Radial Autocorrelation');
Once we have grid spacing
fgσ = P/pi
bgσ = fgσ*2.5
λhat = kdeλ(N,K,fgσ)
λbg = kdeλ(N,K,bgσ)
printstats(λ0,λhat,'KDE')
subplot(131); showim(λhat,'Rate, KDE, σ=%d'%fgσ);
subplot(132); showim(λbg ,'Background Rate');
subplot(133); showim(λbar,'Normalized Rate');
KDE: ∙ Normalized MSE: 30.3% ∙ Pearson correlation: 0.60
Let's start by implementing smoothing using GP regression. Recall the formula for the GP posterior mean:
If we set the prior means to zero, this simplifies to:
To set up our GP regression problem, we need to define the prior covariance
We work with the binned spike counts
where
# Prepare error model for GP
ε0 = mean(K)/mean(N) # variance per measurement
τe = N.ravel()/ε0 # precision per bin
We construct the prior covariance matrix
# Build 2D kernel for the prior
# Scale kernel height to match data variance (heuristic)
k1 = blurkernel(L,fgσ*2)
y = nan_to_num(K/N)
kern = outer(k1,k1)*var(y[mask])
from scipy.linalg import circulant
def kernel_to_covariance(kern):
# Covariance is a doubly block-circulant matrix
# Use np.circulant to build blocks, then copy
# with shift to make 2D block-circulant matrix
assert(argmax(kern.ravel())==0)
L = kern.shape[0]
b = array([circulant(r) for r in kern])
b = b.reshape(L**2,L).T
s = array([roll(b,i*L,1) for i in range(L)])
return s.reshape(L**2,L**2)
Here, we explore a grid size of
The numerical stability of our GP regression will be poor if our prior covariance has small eigenvalues. The eigenvalues of our covariance correspond to the coefficients of the Fourier transform of our kernel, so we can "repair" our kernel by setting too-small eivenvalues with a small positive value.
def repair_small_eigenvalues(kern,mineig=1e-6):
# Kernel must be positive; fix small eigenvalues
assert(argmax(kern.ravel())==0)
kfft = fft2(kern)
keig = abs(kfft)
υmin = mineig*np.max(keig)
zero = keig<υmin
kfft[zero] = υmin
kern = real(ifft2(maximum(υmin,kfft)))
return kern
To solve our GP regression problem, we use the following form for the posterior mean. This form avoids inverting the prior covariance, so it's more numerically stable.
We apply a few optimizations for speed.
-
can be evaluated as , where is element-wise multiplication. -
can be evaluated with row-wise multiplication . -
can be evaluated via convolution using the FFT.
Finally, we use scipy.sparse.minres
to solve the linear system
Minres stands for "minimum residual", and is a type of Krylov subspace solver. It works by re-phrasing
And then minimizing this error ("residual"). We do not need to explicitly construct
import time
ttic = None
def tic(msg=''):
# Timer routine to track performance
global ttic
t = time.time()*1000
if ttic and msg:
print(('Δt = %d ms'%(t-ttic)).ljust(14)\
+'elapsed for '+msg)
ttic = t
def showkn(k,t):
# Plot helper; Shift convolution kernel to plot
imshow(fftshift(k)); axis('off'); title(t);
from scipy.sparse.linalg import minres,LinearOperator
def solveGP(kern,y,τe,tol=1e-4,reg=1e-5):
# Minimum residual solver is fast
kern = repair_small_eigenvalues(kern,reg)
knft = fft2(kern)
τy = τe*zeromean(y).ravel()
Στy = conv(τy,knft).ravel()
Hv = lambda v:conv(τe*v,knft).ravel() + v
ΣτεI = LinearOperator((L**2,L**2),Hv,Hv,dtype=np.float64)
μ = minres(ΣτεI,Στy,tol=tol)[0]
return μ.reshape(L,L) + mean(y[mask])
λGP1 = solveGP(kern,y,τe.ravel())
printstats(λ0,λGP1,'GP regression error')
subplot(131); showkn(kern,'Prior Kernel');
subplot(132); showim(y,'Observations');
subplot(133); showim(λGP1/λbg,'Posterior Rate');
GP regression error: ∙ Normalized MSE: 25.4% ∙ Pearson correlation: 0.68
Sometimes GP regression reduces to convolution
It seems like GP regression yields similar results to kernel density estimation. Can we relate these two operations? Recall the solution for the GP posterior:
The prior
where
In the special case that all measurements have noise
When the GP is evaluated on a regularly-spaced grid, the eigenspace
For large measurement error
This highlights that sometimes filtering the observations with a convolution kernel gives you something almost as good as a GP regression. This is much simpler, and is often good enough.
def mirrorpad(y,pad):
# Reflected boundary for convolution
y[:pad, :]=flipud(y[ pad: pad*2,:])
y[:, :pad]=fliplr(y[:, pad: pad*2])
y[-pad:,:]=flipud(y[-pad*2:-pad,:])
y[:,-pad:]=fliplr(y[:,-pad*2:-pad])
return y
# Uniform measurement error ⇒ GP = convolution
μτ = mean((N/ε0)[mask])
kft = fft2(kern)
gft = (kft*μτ)/(kft*μτ+1)
y = mirrorpad(nan_to_num(K/N),pad)
μy = mean(y[mask])
λcnv = conv(y-μy,gft)+μy
printstats(λcnv,λGP1,'Error between GP regression and convolution')
subplot(121); showkn(real(ifft2(gft)),'Convolution Kernel');
subplot(122); showim(λcnv/λbg,'Convolution Approximation');
Error between GP regression and convolution: ∙ Normalized MSE: 17.4% ∙ Pearson correlation: 0.92
Better priors
So far, we've only used GP regression with a Gaussian prior. When analyzing data from grid cells, the real power of GP regression lies in being able to encode the knowledge that the grid should be periodic into the GP prior kernel.
To construct a periodic prior, we estimate the autocorrelation from a perfect grid. To avoid assuming any particular orientation, we make the kernel radially symmetric. To avoid inferring long-range interactions where none exist, we taper the kernel to look only at the local neighborhood.
from scipy.interpolate import interp1d
def radial_kernel(rk):
# Make radially symmetric 2D kernel from 1D radial kernel
r = abs(coords)
kern = interp1d(arange(L//2),rk[L//2:],
fill_value=0,bounds_error=0)(r)
return fftshift(kern)
# Make symmetric kernel from autocorrelation of ideal grid
acgrd = fft_acorr(ideal_hex_grid(L,P))
kernR = radial_kernel(radial_average(acgrd))
# Restrict kernel to local neighborhood and normalize
window = abs(coords)<P*sqrt(2)
kern0 = blur(kernR*fftshift(window),P/pi)
kern0 = kern0/np.max(kern0)
subplot(131); showim(acgrd,'Ideal Autocorrelation',domask=False);
subplot(132); showkn(kernR,'Radial Kernel');
subplot(133); showkn(kern0,'Windowed');
We adapt the kernel to the observed statistics of the spike count data by scaling the zero-lag peak in the kernel to match a estimate of the variance in the rate.
The zero-lag autocorrelation of the data reflects the sum of the true variance in the underlying rates, plus the average measurement noise.
To remove the contribution from the measurement noise, we estimating the zero-lag variance by fitting a quadratic polynomial to the correlation at nearby, nonzero lags.
This prior encodes the assumption that the observed spike counts have a periodic underlying structure, and leads to better recovery of the grid fields.
def zerolag(ac,r=3):
# Estimate true zero-lag variance via quadratic interpolation.
z = array(ac[L//2-r:L//2+r+1])
v = arange(r*2+1)
return polyfit(v[v!=r],z[v!=r],2)@[r**2,r,1]
# Estimate zero-lag variance and scale kernel
acorrR1 = radial_acorr(regλ(N,K))
acorrR2 = copy(acorrR1)
v0 = zerolag(acorrR1)
kern = kern0*v0
acorrR2[L//2] = v0
ε0 = mean((K/N)[N>0])
λGP2 = solveGP(kern,y,N.ravel()/ε0)
printstats(λ0,λGP2,'GP with periodic kernel')
subplot(121)
axhline(0,color='k',lw=.8)
plot(acorrR1[L//2:],label='Autocorrelation')
plot(kern[0,:L//2] ,label='Kernel')
xticks([0]); xlabel('Distance'); xlim(0,L//4)
yticks([0]); ylabel('Correlation'); ylim(ylim()[0],v0*4)
[gca().spines[s].set_visible(0) for s in ['top','right','bottom']];
legend(); title('Height Calibration')
subplot(122); showim(λGP2,t='Posterior Rate');
GP with periodic kernel: ∙ Normalized MSE: 27.4% ∙ Pearson correlation: 0.79
Heuristic approximation of Poisson noise
Neuronal spiking is typically treated as conditionally Poisson, which means its variance should be proportional to the firing rate. Let's explore a heuristic way to incorporate a Poisson noise assumption into our GP regressions. Earlier, we discussed how the gamma distribution could serve a conjugate prior for Poisson count data. We can also use a Gamma distribution to model the measurement uncertainty from a collection of Poisson observations, and incorporate this model of uncertainty into our GP regression.
The variance of a
Performance for this model of the error is mixed: it can work better than assuming constant error when data are limited, but sometimes performs worse than simply assuming uniform variance equal to the neuron's average firing rate. We discuss a more principled way to handle Poisson noise in the next section.
# Use estimated rate as measurement error variance
ve = kdeλ(N,K,fgσ,ρ=1,γ=.5)
y = nan_to_num(K/N)
λGP3 = solveGP(kern,y,(N/ve).ravel())
printstats(λ0,λGP3,'GP')
subplot(121); showim(1/ve,q2=95,t='Precision ($1/\sigma^2_\epsilon$) Estimate');
subplot(122); showim(λGP3,'Posterior Rate');
GP: ∙ Normalized MSE: 48.6% ∙ Pearson correlation: 0.73
We can get an even better model of the data by fitting a log-Gaussian Cox process model to the binned count observations. This places a Gaussian process prior on the logarithm of the intensity,
Above,
Recall that the probability of observing spike count
We work with log-probability for numerical stability. The log-probability of observing spike count
We estimate a posterior distribution on
The maximum a posteriori estimate
We find
The negative log-posterior
We bin the data into
The negative log-posterior can then be written as:
Written as a sum over bins like this,
We find the MAP by minimizing the above as a function of minimize
function performed poorly, either crashing, failing to terminate. Scipy's conjugate gradient method performed the best, but achieved poor error tolerance. Instead, we can build our own Newton-Raphson solver.
Finding the maximum a posteriori using Newton-Raphson
Newton-Raphson solves a linear system on each iteration. Each iteration takes the same amount of time as solving a single GP regression problems.
(Indeed, one can view each stage of Newton-Raphson as its own GP regression problem. This is the idea behind the Iteratively Reweighted Least Squares (IRLS) approach to fitting Generalized Linear Models (GLMs). The Gaussian process model used here can be viewed as a Poisson GLM with the GP prior acting as a regularizer. Lieven Clement has a good introduction on IRLS.)
Each iteration of Newton-Raphson updates the parameters as
where
To apply Newton-Raphson we need to calculate the Hessian matrix and Jacobian vector. We can express these as a sum of a contribution from the log-prior and log-likelihood.
The negative log-prior is
For the negative log-likelihood
These can be written in vector form as
where
The Jacobian and Hessian can be written as:
The Newton-Raphson update is then given by
Note: (pre) conditioning
As in GP regression, this problem can be numerically unstable if
However, when minres
.
Note: when to use a separate bias term?
Sometimes, you might want to separate out the average log-rate, and paramterize the LGCP as
Since GP regression is linear, the average firing rate is the maximum likelihood estimate of the average of the posterior mean, and it suffices to subtract the average rate before inferring the firing rate map. The average log-rate is less straightforward to estimate in LGCP regression, and it must be inferred along with the weights during optimization.
Here, we limit small eigenvalues of
However, in other applications
Note: Iteratevely Reweighted Least-Squares (IRLS)
The Iteratevely Reweighted Least-Squares (IRLS) approach recasts the Newton-Raphson iteration as solving a new GP regression problem. Rewrite the Newton-Raphson iteration as:
Recall the formula for the GP posterior is
This confirms that estimating the LGCP posterior has similar complexity to GP regression.
Note: initializing a prior for log-Gaussian inference
So far, we have defined a Poisson observation model and log-posterior loss function for a log-Gaussian point-process model of the grid cell. We also need to initialize a sensible prior for the weights
# define a "safe log" function
minrate = 1e-2
slog = lambda x:log(maximum(minrate,x))
# Precompute variables; Passed as globals to jac/hess
n = N.ravel()
y = nan_to_num(K/N)
lλh = slog(kdeλ(N,K,fgσ))
kern = kern0*zerolag(radial_acorr(lλh))
kern = repair_small_eigenvalues(kern,1e-5)
knft = fft2(kern)
kift = 1/knft
kift[0,0]=0
# preconditioner given by prior covariance
Mv = lambda v:conv(v,knft).ravel()
M = LinearOperator((L**2,)*2,Mv,Mv,dtype=np.float32)
def jacobian(w):
J0 = conv(w,kift).ravel()
Jl = n*(exp(w)-y.ravel())
return J0+Jl
def hessian(w):
# Hessian as linear operator to use with minres
nλ = n*exp(w)
Hv = lambda u:conv(u,kift).ravel()+u*nλ
return LinearOperator((L**2,)*2,Hv,Hv,dtype=np.float64)
def newton_raphson(lλh,J,H,tol=1e-3,mtol=1e-5):
u = lλh.ravel()
for i in range(10):
Δ = -minres(H(u),J(u),tol=mtol,M=M)[0]
u += Δ
if max(abs(Δ))<tol: return u
print('Iteration did not converge')
w1 = newton_raphson(lλh,jacobian,hessian)
LGCP1 = w1.reshape(L,L)
printstats(slog(λ0), LGCP1,'LGCP, log-rate')
subplot(131); showkn(kern ,'Kernel');
subplot(132); showim(y ,'Observations');
subplot(133); showim(LGCP1,'Log-Rate');
LGCP, log-rate: ∙ Normalized MSE: 2.1% ∙ Pearson correlation: 0.75
λhat = kdeλ(N,K,fgσ) # Foreground rate
λbg = kdeλ(N,K,bgσ) # Background rate
lλh = slog(λhat) # Log rate
lλb = slog(λbg) # Log background
# Precompute variables; Passed as globals to jac/hess
kern = kern0*zerolag(radial_acorr(lλh-lλb))
kern = repair_small_eigenvalues(kern,1e-5)
knft = fft2(kern)
kift = 1.0/knft
Mv = lambda v:conv(v,knft).ravel()
M = LinearOperator((L**2,)*2,Mv,Mv,dtype=np.float32)
def jacobian(w):
J0 = conv(w,kift).ravel()
Jl = n*(exp(w+lλb.ravel())-y.ravel())
return J0+Jl
def hessian(w):
nλ = n*exp(w+lλb.ravel())
Hv = lambda u:conv(u,kift).ravel()+u*nλ
return LinearOperator((L**2,)*2,Hv,Hv,dtype=np.float64)
# Fit model and unpack result
w2 = newton_raphson(lλh-lλb,jacobian,hessian)
LGCP2 = w2.reshape(L,L) + lλb
printstats(slog(λ0),LGCP2,'LGCP, log-rate')
subplot(131); showkn(kern,'Kernel');
subplot(132); showim(y,'Observations');
subplot(133); showim(w2,'Normalized Log-Rate');
LGCP, log-rate: ∙ Normalized MSE: 2.4% ∙ Pearson correlation: 0.73
Convolution approximation
We can also approximate the log-Gaussian model as a convolution. This amounts to considering only the first iteration of Newton-Raphson as if it were a GP regression problem, and replacing the per-bin measurement noise with its average. Consider a single iteration of the Newton-Raphson iteration, for the weights alone. This is:
where
This implies an approximate solution in terms of two convolutions:
where
def LGCP_convolutional(N,K,fgσ,bgσ,kern,pad):
# Evaluate via convolution
kern = repair_small_eigenvalues(kern)
y = mirrorpad(nan_to_num(K/N),pad)
λ = kdeλ(N,K,fgσ)
lb = slog(kdeλ(N,K,bgσ))
w = slog(λ)-lb
β = mean(w[N>0])
w = w-β
c = mean(1/(N*λ)[N>0])
Σf = fft2(kern)
Gf = c/(c+Σf)
w -= conv(w+conv(N*(λ-y),Σf),Gf)
return w+β+lb
LGCP3 = LGCP_convolutional(N,K,fgσ,bgσ,kern,pad)
printstats(LGCP2,LGCP3,'Error between Newton-Raphson and convolution')
subplot(121); showim(LGCP2-lλb,'LGCP Log-Rate')
subplot(122); showim(LGCP3-lλb,'Convolution');
Error between Newton-Raphson and convolution: ∙ Normalized MSE: 0.6% ∙ Pearson correlation: 0.97
- Introduction:
- Reviewed Bayesian inference for Gaussian processes
- Simulated spiking observations from a hippocampal grid cell
- Discussed estimators of spike rate
- Smoothing approaches:
- Inferred grid rate maps using KDE
- Illustrated smoothing using GP regression
- Showed that GP smoothing is similar to KDE
- Developed a periodic kernel for inferring grid maps with GP regression
- Log-Gaussian Cox processes (LGCP)
- Introduced log-Gaussian Cox Processes for inferring rate mapss
- Derived the Jacobian and Hessian for the log-posterior of the LGCP
- Solved for the MAP estimator of the LGCP using Newton-Raphson
- Illustrated that the LGCP MAP estimate can be approximated via convolution
- Shown how to subtract background rate variations in the LGCP model
Almost done!
Estimating confidence intervals around peaks
One question that might be nagging you at this point is: should we believe the inferred rate maps? When we see a peak (a "grid field"), is this real , or just a noisy fluctuation? For this, it is useful generate some sort of confidence bounds or other summary of uncertainty in the inferred grid map.
First, let's find our peaks
def findpeaks(q,th=-inf,r=1):
# Local maxima > th in square neighborhood radius r.
L = q.shape[0]
D = 2*r
Δ = range(D+1)
q0 = q[r:-r,r:-r,...]
p = q0>th
for i,j in {(i,j) for i in Δ for j in Δ if i!=r or j!=r}:
p &= q0>=q[i:L+i-D,j:L+j-D,...]
p2 = zeros(q.shape,bool)
p2[r:-r,r:-r,...] = p
return p2
pxy = array(where((findpeaks(LGCP2)*mask).T))
figure(figsize=(4,3));
showim(w2,'Peaks');
scatter(*pxy,s=5,facecolor='k',edgecolor='w',lw=0.4);
Both GP regression and LGCP model provide estimates of the posterior covariance, which encodes the uncertainty in our posterior mean (or mode).
For GP regression, the covariance is
For the LGCP model, we can use the Laplace approximation to model the uncertainty in our MAP estimate
Intuitively, directions with higher curvature in our loss function are more constrained, and so have lower posterior variance. Conversely, unconstrained directions have low curvature, and therefore large posterior variance.
Incidentally, the curvature of the observation likelihood in the LGCP model,
In these notes, GP regression inferred a posterior distribution on
Using the Laplace approximation to calculate uncertainty in peak location
The posterior covariance describes a distribution over different possible rate maps. We can denote this as the posterior mean (or mode)
The fluctuations
If
One can calculate
Above,
We are interested in summarizing the overall uncertainty in the location of a peak. This is captured by by the covariance
The term
Building a low-rank model of the posterior variance
A
The solution is to approximate the posterior covariance in a low-rank subspace
where
Selecting a random basis is common. In our case, we use the discrete cosyne transform of the 2D spatial domain. This can be computed rapidly using the FFT. Additionally, the prior
def mirror(x):
# Mirror LxL data up to 2L+1 x 2L+1
x = x.reshape(L,L)
return block([[x,fliplr(x[:,1:])],[flipud(x[1:,:]),fliplr(flipud(x[1:,1:]))]])
def padout(kern):
# Zero-pad LxL kernel up to 2L+1 x 2L+1
k2 = zeros((L*2-1,L*2-1))
k2[L//2:L//2+L,L//2:L//2+L] = fftshift(kern)
return fftshift(k2)
# Why this DCT implementation?
# - This implmementation can be used directy to evaluate convolution with reflected
# boundary conditions via pointwise multiplication (convolution theorem)
# - It's based on the FFT of real symmetric data, so the data packing and interpretation
# of the coefficient matrix is the same as that of a FFT of twice the size
# - The eigenvalues are real-valued, so they can be used directly with
# linear algebra routines that require real-valued input
normalization = 1/(L*2+1)
def dct2v(x):
# DCT Option 1: reflect data to create symmetry
x = x.reshape(L,L)
return real(fft2(mirror(x)))[:L,:L]*normalization
def dct2k(k):
# DCT Option 2: if kernel already symmetric, zero pad
return real(fft2(padout(k.reshape(L,L))))[:L,:L]
def idct2(x):
# Inverse DCT
return real(fft2(mirror(x)))[:L,:L]*normalization
def dctconv(v,kct):
# Apply convolution operator via DCT
xct = dct2v(v)
return idct2(xct*kct).ravel()
# DCT inverse should work and DCT
# and FFT convolution should be similar
x = randn(L,L)
x1 = idct2(dct2v(x))
printstats(x,x1,'DCT inverse')
x1 = conv(x,fft2(kern))
x2 = dctconv(x,dct2k(kern))
printstats(x1,x2,'DCT convolution')
DCT inverse: ∙ Normalized MSE: 0.0% ∙ Pearson correlation: 1.00 DCT convolution: ∙ Normalized MSE: 1.0% ∙ Pearson correlation: 0.99
# Low-rank approximation in frequency space using the DCT
from scipy.sparse import coo_matrix
keig = abs(dct2k(kern0))
print('minimum eigenvalue magnitude %e'%np.min(keig))
print('maximum eigenvalue magnitude %e'%np.max(keig))
mine = 0.005*np.max(keig)
use2 = keig>=mine
use1 = any(use2,0)
use3 = use2[:,use1][use1,:].ravel()
M2 = sum(use2)
M1 = sum(use1)
down = coo_matrix(eye(L*L)[use2.ravel()])
print('Using %d components'%M2)
minimum eigenvalue magnitude 0.000000e+00 maximum eigenvalue magnitude 1.251435e+02 Using 931 components
def dct2lr(v):
# send vector into low-rank representation
if np.all(v==0): return zeros(M2)
v = v.reshape(L,L)
for i in range(2):
v = block([v,fliplr(v[:,1:])])
v = real(fft(v)).T[:L][use1]*normalization
return v.ravel()[use3]
def idct2lr(u):
# expand vector from subspace
u = u.ravel()#@pcndi
return idct2(down.T@u).ravel()
def dct2Alr(A):
# collapse L²×L² matrix to subspace
A = array([dct2lr(a) for a in A.reshape(L*L,L,L)]).T
A = array([dct2lr(a) for a in A.reshape(M2,L,L)]).T
return A
# Expand matrix on left size from subspace
def idct2Alr_left(A):
# Expand compressed representation on the left
# A is MxM, we return NxM, N=L*L
return array([idct2lr(a) for a in A.T]).T
def dct2klr(k):
# DCT Option 2: if kernel already symmetric, zero pad
return dct2k(k)[use2].ravel()
def dctconvlr(v,klr):
# Apply convolution operator via DCT
xlr = dct2lr(v)
return idct2lr(xlr*klr).ravel()
# low-rank DCT inverse should work and low-rank DCT
# and FFT convolution should be similar if input is
# well-approximated by low-rank.
x = randn(L,L)
x1 = conv(x,fft2(kern))
x2 = dctconvlr(x,dct2klr(kern))
printstats(x1,x2,'low-rank convolution')
print(mean(x1),mean(x2))
x1 = idct2lr(dct2lr(x2))
printstats(x1,x2,'low-rank inverse')
print(mean(x1),mean(x2))
low-rank convolution: ∙ Normalized MSE: 25956.5% ∙ Pearson correlation: 1.00 -0.04962543031971363 -0.00018115893491703268 low-rank inverse: ∙ Normalized MSE: 25905.1% ∙ Pearson correlation: 1.00 -6.939700847602088e-07 -0.00018115893491703268
from scipy.linalg import cholesky as chol
from scipy.linalg.lapack import dtrtri
from scipy.linalg import solve_triangular as stri
# Calculate low-rank posterior covariance
klr = maximum(dct2klr(kern),1e-5)
kilr = 1/klr
v = eye(L)
v = block([v,fliplr(v[:,1:])])
v = real(fft(v)).T[:L][use1]*normalization
G = einsum('ml,ML->mMlL',v,v).reshape(M1*M1,L*L)[use3]
Hlr = diag(kilr*(L*2+1)**-2) + (G*(n*exp(LGCP2).ravel()))@G.T
Clr = chol(Hlr)
Dlr = dtrtri(Clr)[0]
Qlr = G.T@Dlr
Armed with this low-rank approximation, we can now calculate
The derivatives
def dx_op(L):
# 2D difference operator in the 1st coordinate
dx = zeros((L,L))
dx[0, 1]=-.5
dx[0,-1]= .5
return dx
def hessian_2D(q):
# Get Hessian at all points
dx = dx_op(q.shape[0])
f1 = fft2(dx)
f2 = fft2(dx.T)
d11 = conv(q,f1*f1)
d12 = conv(q,f2*f1)
d22 = conv(q,f2*f2)
return array([[d11,d12],[d12,d22]]).transpose(2,3,0,1)
q = w2.reshape(L,L)
dx = dx_op(L)
Hx = hessian_2D(q)
Dx = det(Hx)
One we have calculated
from scipy.stats import chi2
def covariance_crosshairs(S,p=0.8):
# Generate a collection of (x,y) lines denoting the confidence
# bound for p fraction of data from 2D covariance matrix S
sigma = chi2.isf(1-p,df=2)
e,v = eigh(S)
lines = list(exp(1j*linspace(0,2*pi,181)))
lines += [nan]+list( linspace(-1,-.2,5))
lines += [nan]+list(1j*linspace(-1,-.2,5))
lines += [nan]+list( linspace(.2,.95,5))
lines += [nan]+list(1j*linspace(.2,.95,5))
lines = array(lines)
lines = array([lines.real,lines.imag])*sigma*(e**0.5)[:,None]
return solve(v,lines)
Overall, this enables reasonably fast approximate confidence intervals on the peak locations.
def cinv(X,repair=False):
# Invert matrix via Cholesky factorization
ch = chol(X)
ich = dtrtri(ch)[0]
return ich.dot(ich.T)
def csolve(H,J):
# Solve PSD linear system x = H^{-1}J via Cholesky factorization
C = chol(H)
return stri(C,stri(C.T,J,lower=True))
def plot_peakbounds(pxy,P):
# D should be the cholesky factor of the Hessian of the log-posterior
lx,ly = [],[]
for x2,x1 in pxy.T:
# Jacobian at x0
Δx1 = roll(dx ,(x1,x2),(0,1))
Δx2 = roll(dx.T,(x1,x2),(0,1))
J = array([Δx1,Δx2]).reshape(2,L**2)
# Peak location confidence
ΣxJD = csolve(-Hx[x1,x2],J@G.T@Dlr)
Σx0 = ΣxJD@ΣxJD.T
# Plot if peak is acceptably localized
if max(eigh(Σx0)[0])<P*2:
cx,cy = covariance_crosshairs(Σx0,p=0.9)
lx += [nan] + list(cx+x2)
ly += [nan] + list(cy+x1)
plot(lx,ly,color='w',lw=1.6)
plot(lx,ly,color='k',lw=0.4)
axis('off'); title('90% Confidence');
xlim(0,L); ylim(0,L)
figure(figsize=(4,3));
showim(q);
plot_peakbounds(pxy,P);

softmask = blur(mask,5,normalize=True)
def peak_density(w,Niter=1000):
# w: posterior mean or mode vector
# Ch: cholesky factor of log-posterior Hessian
q = Qlr@randn(M2,Niter)
q = (q+w2.ravel()[:,None]).reshape(L,L,Niter)
q = (q-mean(mean(q,2)[mask]))*softmask[:,:,None]
peaks = findpeaks(q,th=std(q))
dnsty = mean(peaks,axis=2) + 1/Niter
μhght = nan_to_num(sum(q*peaks,2)/sum(peaks,2))
return dnsty,μhght
dnsty,μhght = peak_density(w2)
subplot(121); showim((dnsty),'Peak Density')
colorbar(label='$\log\,\Pr(\operatorname{peak})$')
subplot(122); showim(μhght*dnsty,'Height$\cdot\Pr$(peak)')
tight_layout()
In sum, we have illustrated the following workflow for analyzing firing rate maps for hippocampal grid cells:
- Bin spikes into a histogram of total number spikes and visits to each region
- Use autocorrelation to estimate the grid scale
- Use an idealized grid to set a prior for the log rate
- Infer log-rate using kernel density estimation (KDE)
- Use this KDE estimate as initializer for LGCP regression
- Heuristically fit a log-Gaussian Cox process model using convolution
- Identify grid field centers local maxima in the inferred rate map
- Use the Laplace approximation to fit confidence regions for grid field centers
- Sample from the GP posterior to estimate the probability of a grid field center in each region
Histogram-based estimators are easy, and work if low spatial resolution is acceptable. However, they need a lot of data to return a meaningful result. It's also hard to define notions of confidence for histograms.
Kernel density estimators (KDEs) pool data from nearby regions. They are more efficient than Histograms, especially if one optimizes the kernel bandwidth to match the underlying variations in neuronal tuning. For moderate amounts of data, KDE estimators are fast, and return a reasonable estimate of the rate map.
Gaussian process (GP) regression is more statistically efficient than KDE, and also estimates posterior covariance. This enables one to estimate confidence bounds for the inferred rate maps. GP regression requires solving a large linear system, but this can be accelerated using the minimum residual algorithm and calculating matrix products using the FFT. Sometimes, GP regression can be approximated by a convolution.
Log-Gaussian Cox process (LGCP) regression is a generalization of GP regression. It infers the log-firing rate under a Poisson noise assumption, and returns good estimates from limited data. The computational cost of LGCP regression is only slightly higher than GP regression.
Overall, we illustrated several approaches to Gaussian-process regression to hippocampal grid cell data. We covered kernel density estimation, Gaussian process regression, and log-Gaussian Cox process regression. Throughout, we discussed practical issues necessary to achieve good performance, like using the FFT when possible, choosing numerically stable forms of the equations, and fast approximations based on convolution. Ultimately, we derived an FFT-based approximation to log-Gaussian Cox process regression. This provides an approach to analyzing grid cell data that is both statistically and computationally efficient.
Matrix inversion and linear system solving using Cholesky factorization
If a covariance matrix is non-singular (has no zero eigenvalues), then it is Positive Definite. This means it has a Cholesky factorization . The default behavior of scipy.linalg.cholesky
is to return an upper-triangular matrix
The routine scipy.linalg.lapack.dtrtri
will invert an upper-triangular matrix quickly. We can leverage this to invert the matrix
We can also use Cholesky factorization to quickly solve the linear system
The routine scipy.linalg.solve_triangular
can solve lower=True
argument to solve_triangular
to calculate
Cholesky decomposition and inverting a triangular matrix both have
Multiplication using the Fast Fourier Transform (FFT)
Typically, our prior covariance kernel
The Fourier transform coefficients
The inverse
The prodict
where
Form (a)
In most textbooks or tutorials you'll see the posterior mean written as:
where
When observations are sparse, this form is computationally efficient efficient, because
However, in our application we have an extended time-series where a rat visits each location many times. There are many more observations than output points, and
where
This form has the following useful properties:
and can be singular, provided is not. In GP regression, this allows measurements with zero error, and also allows priors with zero eigenvalues. For example, one might set a prior that has zeros for high-frequency components, to encode a strong assumption that the posterior function should be smooth.The update is low-rank and fast when observations are limited
If
is nonsingular, then it is positive definite. This allows one to compute quickly via Cholesky factorization. (However, it is even faster to use a Krylov subspace solver, if the prior is in a form that supports fast matrix-vector product via the FFT.)The final matrix multiplications by
can be computed via FFT. (The data need to be zero padded to do this, but this padding can be stripped away after performing the convolution, so there is no added complexity cost. Contrast this to form b, in which the zero-padding must be retained before solving a linear system, which increases the complexity).
Problem: This form is unsuitable if
Form (b)
It's also common to encounter the following form, when deriving the posterior mean for the product of two multivariate Gaussian distributions:
This form has the following useful properties
-
can be singular, so we can include "null" observations when calculating the regression over a regular grid -
is diagonal, so , which is trivial to compute - If
is nonsingular, then exists and can be computed quickly via the FFT -
is positive definite. can therefore be solve efficiently using Cholesky factorization. - Priors that assume correlations arise from nearest-neighbor interactions can be represented as a prior precision
that is nonzero only for entries for pairs of adjacent regions. This sparsity, when combined with Krylov subspace algorithms for solving the linear system, this allows for fast solutions on arbitrary topologies, for which spectral methods might not be possible.
Problem: This form is unsuitable if
Form (c)
We can also pull out
This form has the following useful properties:
- Both
and can be singular -
is trivial, and can be calculated via FFT - The product
is trivial. -
is well conditioned, so this calculation is fairly numerically stable, requiring less regularization. - If
contains many zero eigenvalues, then it is low rank . In the special case that is circulant, the FFT offers a fast conversion to/from the eigenspace of , and it is possible to use the nonzero Fourier components of as a low-rank basis for calculations.
Problem: The matrix
Additionally, for translationally-invariant kernels on a regular grid, the matrix-vector product
Form (d)
In the special case that the measurement error covariance is constant,
where
Problem: The regression must be on a periodic domain in order to compute
Note: if
Use form (a) when
- Measurement variance
and prior covariance are well defined. - Observations are sparse compared to the number of output points.
- This is as fast or faster than (b,c) for any size problem, but the extra matrix multiplications can be expensive for large systems, or when observations are dense.
Use form (b) with when
- The measurements and outputs are evaluated at the same set of points.
-
is well defined, but is not. -
exists and well-conditioned. - For medium-sized systems,
-
can be solved via Cholesky factorization.
-
- For large systems, use a Krylov subspace solver and leverage special properties of
:-
is circulant, so we can use the FFT, or -
is sparse, arising from a nearest-neighbor model.
-
Use form (c) when
- The measurements and outputs are taken at the same set of points
-
is well defined, but is not. -
is low-rank so does not exist. - For large systems, use a Krylov subspace solver and leverage special properties of
:- If
is circulant, use the FFT - Or, use a low-rank approximation
- If
Use form (d) when
- The measurements and outputs are evaluated on a regular grid.
- The problem is too large to calculate using ordinary matrix operations
- Measurement error can be approximated as constant
- Artifacts from zero or mirrored boundary conditions are acceptable, or the kernel is local and it is acceptable to discard the boundary regions.
No comments:
Post a Comment