5. Cholesky 分解中下三角矩阵的导数#
创建时间:2021-03-17
这份文档简单地学习 Cholesky 分解中下三角矩阵的导数。这是在实现 RI-SCF/MP2 的解析导数时所遇到的问题。
这篇文档的学习与参考文献与复现对象是 Murray [1]。
import numpy as np
import scipy
5.1. Cholesky 分解回顾#
对于任意实对称正定矩阵 \(\mathbf{S} \in \mathbb{R}^{n \times n}\),其 Cholesky 分解可以通过下式给出:
其中,\(\mathbf{L}\) 是下三角矩阵。
n = 10
S = np.cov(np.random.randn(n, 2 * n))
L = np.linalg.cholesky(S)
np.allclose(L @ L.T, S)
True
我们不讨论 Cholesky 分解的实现方式,只需要知道结论即可。
5.2. Cholesky 分解矩阵 \(\mathbf{L}\) 的数值导数#
现在假定 \(\mathbf{S}\) 是关于外部参量 \(x\) 的函数矩阵,并且 \(\partial_x \mathbf{S}\) 是已知且对称的。在这种情况下,我们希望求取 \(\partial_x \mathbf{L}\)。我们令 \(\partial_x \mathbf{S}\) 的变量名是 dS
。
dS = np.cov(np.random.randn(n, n))
数值导数很容易地通过数值查分方法给出:当数值导数间隔 \(h\) 很小时 (譬如 1e-6),那么下述近似关系成立:
我们令通过上述方法求得的 \(\partial_x \mathbf{L}\) 的变量名是 ndL
。
h = 1e-7
ndL = (np.linalg.cholesky(S + h * dS) - L) / h
5.3. Cholesky 分解矩阵 \(\mathbf{L}\) 的解析导数#
通过链式法则,可以知道
对等式两边同时左乘 \(\mathbf{L}^{-1}\) 并右乘 \(\mathbf{L}^{-\dagger}\),得到
我们能知道 \(\mathbf{L}^{-1}\) 与 \(\partial_x \mathbf{L}\) 都是下三角矩阵,因此它们的乘积也是下三角矩阵。同理,\(\partial_x \mathbf{L}^\dagger \mathbf{L}^{-\dagger}\) 是上三角矩阵。这两个矩阵相互呈转置关系,因此对角线上的值是相等的。
因此,我们构造下述作用关系 (或者等价地,矩阵)
F = np.zeros((n, n))
for i in range(n):
F[i, :i] = 1
F[i, i] = 1/2
那么利用上下三角的对称性,下式的下三角部分成立:
对上式左乘 \(\mathbf{L}\),立即得到
为了程序书写方便,额外定义 Linv
\(\mathbf{L}^{-1}\)。注意到点乘 \(\odot\) 的运算优先级比矩阵乘法高,但在 numpy 中点乘与矩阵乘法的运算优先级相同,因此要多加一层括号。
Linv = np.linalg.inv(L)
dL = L @ (F * (Linv @ dS @ Linv.T))
在适当的阈值下,数值与解析导数的误差相近。
np.allclose(dL, ndL, rtol=1e-5, atol=1e-6)
True
5.4. \(\mathbf{L}\) 的解析导数的快速实现#
由于矩阵求逆是 \(O(n^3)\) 运算量,计算耗时相当大;因此较为廉价的方法是利用求解线性问题,避免直接求逆。
from scipy.linalg import solve_triangular
from functools import partial
st = partial(solve_triangular, lower=True)
np.allclose(L @ (F * st(L, st(L, dS.T).T)), dL, rtol=1e-5, atol=1e-6)
True
我们现在考虑较大的矩阵 (1000 维度):
n = 1000
S = np.cov(np.random.randn(n, 2 * n))
L = np.linalg.cholesky(S)
dS = np.cov(np.random.randn(n, n))
F = np.zeros((n, n))
for i in range(n):
F[i, :i] = 1
F[i, i] = 1/2
其计算耗时可以估计如下:
%%timeit -n 10
# with inverse
Linv = np.linalg.inv(L)
dL = L @ (F * (Linv @ dS @ Linv.T))
67.6 ms ± 456 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit -n 10
# without inverse
dL = L @ (F * st(L, st(L, dS.T).T))
36.5 ms ± 1.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)