Closed form solution:
- vectorize: x=[xi], y=[yi] mina1n‖y−ax‖2
- Stationary condition ∂∂a1n‖y−ax‖2=0=2a‖x‖2−2⟨y,x⟩ a=y⊤x‖x‖2
Linear regression - Vector input, scalar output¶
- Input space: x∈Rd
- Output space: y∈R
- Linear model: f(x)=a⊤x
minaEx[(y−a⊤x)2]
Training set A={(xi,yi)}i≤n, minimize the empirical risk
mina1n∑i(yi−a⊤xi)2
Closed form solution
- Matrix form: X=[xi], y=[yi]
mina1n‖y−X⊤a‖2
- Stationary condition
∂∂a1n‖y−X⊤a‖2=0=−2X⊤y+2XX⊤a
a=(XX⊤)−1X⊤y
Pseudo-inverse
Vector input, bis¶
- SVD: X=USV⊤
y=VSU⊤a a=US−1V⊤y
Easy solution by projecting into the eigen space of X, a is in the eigenspace of X
Karhunen-Loève theorem¶
Let x be a stochastic process with covariance matrix ∑x then
xi=p∑kzk,iek
with ek the eigenvectors of ∑x.
- Samples xi exist in the space spaned by the eigenvectors of the covariance matrix (hence PCA)
- if span(x)<d, some dimensions are useless (noisy)
- Strong influence on the pseudo-inverse solution (S−1) ⇒ remove directions with small eigenvalues
key = jax.random.PRNGKey(0)
key, skey = jax.random.split(key)
x = jax.random.uniform(skey, (50, 1))
X = jnp.concatenate((x, -5*x), axis=1)
a = jnp.array([2, 0])
y = jnp.matmul(X, a)
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S))
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
print('a_hat: {} a: {}'.format(a_hat, a))
eigenvalues: [2.0413006e+01 3.0459864e-07] a_hat: [-2.8013153 -0.96026266] a: [2 0]
Linear regression, bias case¶
Adding a constant to the model is equivalent to the vector case
mina,b1n‖y−X⊤a−1⊤b‖2
- concatenate b to a and 1 to X
mina,b1n‖y−[X;1]⊤[a;b]‖2
Linear Regression, Vector input, vector output¶
- Input space: x∈Rd
- Output space: y∈Rp
- Linear model: f(x)=A⊤x
minaEx[‖y−A⊤x‖2]
Training set A={(xi,yi)}i≤n, minimize the empirical risk
mina1n∑i‖yi−A⊤xi‖2
Equivalent to p scalar output cases stacked together
Let's try with MNIST¶
Regress 0 and 1
data = np.load('mnist.npz')
X = data['X_train_bin']
y = data['y_train_bin']
plt.imshow(X[0,:].reshape(28,28))
print(y[0])
0
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S))
plt.subplot(1,2,1)
plt.imshow(U[:,0].reshape((28,28)))
plt.subplot(1,2,2)
plt.imshow(U[:,1].reshape((28,28)))
eigenvalues: [36.45615 17.209862 12.030047 11.947479 9.36883 7.8752723 6.7997932 6.3316774 5.729243 5.4004116 5.1866603 4.988159 4.8420706 4.1217637 3.9173453 3.5896347 3.2801106 3.1127822 3.0160408 2.7964358 2.6677098 2.543037 2.2888196 2.1948047 2.14488 1.8337901 1.0252038]
<matplotlib.image.AxesImage at 0x71008d42c040>
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
plt.subplot(1,2,1)
plt.imshow(jnp.maximum(a_hat, 0).reshape((28, 28)))
plt.subplot(1,2,2)
plt.imshow(jnp.maximum(-a_hat, 0).reshape((28, 28)))
<matplotlib.image.AxesImage at 0x71008d3b8160>
X_val = data['X_val_bin']
y_val = data['y_val_bin']
y_hat = jnp.matmul(X_val, a_hat)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
[<matplotlib.lines.Line2D at 0x71008c7ea170>]
MNIST, regress 0-9¶
X = data['X_train']
y = data['y_train']
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S[0:10]))
plt.subplot(1,2,1)
plt.imshow(U[:,0].reshape((28,28)))
plt.subplot(1,2,2)
plt.imshow(U[:,1].reshape((28,28)))
eigenvalues: [61.84758 22.601597 19.816174 18.98699 17.21556 15.531815 14.318835 13.584824 12.348789 11.739741]
<matplotlib.image.AxesImage at 0x71008c432140>
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
X_val = data['X_val'][0:30,...]
y_val = data['y_val'][0:30]
y_hat = jnp.matmul(X_val, a_hat)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
[<matplotlib.lines.Line2D at 0x71008c4e6950>]
Output space not adapted (artificial topology)
MNIST regress 0-9 as one-hot¶
- Exercise: train a linear regression for each class (one versus all)
y = jax.nn.one_hot(y, 10)
a = []
for k in range(10):
#
#
a_k = jnp.zeros(784)
a.append(a_k)
a = jnp.array(a)
y_hat = jnp.argmax(jnp.matmul(X_val, a.T), axis=1)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
[<matplotlib.lines.Line2D at 0x71008d5725f0>]
a = [-0.2, 0.7, 0.83, -1.5, 5.23]
p = np.poly1d(a)
x = np.random.rand(24)*4-2
y = p(x) + 0.2*np.random.randn(24)
X = jnp.stack([jnp.ones((len(x))), x, x**2, x**3, x**4], axis=1)
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
print(a_hat)
pp = np.poly1d(a_hat[::-1])
t = np.linspace(-2, 2, 50)
plt.plot(x, y, 'x')
plt.plot(t, pp(t))
plt.plot(t, p(t))
[ 5.266477 -1.538329 0.6909166 0.7270644 -0.15955621]
[<matplotlib.lines.Line2D at 0x7100685f9570>]
Periodic signals¶
Map ϕ:x↦[sin(f0x),sin(2f0x),…,sin(pf0x)]
X = jnp.array([jnp.sin(x), jnp.sin(2*x), jnp.sin(3*x)])
ap = jnp.matmul(jnp.matmul(jnp.linalg.inv(jnp.matmul(X, X.transpose())), X), y)
Yp = jnp.matmul(ap, X)
t = np.linspace(-8, 8, 200)
T = np.array([np.sin(t), np.sin(2*t), np.sin(3*t)])
plt.plot(x, y, 'x')
plt.plot(t, np.matmul(ap, T))
plt.plot(t, np.matmul(a, T))
[<matplotlib.lines.Line2D at 0x710069721420>]
Overcomplete models¶
What if p<ˆp (model has greater capacity than data)
def sin_approx(x, y, p):
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
ap = jnp.matmul(jnp.matmul(jnp.linalg.inv(jnp.matmul(Xp, Xp.transpose())), Xp), y)
Yp = jnp.matmul(ap, Xp)
return ap, Xp, Yp
p = 3
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.66503567 0.8555399 -1.5098379 ]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.13037486374378204
p = 5
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.66392684 0.8552013 -1.4939932 -0.07257754 0.05675733]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.1274968534708023
p = 10
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.64930737 0.86065483 -1.5159405 -0.0419542 0.10985056 -0.04247902 -0.1989938 0.04065548 0.08720873 -0.07313146]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.1047743558883667
p = 15
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.64816654 0.85237145 -1.4848707 -0.06554112 0.17367777 -0.08139829 -0.15539262 0.0064044 0.08781601 -0.0583418 0.01320645 0.03894603 -0.13762946 0.01072791 -0.09005098]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.09392114728689194
p = 20
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.63273454 0.862865 -1.5327747 -0.04082306 0.12400526 -0.07723609 -0.18338399 0.05933774 0.09869707 -0.05381919 0.01258338 0.04874885 -0.09009814 -0.00779995 0.00236045 -0.04848589 -0.05025677 -0.11085568 0.04835366 -0.12392354]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.08044012635946274
p = 25
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.5857791 0.9295168 -1.5959327 -0.06027614 0.15495446 -0.07315889 -0.18192858 -0.00906502 0.12587665 0.02122447 -0.05915882 0.0195908 0.00481635 -0.02463253 -0.0524892 -0.01131299 -0.07204011 -0.1264672 0.09676984 -0.20841293 0.02010447 0.1596179 -0.20179838 0.03811657 0.21566837]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.054160039871931076
p = 35
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 0.60171056 2.1561666 -3.7234256 1.9323449 -1.881022 1.7465707 -1.9733595 2.061804 -1.430026 0.31073868 0.800243 -1.7518697 2.4190798 -2.4179287 1.7261914 -0.81368876 0.62834394 -1.0664698 0.6718048 -0.6353989 -0.20793903 1.5484711 -1.6906972 0.984274 -0.906407 0.850715 -0.3213622 0.83473617 -0.9671983 0.1878281 0.6671101 -1.3488506 1.6334653 -0.616805 -0.2670184 ]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.02646658569574356
Train/validation¶
x_val = np.random.rand(48)*16-8
X_val = jnp.array([ jnp.sin(x_val), jnp.sin(2*x_val), jnp.sin(3*x_val)])
y_val = jnp.matmul(a, X_val) + 0.3*np.random.randn(48)
mse = []
mse_val = []
for p in range(25):
ap, Xp, Yp = sin_approx(x, y, p)
_, Xp_val, _ = sin_approx(x_val, y_val, p)
Yp_val = jnp.matmul(ap, Xp_val)
mse.append(((y - Yp)**2).mean()), mse_val.append(((y_val - Yp_val)**2).mean())
plt.plot(mse, label='mse'); plt.plot(mse_val, label='mse_val')
plt.legend()
<matplotlib.legend.Legend at 0x710055d96e30>
Regularization¶
Noisy observation: y=a⊤x+ε,ε∼N(0,σ)
Assume ‖a‖0<d (not all input dimensions are used), can we force ˆa to be also sparse?
minaEx[(y−x⊤a)2]+Ω(a)
With Ω(a) a regularizer that increases cost for more complex a
LASSO¶
Least Absolute Shrinkage and Selection Operator (Tibshirani, 1996)
mina1n∑i(yi−x⊤ia)2+λ‖a‖1
Optimize using gradient descent
a←a−η[−2n∑i(yi−x⊤ia)+λsign(a)]
def sin_pred(a, X):
return jnp.matmul(a, X)
def sin_lasso(a, X, y, lam):
yp = sin_pred(a, X)
return ((y - yp)**2).mean() + lam*jnp.abs(a).sum()
@jax.jit
def update(a, X, y, lam):
da = jax.grad(sin_lasso, argnums=0)(a, X, y, lam)
return a - 0.05*da
p=25
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
ap = jnp.zeros(p)
for i in range(100):
ap = update(ap, Xp, y, 0.1)
print(a, ap)
[ 0.7 0.83 -1.5 ] [ 5.0244969e-01 6.7708075e-01 -1.3217005e+00 -5.5618468e-03 2.1797190e-03 -2.4922101e-03 -7.4531697e-02 -6.1854455e-03 1.1975669e-03 -3.2022649e-03 -3.3327036e-03 3.0571646e-03 -1.7762976e-02 -5.9072860e-03 -4.4773789e-03 -6.0488996e-03 -6.4997919e-02 -6.8982323e-03 1.4052609e-03 -7.1133333e-03 3.2059646e-03 -7.8694215e-03 -4.8849126e-03 1.2485289e-02 6.9161236e-02]
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.06530888378620148
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(a, T[:len(a), :]), label='gt')
for lam in [0, 0.0001, 0.001, 0.01, 0.1, 1.0]:
ap = jnp.zeros(p)
for i in range(50):
ap = update(ap, Xp, y, lam)
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]), label='$\lambda={}$'.format(lam))
plt.legend()
<matplotlib.legend.Legend at 0x710055c4b220>
def lasso_approx(Xp, y, lam):
ap = jnp.zeros(len(Xp[:,0]))
for i in range(100):
ap = update(ap, Xp, y, lam)
Yp = jnp.matmul(ap, Xp)
return ap, Xp, Yp
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
Xp_val = jnp.sin(jnp.matmul(x_val.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
mse = []
mse_val = []
for lam in [0, 0.0001, 0.001, 0.01, 0.1, 1.0]:
ap, Xp, Yp = lasso_approx(Xp, y, lam)
Yp_val = jnp.matmul(ap, Xp_val)
mse.append(((y - Yp)**2).mean()), mse_val.append(((y_val - Yp_val)**2).mean())
plt.plot(mse, label='mse'); plt.plot(mse_val, label='mse_val')
plt.legend()
<matplotlib.legend.Legend at 0x710055c5d7b0>
Analysis¶
Project X into its eigenspace:
mina1n‖y−X⊤Ua‖2+λ‖a‖1 =1n‖y−VSa‖2+λ‖a‖1
Stationary condition:
∂∂a=0=−2SV⊤y+2S2a+λsign(a) a=S−1V⊤y−S−2λsign(a)2
Let ˜a=S−1V⊤y
Note that sign(a)=sign(˜a)=˜a|˜a| a=˜a(1−λS−22|˜a|)
Case >0, sign(a)=sign(˜a)=1
ai=˜ai⏟>0(1−λS−2i2|˜ai|)>0
ai=˜aimax(0,1−λS−2i2|˜ai|)
Case <0, sign(a)=sign(˜a)=−1
ai=˜ai⏟<0(1−λS−22|˜ai|)<0
ai=˜aimax(0,1−λS−22|˜ai|)
Soft thresholding:
a=˜amax(0,1−λS−22|˜a|)
λ removes components that would change the sign of the solution → Sparse solution
Remember: analysis only valid in eigenspace
Conditioning¶
In eigenspace, pseudo-inverse solution: a=US−1V⊤y
What if S has small eigenvalues (span(X)<d)? How to prevent solution to focus on the noise?
Avoid large values in the solution:
mina1n∑i(yi−x⊤ia)2+λ‖a‖2
Tikonov regularization (ridge regression)
Analysis¶
1n‖y−X⊤a‖2+λ‖a‖2
Stationary condition:
∂∂a=0=2n(−X⊤y+XX⊤a)+2λa
a=(XX⊤+nλI)−1X⊤y
Offsetting all eigenvalues in the covariance matrix by λ
Elastic net¶
Add both regularization
mina1n∑i(yi−x⊤ia)2+λ1‖a‖1+λ2‖a‖22
- λ1 controls sparsity
- λ2 controls sensitivity to noisy components
Optimize using gradient descent
Other loss functions¶
ℓ2 is sensitive to outliers
x = np.random.rand(24)*16-8
y = x + 0.1*np.random.rand(24)
y[0] += 20
a = jnp.dot(x, y)/jnp.dot(x, x)
print(a)
t = jnp.linspace(-8, 8, 10)
plt.plot(x, y, 'x')
plt.plot(t, a*t, '-r')
0.8166666
[<matplotlib.lines.Line2D at 0x710055a97f40>]
No close form solution, gradient descent (subderivative ∇‖0‖1=0)
robust regression
def l1(a, x, y):
return jnp.abs(y - a*x).mean()
@jax.jit
def update(a, x, y):
da = jax.grad(l1, argnums=0)(a, x, y)
return a - 0.02*da
a = 0.
for i in range(100):
a = update(a, x, y)
print(a)
t = jnp.linspace(-8, 8, 10)
plt.plot(x, y, 'x')
plt.plot(t, a*t, '-r')
1.0179636
[<matplotlib.lines.Line2D at 0x710055934760>]
Sensitivity to small errors¶
- ℓ2: derivative falls quickly to zero
- ℓ1: constant derivative
t = jnp.linspace(-0.5, 0.5, 25)
plt.plot(t, jnp.abs(t), label='MAE')
plt.plot(t, t**2, label='MSE')
plt.legend()
<matplotlib.legend.Legend at 0x710055993c70>
Do both?¶
- Penalize large errors: ℓ2
- Penalize small errors (assuming no outliers): ℓ1
minAE[‖y−A⊤x‖2+λ‖y−A⊤x‖1]
Optimize using gradient descent
Full model¶
- Ridge regularization (noisy components)
- Sparsity regularization (overcomplete model)
- Large errors penalization
- Small errors penalization
minAE[‖y−A⊤x‖2+λ‖y−A⊤x‖1]+λ2‖A‖2F+λ1‖A‖1
- Optimize using gradient descent (surprise!)
- 3 hyper-parameters: use proper cross validation ("With four parameters I can fit an elephant, and with five I can make him wiggle his trunk", J. Von Neumann)
Dictionary learning¶
Unsupervised learning: target space is a new representation of the input space
- Input: X∈Rd×n
- Model: Dictionary D∈Dd×p
- Output: Factors A∈Rp×n
minD,A‖X−DA‖2F
If p<d, then D are the p leading left singular vectors of X and the factors A are the combination of the corresponding singular values with the right singular vectors
If p>d, we have an overcomplete dictionary, which means we can afford to not use all entries to reconstruct a sample
mina‖x−Da‖2+λ‖a‖0 Sparse coding
Alternate update:
Fix D, update A
- Difficult problem, relax to ‖a‖1 or use iterative thresholding methods
Fix A, update D
D=X(A⊤A)−1
K-SVD¶
Update one atom at a time (see (Aharon et al, 2006))
- dk: atom k of the dictionary
- ak∈Rn: factors corresponding to atom k
- ˉDk=[di]i≠k∈Rd×p−1: reduced dictionary without atom dk
- ˉAk=[ai]i≠k∈Rp−1×n: factors corresponding to the reduced dictionary
minD,A‖X−DA‖2F=‖X−ˉDkˉAk−dkak‖2F
Ek=X−ˉDkˉAk
Iterative updates: mindk,ak‖Ek−dkak‖2F
- SVD of Ek=USV⊤
- Rank 1 approximation: Ek≈u1s1v⊤1
- Get hard thresholding selection matrix: Ωk∈{0,1}n×n′, that select n′ samples that are coded by atom k (ex: highest absolute values of v1)
- Compute reduced problem for selected samples: mindk,ak‖EΩk−dkakΩk‖2F
- Update dk and ak using rank-1 approximation of EkΩk≈usv⊤
X = jnp.transpose(data['X_train'])
y = data['y_train']
D = np.random.rand(784, 64)
A = np.random.rand(64, 100)
for e in range(50):
D = jnp.matmul(jnp.matmul(X, A.T), jnp.linalg.inv(jnp.matmul(A, A.T)))
A = jnp.matmul(jnp.linalg.inv(jnp.matmul(D.T, D)), jnp.matmul(D.T, X))
S = jnp.sign(A)
I = jnp.argsort(jnp.abs(A), axis=0)[-33, :]
A = S * jnp.clip(jnp.abs(A) - jnp.abs(A[I, jnp.arange(100)]), a_min=0)
plt.subplot(1,5,1)
plt.imshow(D[:,0].reshape(28, 28))
plt.subplot(1,5,2)
plt.imshow(D[:,1].reshape(28, 28))
plt.subplot(1,5,3)
plt.imshow(D[:,2].reshape(28, 28))
plt.subplot(1,5,4)
plt.imshow(D[:,3].reshape(28, 28))
plt.subplot(1,5,5)
plt.imshow(D[:,4].reshape(28, 28))
<matplotlib.image.AxesImage at 0x7100554fcfd0>
plt.imshow(A)
<matplotlib.image.AxesImage at 0x7100553ede10>
MNIST¶
Exercise: Try a linear regression (vector output) using A instead of X
Why?¶
- A may provide a better alternative to X for doing learning a predictor
- D may provide insights (modes of X)
Relation to k-means
minD,A‖X−DA‖2F s.t. ∀i,‖ai‖0=1
- Only a single atom selected per sample
- Alternate optimization:
dk=Xak⊤‖ak‖1 ai=[1m=n]m,n=argmink‖dk−xi‖
Linear Model (regression), take home¶
- MSE often leads to closed form solution
- MSE to penalize large errors, MAE to penalize small errors
- MAE robust to outliers
- Sensitivity to condition number: ℓ2 regularization
- Sparse model: ℓ1 regularization
Dictionary learning
- Find better representation with a linear model
- Non linear relation: explicit non-linear mapping + linear model