The Matrix Square Root and its Gradient

Author: Subhransu Maji

# Background

In the improved B-CNN paper published at BMVC 2017 we showed that the matrix square root is an effective way to normalize covariance matrices used for classification tasks. While the square root and its gradient can be computed via a SVD decomposition, this is not efficienly implemented on existing libraries for GPUs. The BMVC paper presented some GPU friendly routines for computing the matrix square root and its gradient. Here we discuss a two extensions that allows simpler and faster gradients through automatic differentiation and iterative methods for solving the Lyapunov equation.

Given a positive semi definite (PSD) matrix $A$ the square root is defined as a matrix $Z$ such that $ZZ=A$. This can be computed by first computing the SVD of the matrix $A=U\Sigma U^T$ after which the the square root can be obtained as $Z=U\Sigma^{1/2}U^T$. However, currently SVD lacks GPU support hence instead we proposed to use Newton Schulz iterations after scaling the matrix. Assume $||A|| < 1$. The square root can be obtained by initializing $Y_0=A$ and $Z_0=I$ and iterating: $$Y_{k+1} = \frac{1}{2}Y_k(3I - Z_k Y_k), Z_{k+1} = \frac{1}{2}(3I - Z_k Y_k)Z_k.$$ Then $Y_k \rightarrow A^{1/2}$, and $Z_k \rightarrow A^{-1/2}$. We found that 5 iterations were sufficient to match the accuracy and 10 iterations sufficient to match the square root to numerical precision on a floating point GPU. This can be 5-20x faster than computing the SVD depending on the GPU and software, especially since these iterations can be implemented in batch mode in languages like pytorch.

Training requires gradient computation. In the paper we obtained this by solving a Lyapunov equation: $$A^{1/2} \left( \frac{\partial L}{\partial A}\right) + \left( \frac{\partial L}{\partial A}\right) A^{1/2} = \frac{\partial L}{\partial Z}.$$ Given the SVD of the matrix $A$ and the gradient ${\partial L}/{\partial Z}$ the solution to the Lyapunov equation can be obtained in closed form. The Lyapunov gradients are also more numerically stable than matrix backpropagation gradients with SVD (e.g., Ionescu et al.). The former depends on $\min(1/(\lambda_i + \lambda_j))$ while the latter depends on $\min(1/(\lambda_i - \lambda_j)).$ where $\lambda_i$ is the $i^{th}$ eigenvalue of $Z$. The Lyapunov equation has a closed form solution given the SVD hence during training SVD can be used as both the forwad and backward, while iterative methods can be used as forward at test time. This means that the training is slower. While training speed is less critical than testing speed it would be nice to also be able to train faster. Below are two approaches that implement the gradients efficiently.

The gradients can be automatically obtained if the Newton Schulz iterations are implemented in a language that supports automatic differentiation. For example here is a code snippet in PyTorch. The first few lines scale the matrix since the iterations are locally convergent.

def sqrt_newton_schulz_autograd(A, numIters, dtype):
batchSize = A.data.shape[0]
dim = A.data.shape[1]
normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()
Y = A.div(normA.view(batchSize, 1, 1).expand_as(A));
I = Variable(torch.eye(dim,dim).view(1, dim, dim).
Z = Variable(torch.eye(dim,dim).view(1, dim, dim).

for i in range(numIters):
T = 0.5*(3.0*I - Z.bmm(Y))
Y = Y.bmm(T)
Z = T.bmm(Z)
sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
error = compute_error(A, sA)
return sA, error


This works quite well. For example on my GPU with a batch size of 32 and 512x512 random matrices the time taken to compute the gradient with 5 iterations is about 5-10% of the time taken for forward. Thus gradients are nearly free!

# Gradients by iterative Lyapunov solver

The drawback of the autograd approach is that a naive implementation stores all the intermediate results. Thus the memory overhead scales linearly with the number of iterations which is problematic for large matrices. In general one can tradeoff memory with computation during backpropagation. For example one can store the values of $T_k$, $Z_k$ and $Y_k$ at every $\sqrt{T}$ iterations and recompute the remaining from the previous checkpoint during backward propagation. This strategy requires $\sqrt{T}$ memory but is $2\times$ slower. Without any no additional memory overhead one can recompute the intermediate values on the fly from the input at step of the gradient but that would be $T^2 \times$ slower. Can we obtain gradients with no memory overhead without taking a hit in speed? It turns out we can do this by using an iterative method for solving the Lyapunov eqauation implemented in the following code snippet:

def lyap_newton_schulz(z, dldz, numIters, dtype):
batchSize = z.shape[0]
dim = z.shape[1]
normz = z.mul(z).sum(dim=1).sum(dim=1).sqrt()
a = z.div(normz.view(batchSize, 1, 1).expand_as(z))
I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
q = dldz.div(normz.view(batchSize, 1, 1).expand_as(z))
for i in range(numIters):
q = 0.5*(q.bmm(3.0*I - a.bmm(a)) - a.transpose(1, 2).bmm(a.transpose(1,2).bmm(q) - q.bmm(a)) )
a = 0.5*a.bmm(3.0*I - a.bmm(a))
dlda = 0.5*q
return dlda


In practice the iterative solver is as fast as the autograd but requires no checkpointing. This is great if both the forward and backward are run to convergence. However, I've not tested what happens in real problems when the method is run only for a few iterations (e.g. only once!). The experiments in the BMVC paper showed that even a single iteration is useful.

Take a look at the source code in Matlab and Python here https://github.com/msubhransu/matrix-sqrt for a detailed comparison of the methods. The implementation is fairly straightforwad using PyTorch's autograd support.

References:

• Ionescu, Catalin et al. "Matrix backpropagation for deep networks with structured layers." ICCV 15
• Lin, Tsung-Yu, and Subhransu Maji. "Improved Bilinear Pooling with CNNs." BMVC 17
• Higham, Nicholas J. "Functions of matrices: theory and computation. Society for Industrial and Applied Mathematics", 2008.
• NIPS 2017 autodiff workshop: https://autodiff-workshop.github.io