Training set $\mathcal{A} = \{(x_i, y_i)\}_{i\leq n}$, minimize the empirical risk
$$ \min_a \frac{1}{n}\sum_i (y_i - ax_i)^2$$Closed form solution:
Training set $\mathcal{A} = \{(\mathbf{x}_i, y_i)\}_{i\leq n}$, minimize the empirical risk
$$ \min_a \frac{1}{n}\sum_i (y_i - \mathbf{a}^\top\mathbf{x}_i)^2$$Closed form solution
Pseudo-inverse
Easy solution by projecting into the eigen space of $\mathbf{X}$, $\mathbf{a}$ is in the eigenspace of $\mathbf{X}$
Let $\mathbf{x}$ be a stochastic process with covariance matrix $\sum_\mathbf{x}$ then
$$ \mathbf{x}_i = \sum_k^p z_{k,i}\mathbf{e}_k $$with $\mathbf{e_k}$ the eigenvectors of $\sum_\mathbf{x}$.
key = jax.random.PRNGKey(0)
key, skey = jax.random.split(key)
x = jax.random.uniform(skey, (50, 1))
X = jnp.concatenate((x, -5*x), axis=1)
a = jnp.array([2, 0])
y = jnp.matmul(X, a)
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S))
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
print('a_hat: {} a: {}'.format(a_hat, a))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
eigenvalues: [2.0413006e+01 3.0459864e-07] a_hat: [-2.8013153 -0.96026266] a: [2 0]
Adding a constant to the model is equivalent to the vector case
$$ \min_{\mathbf{a}, b} \frac{1}{n} \|\mathbf{y} - \mathbf{X}^\top\mathbf{a} - \mathbf{1}^\top b \|^2$$Training set $\mathcal{A} = \{(\mathbf{x}_i, \mathbf{y}_i)\}_{i\leq n}$, minimize the empirical risk
$$ \min_a \frac{1}{n}\sum_i \|\mathbf{y}_i - \mathbf{A}^\top\mathbf{x}_i\|^2$$Equivalent to $p$ scalar output cases stacked together
Regress 0 and 1
data = np.load('mnist.npz')
X = data['X_train_bin']
y = data['y_train_bin']
plt.imshow(X[0,:].reshape(28,28))
print(y[0])
0
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S))
plt.subplot(1,2,1)
plt.imshow(U[:,0].reshape((28,28)))
plt.subplot(1,2,2)
plt.imshow(U[:,1].reshape((28,28)))
eigenvalues: [36.45615 17.209862 12.030047 11.947479 9.36883 7.8752723 6.7997932 6.3316774 5.729243 5.4004116 5.1866603 4.988159 4.8420706 4.1217637 3.9173453 3.5896347 3.2801106 3.1127822 3.0160408 2.7964358 2.6677098 2.543037 2.2888196 2.1948047 2.14488 1.8337901 1.0252038]
<matplotlib.image.AxesImage at 0x7f19a0082290>
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
plt.subplot(1,2,1)
plt.imshow(jnp.maximum(a_hat, 0).reshape((28, 28)))
plt.subplot(1,2,2)
plt.imshow(jnp.maximum(-a_hat, 0).reshape((28, 28)))
<matplotlib.image.AxesImage at 0x7f198c728b10>
X_val = data['X_val_bin']
y_val = data['y_val_bin']
y_hat = jnp.matmul(X_val, a_hat)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
[<matplotlib.lines.Line2D at 0x7f19a01e8b10>]
X = data['X_train']
y = data['y_train']
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S[0:10]))
plt.subplot(1,2,1)
plt.imshow(U[:,0].reshape((28,28)))
plt.subplot(1,2,2)
plt.imshow(U[:,1].reshape((28,28)))
eigenvalues: [61.84758 22.601597 19.816174 18.98699 17.21556 15.531815 14.318835 13.584824 12.348789 11.739741]
<matplotlib.image.AxesImage at 0x7f198c6107d0>
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
X_val = data['X_val'][0:30,...]
y_val = data['y_val'][0:30]
y_hat = jnp.matmul(X_val, a_hat)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
[<matplotlib.lines.Line2D at 0x7f198c554510>]
Output space not adapted (artificial topology)
y = jax.nn.one_hot(y, 10)
a = []
for k in range(10):
#
#
a_k = jnp.zeros(784)
a.append(a_k)
a = jnp.array(a)
y_hat = jnp.argmax(jnp.matmul(X_val, a.T), axis=1)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
a = [-0.2, 0.7, 0.83, -1.5, 5.23]
p = np.poly1d(a)
x = np.random.rand(24)*4-2
y = p(x) + 0.2*np.random.randn(24)
X = jnp.stack([jnp.ones((len(x))), x, x**2, x**3, x**4], axis=1)
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
print(a_hat)
pp = np.poly1d(a_hat[::-1])
t = np.linspace(-2, 2, 50)
plt.plot(x, y, 'x')
plt.plot(t, pp(t))
plt.plot(t, p(t))
[ 5.2697163 -1.5135039 0.7879974 0.70156187 -0.19815296]
[<matplotlib.lines.Line2D at 0x7f19846b1fd0>]
Map $\phi: x \mapsto [\sin(f_0 x), \sin(2f_0x), \dots, \sin(pf_0x)]$
X = jnp.array([jnp.sin(x), jnp.sin(2*x), jnp.sin(3*x)])
ap = jnp.matmul(jnp.matmul(jnp.linalg.inv(jnp.matmul(X, X.transpose())), X), y)
Yp = jnp.matmul(ap, X)
t = np.linspace(-8, 8, 200)
T = np.array([np.sin(t), np.sin(2*t), np.sin(3*t)])
plt.plot(x, y, 'x')
plt.plot(t, np.matmul(ap, T))
plt.plot(t, np.matmul(a, T))
[<matplotlib.lines.Line2D at 0x7f198457be10>]
What if $p < \hat{p}$ (model has greater capacity than data)
def sin_approx(x, y, p):
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
ap = jnp.matmul(jnp.matmul(jnp.linalg.inv(jnp.matmul(Xp, Xp.transpose())), Xp), y)
Yp = jnp.matmul(ap, Xp)
return ap, Xp, Yp
p = 3
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.59392786 0.83312976 -1.5430541 ]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.09781882911920547
p = 5
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.59637094 0.8385289 -1.5672264 0.06469631 -0.01800793]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.0961289182305336
p = 10
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.6073679 0.81426585 -1.5253923 0.03802376 -0.00418478 -0.01558349 0.06264809 0.05295399 -0.09149553 0.11939779]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.08592929691076279
p = 15
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.35529602 1.1142317 -1.7437214 0.15871115 0.00858292 -0.10777945 0.15600105 -0.08507273 0.08278456 -0.04437378 0.08144233 -0.01490425 -0.21221659 0.28036714 -0.02712816]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.06189465522766113
p = 20
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [-0.11323678 1.7993791 -2.3644838 0.49984902 -0.01194978 -0.32451046 0.5129152 -0.50142884 0.45345783 -0.24447411 0.10667838 0.20195109 -0.6100314 0.6498178 -0.25845143 -0.03708184 0.16033 -0.29417005 0.03206168 -0.0902726 ]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.04660849645733833
p = 25
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 2.483035 -1.9207842 0.57731056 -0.44159245 -0.9491972 1.4420291 -1.1824136 0.8650079 -0.61695516 0.45492303 0.19773307 -1.0832504 1.5003977 -1.1892037 0.3630464 0.81794167 -1.6101568 1.2754178 -0.51291 -0.5356204 0.87538815 -0.7491144 0.10084951 0.21558303 -0.11464696]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.0363682359457016
p = 35
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [-34.420845 53.54271 -47.404167 24.04839 2.5520477 -22.580738 33.575417 -30.581444 22.487469 -8.539528 -5.8411484 17.944695 -26.375664 26.720001 -15.869003 -2.5119934 19.008682 -22.40055 14.070158 -1.2852402 -7.173237 8.027071 -5.567108 0.89237213 2.7684917 -2.0045357 -2.4827309 5.055525 -2.6962426 -2.1028605 4.513747 -2.6814532 -0.19490051 0.98603344 -0.61041975]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 2.7143747806549072
x_val = np.random.rand(48)*16-8
X_val = jnp.array([ jnp.sin(x_val), jnp.sin(2*x_val), jnp.sin(3*x_val)])
y_val = jnp.matmul(a, X_val) + 0.3*np.random.randn(48)
mse = []
mse_val = []
for p in range(25):
ap, Xp, Yp = sin_approx(x, y, p)
_, Xp_val, _ = sin_approx(x_val, y_val, p)
Yp_val = jnp.matmul(ap, Xp_val)
mse.append(((y - Yp)**2).mean()), mse_val.append(((y_val - Yp_val)**2).mean())
plt.plot(mse, label='mse'); plt.plot(mse_val, label='mse_val')
plt.legend()
<matplotlib.legend.Legend at 0x7f19a3b17310>
Noisy observation: $$ y = \mathbf{a}^\top\mathbf{x} + \varepsilon, \varepsilon \sim \mathcal{N}(0, \sigma)$$
Assume $\|\mathbf{a}\|_0 < d$ (not all input dimensions are used), can we force $\hat{\mathbf{a}}$ to be also sparse?
With $\Omega(\mathbf{a})$ a regularizer that increases cost for more complex $\mathbf{a}$
Least Absolute Shrinkage and Selection Operator
$$\min_\mathbf{a} \frac{1}{n}\sum_i (y_i - \mathbf{x}_i^\top\mathbf{a})^2 + \lambda\|\mathbf{a}\|_1 $$Optimize using gradient descent
$$ \mathbf{a} \leftarrow \mathbf{a} - \eta \left[ \frac{-2}{n}\sum_i (y_i - \mathbf{x}_i^\top\mathbf{a}) + \lambda\text{sign}(\mathbf{a}) \right] $$def sin_pred(a, X):
return jnp.matmul(a, X)
def sin_lasso(a, X, y, lam):
yp = sin_pred(a, X)
return ((y - yp)**2).mean() + lam*jnp.abs(a).sum()
@jax.jit
def update(a, X, y, lam):
da = jax.grad(sin_lasso, argnums=0)(a, X, y, lam)
return a - 0.05*da
p=25
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
ap = jnp.zeros(p)
for i in range(100):
ap = update(ap, Xp, y, 0.1)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 6.5387887e-01 4.9198857e-01 -1.2269275e+00 1.6643824e-03 -5.4124887e-03 2.7261022e-03 -4.6199327e-03 1.8823075e-03 -1.3611598e-03 1.0498903e-01 6.7312829e-04 -6.6337660e-02 1.5109220e-03 6.7850342e-03 6.5390445e-02 -2.7749939e-03 -8.5329272e-02 3.3036065e-03 -6.3027009e-02 -5.3070008e-04 -4.4770658e-04 -5.5702450e-03 -4.1388847e-02 4.2398345e-02 5.0350791e-03]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.03701911121606827
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(a, T[:len(a), :]), label='gt')
for lam in [0, 0.0001, 0.001, 0.01, 0.1, 1.0]:
ap = jnp.zeros(p)
for i in range(50):
ap = update(ap, Xp, y, lam)
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]), label='$\lambda={}$'.format(lam))
plt.legend()
<matplotlib.legend.Legend at 0x7f19842dc2d0>
def lasso_approx(Xp, y, lam):
ap = jnp.zeros(len(Xp[:,0]))
for i in range(100):
ap = update(ap, Xp, y, lam)
Yp = jnp.matmul(ap, Xp)
return ap, Xp, Yp
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
Xp_val = jnp.sin(jnp.matmul(x_val.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
mse = []
mse_val = []
for lam in [0, 0.0001, 0.001, 0.01, 0.1, 1.0]:
ap, Xp, Yp = lasso_approx(Xp, y, lam)
Yp_val = jnp.matmul(ap, Xp_val)
mse.append(((y - Yp)**2).mean()), mse_val.append(((y_val - Yp_val)**2).mean())
plt.plot(mse, label='mse'); plt.plot(mse_val, label='mse_val')
plt.legend()
<matplotlib.legend.Legend at 0x7f1928662990>
Project $\mathbf{X}$ into its eigenspace:
$$ \min_\mathbf{a} \frac{1}{n} \|\mathbf{y} - \mathbf{X}^\top\mathbf{Ua}\|^2 + \lambda\|\mathbf{a}\|_1 $$$$ = \frac{1}{n} \|\mathbf{y} - \mathbf{VSa}\|^2 + \lambda\|\mathbf{a}\|_1$$Stationary condition:
$$\frac{\partial}{\partial \mathbf{a}} = 0 = -2\mathbf{SV}^\top\mathbf{y} + 2\mathbf{S}^2\mathbf{a} + \lambda\text{sign}(\mathbf{a})$$$$\mathbf{a} = \mathbf{S}^{-1}\mathbf{V}^\top\mathbf{y} - \mathbf{S}^{-2}\frac{\lambda\text{sign}(\mathbf{a})}{2}$$Let $\tilde{\mathbf{a}} = \mathbf{S}^{-1}\mathbf{V}^\top\mathbf{y}$
Note that $\text{sign}(\mathbf{a}) = \text{sign}(\tilde{\mathbf{a}}) = \frac{\tilde{\mathbf{a}}}{\vert\tilde{\mathbf{a}}\vert}$ $$\mathbf{a} = \tilde{\mathbf{a}}\left(1 - \frac{\lambda\mathbf{S}^{-2}}{2\vert\tilde{\mathbf{a}}\vert}\right)$$
Case $>0$, $\text{sign}(\mathbf{a}) = \text{sign}(\tilde{\mathbf{a}}) = 1$
$$ \mathbf{a}_i = \color{blue}{\underbrace{\tilde{\mathbf{a}}_i}_{>0}} \left(1 - \frac{\lambda\mathbf{S}^{-2}_i}{2\vert\tilde{\mathbf{a}}_i\vert}\right) >0$$$$ \mathbf{a}_i = \tilde{\mathbf{a}}_i \max\left(0, 1 - \frac{\lambda\mathbf{S}^{-2}_i}{2\vert\tilde{\mathbf{a}}_i\vert}\right) $$Case $<0$, $\text{sign}(\mathbf{a}) = \text{sign}(\tilde{\mathbf{a}}) = -1$
$$ \mathbf{a}_i = \color{blue}{\underbrace{\tilde{\mathbf{a}}_i}_{<0}} \left(1 - \frac{\lambda\mathbf{S}^{-2}}{2\vert\tilde{\mathbf{a}}_i\vert}\right) <0$$$$ \mathbf{a}_i = \tilde{\mathbf{a}}_i \max\left(0, 1 - \frac{\lambda\mathbf{S}^{-2}}{2\vert\tilde{\mathbf{a}}_i\vert}\right) $$Soft thresholding:
$$ \mathbf{a} = \tilde{\mathbf{a}} \max\left(0, 1 - \frac{\lambda\mathbf{S}^{-2}}{2\vert\tilde{\mathbf{a}}\vert}\right) $$$\lambda$ removes components that would change the sign of the solution $\rightarrow$ Sparse solution
Remember: analysis only valid in eigenspace
In eigenspace, pseudo-inverse solution: $$ \mathbf{a} = \mathbf{US}^{-1}\mathbf{V}^\top\mathbf{y}$$
What if $\mathbf{S}$ has small eigenvalues ($\text{span}(\mathbf{X}) < d$)? How to prevent solution to focus on the noise?
Avoid large values in the solution:
$$\min_\mathbf{a} \frac{1}{n}\sum_i (y_i - \mathbf{x}_i^\top\mathbf{a})^2 + \lambda\|\mathbf{a}\|^2 $$Tikonov regularization (ridge regression)
Stationary condition:
$$\frac{\partial}{\partial \mathbf{a}} = 0 =\frac{2}{n}(\mathbf{X}^\top\mathbf{y} + \mathbf{XX}^\top\mathbf{a}) +2\lambda\mathbf{a} $$Offsetting all eigenvalues in the covariance matrix by $\lambda$
Add both regularization
$$\min_\mathbf{a} \frac{1}{n}\sum_i (y_i - \mathbf{x}_i^\top\mathbf{a})^2 + \lambda_1\|\mathbf{a}\|_1 +\lambda_2\|\mathbf{a}\|_2^2$$Optimize using gradient descent
$\ell_2$ is sensitive to outliers
x = np.random.rand(24)*16-8
y = x + 0.1*np.random.rand(24)
y[0] += 20
a = jnp.dot(x, y)/jnp.dot(x, x)
print(a)
t = jnp.linspace(-8, 8, 10)
plt.plot(x, y, 'x')
plt.plot(t, a*t, '-r')
1.0905238
[<matplotlib.lines.Line2D at 0x7f198467fd50>]
Mean absolute error (or $\ell_1$ error)
$$\min_\mathbf{a} \mathbb{E}[ \vert y_i - \mathbf{a}^\top\mathbf{x}_i\vert ]$$Vector case
$$\min_\mathbf{a} \mathbb{E}[ \| \mathbf{y}_i - \mathbf{A}^\top\mathbf{x}_i\|_1 ]$$No close form solution, gradient descent (subderivative $\nabla\|0\|_1 = 0$)
robust regression
def l1(a, x, y):
return jnp.abs(y - a*x).mean()
@jax.jit
def update(a, x, y):
da = jax.grad(l1, argnums=0)(a, x, y)
return a - 0.02*da
a = 0.
for i in range(100):
a = update(a, x, y)
print(a)
t = jnp.linspace(-8, 8, 10)
plt.plot(x, y, 'x')
plt.plot(t, a*t, '-r')
0.974505
[<matplotlib.lines.Line2D at 0x7f1984355790>]
t = jnp.linspace(-0.5, 0.5, 25)
plt.plot(t, jnp.abs(t), label='MAE')
plt.plot(t, t**2, label='MSE')
plt.legend()
<matplotlib.legend.Legend at 0x7f1928500110>
Optimize using gradient descent
Unsupervised learning: target space is a new representation of the input space
If $p < d$, then $\mathbf{D}$ are the $p$ leading left singular vectors of $\mathbf{X}$ and the factors $\mathbf{A}$ are the combination of the corresponding singular values with the right singular vectors
If $p > d$, we have an overcomplete dictionary, which means we can afford to not use all entries to reconstruct a sample
$$ \min_{\mathbf{a}} \|\mathbf{x} - \mathbf{Da}\|^2 +\lambda \|\mathbf{a}\|_0$$Sparse coding
Alternate update:
Fix $\mathbf{D}$, update $\mathbf{A}$
Fix $\mathbf{A}$, update $\mathbf{D}$
Update one atom at a time
Iterative updates: $$\min_{\mathbf{d}_k, \mathbf{a}^k} \|\mathbf{E}_k - \mathbf{d}_k\mathbf{a}^k\|_F^2$$
X = jnp.transpose(data['X_train'])
y = data['y_train']
D = np.random.rand(784, 64)
A = np.random.rand(64, 100)
for e in range(50):
D = jnp.matmul(jnp.matmul(X, A.T), jnp.linalg.inv(jnp.matmul(A, A.T)))
A = jnp.matmul(jnp.linalg.inv(jnp.matmul(D.T, D)), jnp.matmul(D.T, X))
S = jnp.sign(A)
I = jnp.argsort(jnp.abs(A), axis=0)[-33, :]
A = S * jnp.clip(jnp.abs(A) - jnp.abs(A[I, jnp.arange(100)]), a_min=0)
plt.subplot(1,5,1)
plt.imshow(D[:,0].reshape(28, 28))
plt.subplot(1,5,2)
plt.imshow(D[:,1].reshape(28, 28))
plt.subplot(1,5,3)
plt.imshow(D[:,2].reshape(28, 28))
plt.subplot(1,5,4)
plt.imshow(D[:,3].reshape(28, 28))
plt.subplot(1,5,5)
plt.imshow(D[:,4].reshape(28, 28))
<matplotlib.image.AxesImage at 0x7f19282cff10>
plt.imshow(A)
<matplotlib.image.AxesImage at 0x7f19281f1e10>
Exercise: Try a linear regression (vector output) using $\mathbf{A}$ instead of $\mathbf{X}$
Relation to k-means
$$ \min_{\mathbf{D},\mathbf{A}} \|\mathbf{X} - \mathbf{DA}\|_F^2\quad \text{ s.t. } \forall i, \|\mathbf{a}_i\|_0 = 1$$Dictionary learning