Machine Learning and Applications - Linear Models¶

David Picard¶

École Nationale des Ponts et Chaussées¶

david.picard@enpc.fr¶

Linear Regression¶

Scalar input, scalar output¶

  • Input space: $x \in \mathbb{R}$
  • Output space: $y \in \mathbb{R}$
  • Linear model: $f(x) = ax$

$$ \min_{a} \mathbb{E}_x [ (y - ax)^2 ]$$

Training set $\mathcal{A} = \{(x_i, y_i)\}_{i\leq n}$, minimize the empirical risk

$$ \min_a \frac{1}{n}\sum_i (y_i - ax_i)^2$$

Closed form solution:

  • vectorize: $\mathbf{x} = [x_i]$, $\mathbf{y} = [y_i]$ $$ \min_a \frac{1}{n} \| \mathbf{y} - a\mathbf{x} \|^2 $$
  • Stationary condition $$ \frac{\partial}{\partial a} \frac{1}{n}\| \mathbf{y} - a\mathbf{x} \|^2 = 0 = 2a \|\mathbf{x}\|^2 - 2 \langle \mathbf{y}, \mathbf{x} \rangle$$ $$ a = \frac{\mathbf{y}^\top\mathbf{x}}{\|\mathbf{x}\|^2}$$

Linear regression - Vector input, scalar output¶

  • Input space: $\mathbf{x} \in \mathbb{R}^d$
  • Output space: $y \in \mathbb{R}$
  • Linear model: $f(x) = \mathbf{a}^\top\mathbf{x}$

$$ \min_{\mathbf{a}} \mathbb{E}_x [ (y - \mathbf{a}^\top\mathbf{x})^2 ]$$

Training set $\mathcal{A} = \{(\mathbf{x}_i, y_i)\}_{i\leq n}$, minimize the empirical risk

$$ \min_a \frac{1}{n}\sum_i (y_i - \mathbf{a}^\top\mathbf{x}_i)^2$$

Closed form solution

  • Matrix form: $\mathbf{X} = [\mathbf{x}_i]$, $\mathbf{y} = [y_i]$

$$ \min_\mathbf{a} \frac{1}{n} \|\mathbf{y} - \mathbf{X}^\top\mathbf{a} \|^2$$

  • Stationary condition

$$\frac{\partial}{\partial \mathbf{a}} \frac{1}{n} \|\mathbf{y} - \mathbf{X}^\top\mathbf{a} \|^2 = 0 = -2\mathbf{X}^\top\mathbf{y} + 2\mathbf{XX}^\top\mathbf{a} $$

$$ \mathbf{a} = (\mathbf{XX}^\top)^{-1} \mathbf{X}^\top\mathbf{y} $$

Pseudo-inverse

Vector input, bis¶

  • SVD: $\mathbf{X} = \mathbf{USV}^\top$

$$ \mathbf{y} = \mathbf{VSU}^\top\mathbf{a} $$ $$ \mathbf{a} = \mathbf{US}^{-1}\mathbf{V}^\top\mathbf{y}$$

Easy solution by projecting into the eigen space of $\mathbf{X}$, $\mathbf{a}$ is in the eigenspace of $\mathbf{X}$

Karhunen-Loève theorem¶

Let $\mathbf{x}$ be a stochastic process with covariance matrix $\sum_\mathbf{x}$ then

$$ \mathbf{x}_i = \sum_k^p z_{k,i}\mathbf{e}_k $$

with $\mathbf{e_k}$ the eigenvectors of $\sum_\mathbf{x}$.

  • Samples $\mathbf{x}_i$ exist in the space spaned by the eigenvectors of the covariance matrix (hence PCA)
  • if $span(\mathbf{x}) < d$, some dimensions are useless (noisy)
  • Strong influence on the pseudo-inverse solution ($\mathbf{S}^{-1}$) $\Rightarrow$ remove directions with small eigenvalues
In [2]:
key = jax.random.PRNGKey(0)
key, skey = jax.random.split(key)
x = jax.random.uniform(skey, (50, 1))
X = jnp.concatenate((x, -5*x), axis=1)
a = jnp.array([2, 0])
y = jnp.matmul(X, a)

U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S))
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
print('a_hat: {} a: {}'.format(a_hat, a))
eigenvalues: [2.0413006e+01 3.0459864e-07]
a_hat: [-2.8013153  -0.96026266] a: [2 0]

Linear regression, bias case¶

Adding a constant to the model is equivalent to the vector case

$$ \min_{\mathbf{a}, b} \frac{1}{n} \|\mathbf{y} - \mathbf{X}^\top\mathbf{a} - \mathbf{1}^\top b \|^2$$

  • concatenate $b$ to $\mathbf{a}$ and $1$ to $\mathbf{X}$

$$ \min_{\mathbf{a}, b} \frac{1}{n} \|\mathbf{y} - [\mathbf{X}; \mathbf{1}]^\top[\mathbf{a}; b] \|^2$$

Linear Regression, Vector input, vector output¶

  • Input space: $\mathbf{x} \in \mathbb{R}^d$
  • Output space: $\mathbf{y} \in \mathbb{R}^p$
  • Linear model: $f(x) = \mathbf{A}^\top\mathbf{x}$

$$ \min_{\mathbf{a}} \mathbb{E}_x [ \|\mathbf{y} - \mathbf{A}^\top\mathbf{x}\|^2 ]$$

Training set $\mathcal{A} = \{(\mathbf{x}_i, \mathbf{y}_i)\}_{i\leq n}$, minimize the empirical risk

$$ \min_a \frac{1}{n}\sum_i \|\mathbf{y}_i - \mathbf{A}^\top\mathbf{x}_i\|^2$$

Equivalent to $p$ scalar output cases stacked together

Let's try with MNIST¶

Regress 0 and 1

In [3]:
data = np.load('mnist.npz')
X = data['X_train_bin']
y = data['y_train_bin']
plt.imshow(X[0,:].reshape(28,28))
print(y[0])
0
No description has been provided for this image
In [4]:
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S))
plt.subplot(1,2,1)
plt.imshow(U[:,0].reshape((28,28)))
plt.subplot(1,2,2)
plt.imshow(U[:,1].reshape((28,28)))
eigenvalues: [36.45615   17.209862  12.030047  11.947479   9.36883    7.8752723
  6.7997932  6.3316774  5.729243   5.4004116  5.1866603  4.988159
  4.8420706  4.1217637  3.9173453  3.5896347  3.2801106  3.1127822
  3.0160408  2.7964358  2.6677098  2.543037   2.2888196  2.1948047
  2.14488    1.8337901  1.0252038]
Out[4]:
<matplotlib.image.AxesImage at 0x71008d42c040>
No description has been provided for this image
In [5]:
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)

plt.subplot(1,2,1)
plt.imshow(jnp.maximum(a_hat, 0).reshape((28, 28)))
plt.subplot(1,2,2)
plt.imshow(jnp.maximum(-a_hat, 0).reshape((28, 28)))
Out[5]:
<matplotlib.image.AxesImage at 0x71008d3b8160>
No description has been provided for this image
In [6]:
X_val = data['X_val_bin']
y_val = data['y_val_bin']
y_hat = jnp.matmul(X_val, a_hat)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
Out[6]:
[<matplotlib.lines.Line2D at 0x71008c7ea170>]
No description has been provided for this image

MNIST, regress 0-9¶

In [7]:
X = data['X_train']
y = data['y_train']

U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
print('eigenvalues: {}'.format(S[0:10]))
plt.subplot(1,2,1)
plt.imshow(U[:,0].reshape((28,28)))
plt.subplot(1,2,2)
plt.imshow(U[:,1].reshape((28,28)))
eigenvalues: [61.84758  22.601597 19.816174 18.98699  17.21556  15.531815 14.318835
 13.584824 12.348789 11.739741]
Out[7]:
<matplotlib.image.AxesImage at 0x71008c432140>
No description has been provided for this image
In [8]:
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)

X_val = data['X_val'][0:30,...]
y_val = data['y_val'][0:30]
y_hat = jnp.matmul(X_val, a_hat)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
Out[8]:
[<matplotlib.lines.Line2D at 0x71008c4e6950>]
No description has been provided for this image

Output space not adapted (artificial topology)

MNIST regress 0-9 as one-hot¶

  • Exercise: train a linear regression for each class (one versus all)
In [9]:
y = jax.nn.one_hot(y, 10)

a = []
for k in range(10):
    #
    #
    a_k = jnp.zeros(784)
    a.append(a_k)
a = jnp.array(a)

y_hat = jnp.argmax(jnp.matmul(X_val, a.T), axis=1)
plt.stem(range(len(y_val)), y_val, '-b')
plt.plot(range(len(y_val)), y_hat, '-r')
Out[9]:
[<matplotlib.lines.Line2D at 0x71008d5725f0>]
No description has been provided for this image

Non-linear case¶

Polynomial regression¶

What if the relation between $x$ and $y$ is not linear?

  • Map $\phi: x \mapsto [x, x^2, x^3, \dots, x^p]$

$$ \min_\mathbf{a} \mathbb{E}_x[(y - \phi(x)^\top\mathbf{a})^2]$$

In [10]:
a = [-0.2, 0.7, 0.83, -1.5, 5.23]
p = np.poly1d(a)
x = np.random.rand(24)*4-2
y = p(x) + 0.2*np.random.randn(24)
In [11]:
X = jnp.stack([jnp.ones((len(x))), x, x**2, x**3, x**4], axis=1)
U, S, V = jnp.linalg.svd(X.T, full_matrices=False)
Vty = jnp.matmul(V, y)
SinvVty = jnp.matmul(jnp.diag(1./S), Vty)
a_hat = jnp.matmul(U, SinvVty)
print(a_hat)
pp = np.poly1d(a_hat[::-1])

t = np.linspace(-2, 2, 50)
plt.plot(x, y, 'x')
plt.plot(t, pp(t))
plt.plot(t, p(t))
[ 5.266477   -1.538329    0.6909166   0.7270644  -0.15955621]
Out[11]:
[<matplotlib.lines.Line2D at 0x7100685f9570>]
No description has been provided for this image

Periodic signals¶

Map $\phi: x \mapsto [\sin(f_0 x), \sin(2f_0x), \dots, \sin(pf_0x)]$

In [13]:
X = jnp.array([jnp.sin(x), jnp.sin(2*x), jnp.sin(3*x)])
ap = jnp.matmul(jnp.matmul(jnp.linalg.inv(jnp.matmul(X, X.transpose())), X), y)
Yp = jnp.matmul(ap, X)
In [14]:
t = np.linspace(-8, 8, 200)
T = np.array([np.sin(t), np.sin(2*t), np.sin(3*t)])
plt.plot(x, y, 'x')
plt.plot(t, np.matmul(ap, T))
plt.plot(t, np.matmul(a, T))
Out[14]:
[<matplotlib.lines.Line2D at 0x710069721420>]
No description has been provided for this image

Overcomplete models¶

What if $p < \hat{p}$ (model has greater capacity than data)

In [15]:
def sin_approx(x, y, p):
    Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
    ap = jnp.matmul(jnp.matmul(jnp.linalg.inv(jnp.matmul(Xp, Xp.transpose())), Xp), y)
    Yp = jnp.matmul(ap, Xp)
    return ap, Xp, Yp
In [16]:
p = 3
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7   0.83 -1.5 ] [ 0.66503567  0.8555399  -1.5098379 ]
In [17]:
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.13037486374378204
No description has been provided for this image
In [18]:
p = 5
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7   0.83 -1.5 ] [ 0.66392684  0.8552013  -1.4939932  -0.07257754  0.05675733]
In [19]:
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.1274968534708023
No description has been provided for this image
In [20]:
p = 10
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7   0.83 -1.5 ] [ 0.64930737  0.86065483 -1.5159405  -0.0419542   0.10985056 -0.04247902
 -0.1989938   0.04065548  0.08720873 -0.07313146]
In [21]:
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.1047743558883667
No description has been provided for this image
In [22]:
p = 15
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7   0.83 -1.5 ] [ 0.64816654  0.85237145 -1.4848707  -0.06554112  0.17367777 -0.08139829
 -0.15539262  0.0064044   0.08781601 -0.0583418   0.01320645  0.03894603
 -0.13762946  0.01072791 -0.09005098]
In [23]:
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.09392114728689194
No description has been provided for this image
In [24]:
p = 20
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7   0.83 -1.5 ] [ 0.63273454  0.862865   -1.5327747  -0.04082306  0.12400526 -0.07723609
 -0.18338399  0.05933774  0.09869707 -0.05381919  0.01258338  0.04874885
 -0.09009814 -0.00779995  0.00236045 -0.04848589 -0.05025677 -0.11085568
  0.04835366 -0.12392354]
In [25]:
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.08044012635946274
No description has been provided for this image
In [26]:
p = 25
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7   0.83 -1.5 ] [ 0.5857791   0.9295168  -1.5959327  -0.06027614  0.15495446 -0.07315889
 -0.18192858 -0.00906502  0.12587665  0.02122447 -0.05915882  0.0195908
  0.00481635 -0.02463253 -0.0524892  -0.01131299 -0.07204011 -0.1264672
  0.09676984 -0.20841293  0.02010447  0.1596179  -0.20179838  0.03811657
  0.21566837]
In [27]:
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.054160039871931076
No description has been provided for this image
In [28]:
p = 35
ap, Xp, Yp = sin_approx(x, y, p)
print(a, ap)
[ 0.7   0.83 -1.5 ] [ 0.60171056  2.1561666  -3.7234256   1.9323449  -1.881022    1.7465707
 -1.9733595   2.061804   -1.430026    0.31073868  0.800243   -1.7518697
  2.4190798  -2.4179287   1.7261914  -0.81368876  0.62834394 -1.0664698
  0.6718048  -0.6353989  -0.20793903  1.5484711  -1.6906972   0.984274
 -0.906407    0.850715   -0.3213622   0.83473617 -0.9671983   0.1878281
  0.6671101  -1.3488506   1.6334653  -0.616805   -0.2670184 ]
In [29]:
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.02646658569574356
No description has been provided for this image

Train/validation¶

In [30]:
x_val = np.random.rand(48)*16-8
X_val = jnp.array([ jnp.sin(x_val), jnp.sin(2*x_val), jnp.sin(3*x_val)])
y_val = jnp.matmul(a, X_val) + 0.3*np.random.randn(48)
In [32]:
mse = []
mse_val = []
for p in range(25):
    ap, Xp, Yp = sin_approx(x, y, p)
    _, Xp_val, _ = sin_approx(x_val, y_val, p)
    Yp_val = jnp.matmul(ap, Xp_val)
    mse.append(((y - Yp)**2).mean()), mse_val.append(((y_val - Yp_val)**2).mean())
plt.plot(mse, label='mse'); plt.plot(mse_val, label='mse_val')
plt.legend()
Out[32]:
<matplotlib.legend.Legend at 0x710055d96e30>
No description has been provided for this image

Regularization¶

Noisy observation: $$ y = \mathbf{a}^\top\mathbf{x} + \varepsilon, \varepsilon \sim \mathcal{N}(0, \sigma)$$

Assume $\|\mathbf{a}\|_0 < d$ (not all input dimensions are used), can we force $\hat{\mathbf{a}}$ to be also sparse?

$$ \min_\mathbf{a} \mathbb{E}_x[(y - \mathbf{x}^\top\mathbf{a})^2] + \Omega(\mathbf{a}) $$

With $\Omega(\mathbf{a})$ a regularizer that increases cost for more complex $\mathbf{a}$

LASSO¶

Least Absolute Shrinkage and Selection Operator (Tibshirani, 1996)

$$\min_\mathbf{a} \frac{1}{n}\sum_i (y_i - \mathbf{x}_i^\top\mathbf{a})^2 + \lambda\|\mathbf{a}\|_1 $$

Optimize using gradient descent

$$ \mathbf{a} \leftarrow \mathbf{a} - \eta \left[ \frac{-2}{n}\sum_i (y_i - \mathbf{x}_i^\top\mathbf{a}) + \lambda\text{sign}(\mathbf{a}) \right] $$

In [33]:
def sin_pred(a, X):
    return jnp.matmul(a, X)

def sin_lasso(a, X, y, lam):
    yp = sin_pred(a, X)
    return ((y - yp)**2).mean() + lam*jnp.abs(a).sum()

@jax.jit
def update(a, X, y, lam):
    da = jax.grad(sin_lasso, argnums=0)(a, X, y, lam)
    return a - 0.05*da
In [34]:
p=25
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
ap = jnp.zeros(p)
for i in range(100):
    ap = update(ap, Xp, y, 0.1)
print(a, ap)
[ 0.7   0.83 -1.5 ] [ 5.0244969e-01  6.7708075e-01 -1.3217005e+00 -5.5618468e-03
  2.1797190e-03 -2.4922101e-03 -7.4531697e-02 -6.1854455e-03
  1.1975669e-03 -3.2022649e-03 -3.3327036e-03  3.0571646e-03
 -1.7762976e-02 -5.9072860e-03 -4.4773789e-03 -6.0488996e-03
 -6.4997919e-02 -6.8982323e-03  1.4052609e-03 -7.1133333e-03
  3.2059646e-03 -7.8694215e-03 -4.8849126e-03  1.2485289e-02
  6.9161236e-02]
In [35]:
t = jnp.linspace(-8, 8, 200)
T = jnp.sin(np.matmul(t.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(ap, T[:len(ap), :]))
plt.plot(t, jnp.matmul(a, T[:len(a), :]))
print('MSE: {}'.format(((y - Yp)**2).mean()))
MSE: 0.06530888378620148
No description has been provided for this image
In [36]:
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
plt.plot(x, y, 'x')
plt.plot(t, jnp.matmul(a, T[:len(a), :]), label='gt')
for lam in [0, 0.0001, 0.001, 0.01, 0.1, 1.0]:
    ap = jnp.zeros(p)
    for i in range(50):
        ap = update(ap, Xp, y, lam)
    plt.plot(t, jnp.matmul(ap, T[:len(ap), :]), label='$\lambda={}$'.format(lam))
plt.legend()
Out[36]:
<matplotlib.legend.Legend at 0x710055c4b220>
No description has been provided for this image
In [37]:
def lasso_approx(Xp, y, lam):
    ap = jnp.zeros(len(Xp[:,0]))
    for i in range(100):
        ap = update(ap, Xp, y, lam)
    Yp = jnp.matmul(ap, Xp)
    return ap, Xp, Yp 
In [38]:
Xp = jnp.sin(jnp.matmul(x.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
Xp_val = jnp.sin(jnp.matmul(x_val.reshape(-1, 1), 1+jnp.arange(p).reshape(1,p))).transpose()
mse = []
mse_val = []
for lam in [0, 0.0001, 0.001, 0.01, 0.1, 1.0]:
    ap, Xp, Yp = lasso_approx(Xp, y, lam)
    Yp_val = jnp.matmul(ap, Xp_val)
    mse.append(((y - Yp)**2).mean()), mse_val.append(((y_val - Yp_val)**2).mean())
plt.plot(mse, label='mse'); plt.plot(mse_val, label='mse_val')
plt.legend()
Out[38]:
<matplotlib.legend.Legend at 0x710055c5d7b0>
No description has been provided for this image

Analysis¶

Project $\mathbf{X}$ into its eigenspace:

$$ \min_\mathbf{a} \frac{1}{n} \|\mathbf{y} - \mathbf{X}^\top\mathbf{Ua}\|^2 + \lambda\|\mathbf{a}\|_1 $$ $$ = \frac{1}{n} \|\mathbf{y} - \mathbf{VSa}\|^2 + \lambda\|\mathbf{a}\|_1$$

Stationary condition:

$$\frac{\partial}{\partial \mathbf{a}} = 0 = -2\mathbf{SV}^\top\mathbf{y} + 2\mathbf{S}^2\mathbf{a} + \lambda\text{sign}(\mathbf{a})$$ $$\mathbf{a} = \mathbf{S}^{-1}\mathbf{V}^\top\mathbf{y} - \mathbf{S}^{-2}\frac{\lambda\text{sign}(\mathbf{a})}{2}$$

Let $\tilde{\mathbf{a}} = \mathbf{S}^{-1}\mathbf{V}^\top\mathbf{y}$

Note that $\text{sign}(\mathbf{a}) = \text{sign}(\tilde{\mathbf{a}}) = \frac{\tilde{\mathbf{a}}}{\vert\tilde{\mathbf{a}}\vert}$ $$\mathbf{a} = \tilde{\mathbf{a}}\left(1 - \frac{\lambda\mathbf{S}^{-2}}{2\vert\tilde{\mathbf{a}}\vert}\right)$$

Case $>0$, $\text{sign}(\mathbf{a}) = \text{sign}(\tilde{\mathbf{a}}) = 1$

$$ \mathbf{a}_i = \color{blue}{\underbrace{\tilde{\mathbf{a}}_i}_{>0}} \left(1 - \frac{\lambda\mathbf{S}^{-2}_i}{2\vert\tilde{\mathbf{a}}_i\vert}\right) >0$$

$$ \mathbf{a}_i = \tilde{\mathbf{a}}_i \max\left(0, 1 - \frac{\lambda\mathbf{S}^{-2}_i}{2\vert\tilde{\mathbf{a}}_i\vert}\right) $$

Case $<0$, $\text{sign}(\mathbf{a}) = \text{sign}(\tilde{\mathbf{a}}) = -1$

$$ \mathbf{a}_i = \color{blue}{\underbrace{\tilde{\mathbf{a}}_i}_{<0}} \left(1 - \frac{\lambda\mathbf{S}^{-2}}{2\vert\tilde{\mathbf{a}}_i\vert}\right) <0$$

$$ \mathbf{a}_i = \tilde{\mathbf{a}}_i \max\left(0, 1 - \frac{\lambda\mathbf{S}^{-2}}{2\vert\tilde{\mathbf{a}}_i\vert}\right) $$

Soft thresholding:

$$ \mathbf{a} = \tilde{\mathbf{a}} \max\left(0, 1 - \frac{\lambda\mathbf{S}^{-2}}{2\vert\tilde{\mathbf{a}}\vert}\right) $$

$\lambda$ removes components that would change the sign of the solution $\rightarrow$ Sparse solution

Remember: analysis only valid in eigenspace

Conditioning¶

In eigenspace, pseudo-inverse solution: $$ \mathbf{a} = \mathbf{US}^{-1}\mathbf{V}^\top\mathbf{y}$$

What if $\mathbf{S}$ has small eigenvalues ($\text{span}(\mathbf{X}) < d$)? How to prevent solution to focus on the noise?

Avoid large values in the solution:

$$\min_\mathbf{a} \frac{1}{n}\sum_i (y_i - \mathbf{x}_i^\top\mathbf{a})^2 + \lambda\|\mathbf{a}\|^2 $$

Tikonov regularization (ridge regression)

Analysis¶

$$\frac{1}{n} \|\mathbf{y} - \mathbf{X}^\top\mathbf{a}\|^2 + \lambda\|\mathbf{a}\|^2 $$

Stationary condition:

$$\frac{\partial}{\partial \mathbf{a}} = 0 =\frac{2}{n}(-\mathbf{X}^\top\mathbf{y} + \mathbf{XX}^\top\mathbf{a}) +2\lambda\mathbf{a} $$

$$\mathbf{a} = (\mathbf{XX}^\top +n\lambda\mathbf{I})^{-1}\mathbf{X}^\top\mathbf{y} $$

Offsetting all eigenvalues in the covariance matrix by $\lambda$

Elastic net¶

Add both regularization

$$\min_\mathbf{a} \frac{1}{n}\sum_i (y_i - \mathbf{x}_i^\top\mathbf{a})^2 + \lambda_1\|\mathbf{a}\|_1 +\lambda_2\|\mathbf{a}\|_2^2$$

  • $\lambda_1$ controls sparsity
  • $\lambda_2$ controls sensitivity to noisy components

Optimize using gradient descent

Other loss functions¶

$\ell_2$ is sensitive to outliers

In [39]:
x = np.random.rand(24)*16-8
y = x + 0.1*np.random.rand(24)
y[0] += 20

a = jnp.dot(x, y)/jnp.dot(x, x)
print(a)
t = jnp.linspace(-8, 8, 10)
plt.plot(x, y, 'x')
plt.plot(t, a*t, '-r')
0.8166666
Out[39]:
[<matplotlib.lines.Line2D at 0x710055a97f40>]
No description has been provided for this image

MAE¶

Mean absolute error (or $\ell_1$ error)

$$\min_\mathbf{a} \mathbb{E}[ \vert y_i - \mathbf{a}^\top\mathbf{x}_i\vert ]$$

Vector case

$$\min_\mathbf{a} \mathbb{E}[ \| \mathbf{y}_i - \mathbf{A}^\top\mathbf{x}_i\|_1 ]$$

No close form solution, gradient descent (subderivative $\nabla\|0\|_1 = 0$)

robust regression

In [40]:
def l1(a, x, y):
    return jnp.abs(y - a*x).mean()

@jax.jit
def update(a, x, y):
    da = jax.grad(l1, argnums=0)(a, x, y)
    return a - 0.02*da

a = 0.
for i in range(100):
    a = update(a, x, y)
print(a)
t = jnp.linspace(-8, 8, 10)
plt.plot(x, y, 'x')
plt.plot(t, a*t, '-r')
1.0179636
Out[40]:
[<matplotlib.lines.Line2D at 0x710055934760>]
No description has been provided for this image

Sensitivity to small errors¶

  • $\ell_2$: derivative falls quickly to zero
  • $\ell_1$: constant derivative
In [41]:
t = jnp.linspace(-0.5, 0.5, 25)
plt.plot(t, jnp.abs(t), label='MAE')
plt.plot(t, t**2, label='MSE')
plt.legend()
Out[41]:
<matplotlib.legend.Legend at 0x710055993c70>
No description has been provided for this image

Do both?¶

  • Penalize large errors: $\ell_2$
  • Penalize small errors (assuming no outliers): $\ell_1$

$$ \min_\mathbf{A} \mathbb{E}[ \|\mathbf{y} - \mathbf{A}^\top\mathbf{x}\|^2 + \lambda \|\mathbf{y} - \mathbf{A}^\top\mathbf{x}\|_1 ]$$

Optimize using gradient descent

Full model¶

  • Ridge regularization (noisy components)
  • Sparsity regularization (overcomplete model)
  • Large errors penalization
  • Small errors penalization

$$ \min_\mathbf{A} \mathbb{E}[ \|\mathbf{y} - \mathbf{A}^\top\mathbf{x}\|^2 + \lambda \|\mathbf{y} - \mathbf{A}^\top\mathbf{x}\|_1 ] + \lambda_2 \|\mathbf{A}\|_F^2 + \lambda_1 \|\mathbf{A}\|_1$$

  • Optimize using gradient descent (surprise!)
  • 3 hyper-parameters: use proper cross validation ("With four parameters I can fit an elephant, and with five I can make him wiggle his trunk", J. Von Neumann)

Dictionary learning¶

Unsupervised learning: target space is a new representation of the input space

  • Input: $\mathbf{X} \in \mathbb{R}^{d\times n}$
  • Model: Dictionary $\mathbf{D} \in \mathbb{D}^{d\times p}$
  • Output: Factors $\mathbf{A} \in \mathbb{R}^{p\times n}$

$$ \min_{\mathbf{D},\mathbf{A}} \|\mathbf{X} - \mathbf{DA}\|_F^2 $$

If $p < d$, then $\mathbf{D}$ are the $p$ leading left singular vectors of $\mathbf{X}$ and the factors $\mathbf{A}$ are the combination of the corresponding singular values with the right singular vectors

If $p > d$, we have an overcomplete dictionary, which means we can afford to not use all entries to reconstruct a sample

$$ \min_{\mathbf{a}} \|\mathbf{x} - \mathbf{Da}\|^2 +\lambda \|\mathbf{a}\|_0$$ Sparse coding

Alternate update:

  • Fix $\mathbf{D}$, update $\mathbf{A}$

    • Difficult problem, relax to $\|\mathbf{a}\|_1$ or use iterative thresholding methods
  • Fix $\mathbf{A}$, update $\mathbf{D}$

$$\mathbf{D} = \mathbf{X}(\mathbf{A}^\top\mathbf{A})^{-1}$$

K-SVD¶

Update one atom at a time (see (Aharon et al, 2006))

  • $\mathbf{d}_k$: atom $k$ of the dictionary
  • $\mathbf{a}^k \in \mathbb{R}^n$: factors corresponding to atom $k$
  • $\bar{\mathbf{D}}_k = [\mathbf{d}_i]_{i\neq k} \in \mathbb{R}^{d\times p-1}$: reduced dictionary without atom $\mathbf{d}_k$
  • $\bar{\mathbf{A}}^k = [\mathbf{a}_i]_{i\neq k} \in \mathbb{R}^{p-1\times n}$: factors corresponding to the reduced dictionary

$$ \min_{\mathbf{D},\mathbf{A}} \|\mathbf{X} - \mathbf{DA}\|_F^2 = \|\mathbf{X} - \bar{\mathbf{D}}_k\bar{\mathbf{A}}^k - \mathbf{d}_k\mathbf{a}^k\|_F^2$$

$$ \mathbf{E}_k = \mathbf{X} - \bar{\mathbf{D}}_k\bar{\mathbf{A}}^k$$

Iterative updates: $$\min_{\mathbf{d}_k, \mathbf{a}^k} \|\mathbf{E}_k - \mathbf{d}_k\mathbf{a}^k\|_F^2$$

  • SVD of $\mathbf{E}_k = \mathbf{USV}^\top$
  • Rank 1 approximation: $\mathbf{E}_k \approx \mathbf{u}_1s_1\mathbf{v}_1^\top$
  • Get hard thresholding selection matrix: $\Omega_k \in \{0, 1\}^{n\times n'}$, that select $n'$ samples that are coded by atom $k$ (ex: highest absolute values of $\mathbf{v}_1$)
  • Compute reduced problem for selected samples: $$\min_{\mathbf{d}_k, \mathbf{a}^k} \|\mathbf{E}\Omega_k - \mathbf{d}_k\mathbf{a}^k\Omega_k\|_F^2$$
  • Update $\mathbf{d}_k$ and $\mathbf{a}^k$ using rank-1 approximation of $\mathbf{E}_k\Omega_k \approx \mathbf{u}s\mathbf{v}^\top$
In [42]:
X = jnp.transpose(data['X_train'])
y = data['y_train']

D = np.random.rand(784, 64)
A = np.random.rand(64, 100)

for e in range(50):
    D = jnp.matmul(jnp.matmul(X, A.T), jnp.linalg.inv(jnp.matmul(A, A.T)))
    A = jnp.matmul(jnp.linalg.inv(jnp.matmul(D.T, D)), jnp.matmul(D.T, X))
    S = jnp.sign(A)
    I = jnp.argsort(jnp.abs(A), axis=0)[-33, :]
    A = S * jnp.clip(jnp.abs(A) - jnp.abs(A[I, jnp.arange(100)]), a_min=0)
In [43]:
plt.subplot(1,5,1)
plt.imshow(D[:,0].reshape(28, 28))
plt.subplot(1,5,2)
plt.imshow(D[:,1].reshape(28, 28))
plt.subplot(1,5,3)
plt.imshow(D[:,2].reshape(28, 28))
plt.subplot(1,5,4)
plt.imshow(D[:,3].reshape(28, 28))
plt.subplot(1,5,5)
plt.imshow(D[:,4].reshape(28, 28))
Out[43]:
<matplotlib.image.AxesImage at 0x7100554fcfd0>
No description has been provided for this image
In [44]:
plt.imshow(A)
Out[44]:
<matplotlib.image.AxesImage at 0x7100553ede10>
No description has been provided for this image

MNIST¶

Exercise: Try a linear regression (vector output) using $\mathbf{A}$ instead of $\mathbf{X}$

In [ ]:
 

Why?¶

  • $\mathbf{A}$ may provide a better alternative to $\mathbf{X}$ for doing learning a predictor
  • $\mathbf{D}$ may provide insights (modes of $\mathbf{X}$)

Relation to k-means

$$ \min_{\mathbf{D},\mathbf{A}} \|\mathbf{X} - \mathbf{DA}\|_F^2\quad \text{ s.t. } \forall i, \|\mathbf{a}_i\|_0 = 1$$

  • Only a single atom selected per sample
  • Alternate optimization:

$$\mathbf{d}_k = \frac{\mathbf{X}\mathbf{a}^{k\top}}{\|\mathbf{a}^k\|_1}$$ $$\mathbf{a}_i = [\mathbf{1}_{m=n}]_m, n = \text{argmin}_k\|\mathbf{d}_k - \mathbf{x}_i\|$$

Linear Model (regression), take home¶

  • MSE often leads to closed form solution
  • MSE to penalize large errors, MAE to penalize small errors
  • MAE robust to outliers
  • Sensitivity to condition number: $\ell_2$ regularization
  • Sparse model: $\ell_1$ regularization
  • Dictionary learning

    • Find better representation with a linear model
  • Non linear relation: explicit non-linear mapping + linear model