Matrix WhiteningΒΆ

A white dataset is one with identically distributed, independent random variables. The Central Limit Theorem tells us that the sum of such variables converges to a normal distribution, and so matrix whitening is useful for conditioning datasets for models requiring normally distributed input.

The eigenvalue decomposition of some data vector (R) provides the eigenvectors (V) and eigenvalues (D) such that RV=VD. V and D are used to calculate the whitening matrix (W), a matrix that transforms data (X) into a vector (Y) whose elements are decorrelated and have unit variance.

\[\begin{split}Y &= WX \\ COV(Y) &\approx E[YY^T] = \frac{YY^T}{N} = \frac{WXX^TW^T}{N} \approx W COV(X) W^T\end{split}\]

Eigen-decomposition of the covariance of X:

\[\begin{split}COV(X) V = VD \\ D = V^T COV(X) V\end{split}\]

Dividing the covariance of Y by the eigenvalues of the covariance of X:

\[\frac{COV(Y)}{D} = \frac{W COV(X) W^T}{V^T COV(X) V}\]

Solving for W, recognizing that COV(Y) = I:

\[\begin{split}\sqrt{\frac{I}{D}} &= \sqrt{\frac{W COV(X) W^T}{V^T COV(X) V}} \\ D^{-1/2} &= \frac{W}{V^T} \\ W &= V^T D^{-1/2}\end{split}\]

V decorrelates and the inverse square root of D (a diagonal matrix) performs the variance normalization. Figure 1 shows the correlation between three variables, before and after whitening.

_images/whitening.png
 1# Python 3.x
 2import numpy as np
 3from scipy.linalg import cholesky, eigh
 4from scipy.stats import norm
 5import matplotlib.pyplot as plt
 6
 7def plotCorrelation(dat, i1, i2, name, c, ax):
 8    ax.scatter(dat[i1], dat[i2], c=c)
 9    ax.set_xlabel('%s[%d]' % (name,i1))
10    ax.set_ylabel('%s[%d]' % (name,i2))
11    ax.set_aspect('equal')
12
13# Create data in need of whitening (X)
14Ns = 512 # number of samples per variable
15R = np.array([
16        [4.1, -2.9, -2.1],
17        [-2.9, 2.7, 1.2],
18        [-2.1, 1.2, 1.8]]) # covariance matrix for 3-variables
19C = cholesky(R, lower=True) # CC^T=R
20X = np.dot(C, norm.rvs(size=(R.shape[0], Ns)))
21
22# Y=WX, where W is the whitening matrix
23D,V = eigh(np.cov(X)) # eigenvalues and eigenvectors
24W = np.diag(D**-0.5) @ V.transpose()
25Y = W @ X
26
27fig=plt.figure(1); plt.clf()
28plotCorrelation(X, 0, 1, 'X', 'teal', fig.add_subplot(321))
29plotCorrelation(Y, 0, 1, 'Y', 'orange', fig.add_subplot(322))
30plotCorrelation(X, 0, 2, 'X', 'teal', fig.add_subplot(323))
31plotCorrelation(Y, 0, 2, 'Y', 'orange', fig.add_subplot(324))
32plotCorrelation(X, 1, 2, 'X', 'teal', fig.add_subplot(325))
33plotCorrelation(Y, 1, 2, 'Y', 'orange', fig.add_subplot(326))
34plt.tight_layout()
35
36# COV(Y) should be close to the identity matrix
37print('COV(Y) =\n', np.cov(Y))