Books:
Shai Shalev-Shwartz, Shai Ben-David, Understanding Machine Learning: From Theory to Algorithms
Kevin P. Murphy, Probabilistic Machine Learning: An Introduction
In French: Chloé-Agathe Azencott, Introduction au Machine Learning
Other lectures at ENPC:
This lecture uses JAX because I want to keep it low level and look at how the algorithms work under the hood. In practice there are many high level libraries. Do not reinvent the wheel, but beware that some sell square wheels...
We want to find a function $f$ that approximates $y$ from $X$
Solving problem #1, $P(X,y)$ is unknown:
If $P$ was known, we would use
$$f(x) = \arg\max_y P(y|x)$$which is our best guess and would lead to the following error
$$P_e = \int \left(1 - \max_y P(y|x)\right)p(x)dx$$(Bayes error)
Estimate the error instead:
Training set of examples $\mathcal{A} = \{(X_i, y_i)\}_{i\leq n}$ sampled from $P(X,y)$
Approximate the expected error by the empirical risk
Consider the function $$ f(X) = \begin{cases}y_i\text{ if }\exists X_i \in \mathcal{A}\text{ such that }X_i = X \\ 0\text{ else}\end{cases}$$
Obviously $$E(f) = 0$$
However, $f$ is pretty useless at predicting anything outside of $\mathcal{A}$
Points inside a random circle
key = jax.random.PRNGKey(0)
key, skey = jax.random.split(key)
X = jax.random.uniform(skey, (50, 2))
y = gt(X)
plt.scatter(X[:,0], X[:,1], c=y)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
<matplotlib.collections.PathCollection at 0x7f09a5de3290>
class CirclePredictor:
def __init__(self, key):
key, skey = jax.random.split(key)
self.c = jax.random.uniform(key, (2,))
self.r = jax.random.uniform(skey)
def __call__(self, X):
return jnp.sign(1*((X[:,0] - self.c[0])**2 + (X[:,1] - self.c[1])**2 < self.r**2))
def loss(y_pred, y_true):
return (1-(y_pred==y_true)).mean()
key, skey = jax.random.split(key)
pred = CirclePredictor(skey)
t = 50; tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = pred(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels); plt.scatter(X[:,0], X[:,1], c=y)
<matplotlib.collections.PathCollection at 0x7f09a03479d0>
fig = plt.figure()
camera = Camera(fig)
key = jax.random.PRNGKey(7)
l_min = 20; f_best = None
le = []
for i in range(100):
key, skey = jax.random.split(key)
f = CirclePredictor(skey)
l = loss(f(X), y)
if l < l_min:
l_min = l; f_best = f
le.append(l_min)
plt.plot(le, '-k'); camera.snap()
animation = camera.animate()
HTML(animation.to_html5_video())
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = f_best(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y)
<matplotlib.collections.PathCollection at 0x7f09a033a910>
The difference between the expected risk and the empirical risk is know as the generalization gap
How do we know if $f$ is overfitting?
key = jax.random.PRNGKey(6)
Xt = jax.random.uniform(key, (50, 2))
yt = gt(Xt)
fig = plt.figure()
camera = Camera(fig)
key = jax.random.PRNGKey(1)
l_min = 20; f_best = None
le = []
lt = []
for i in range(100):
key, skey = jax.random.split(key)
f = CirclePredictor(skey)
l = loss(f(X), y)
if l < l_min:
l_min = l; f_best = f
le.append(l)
lt.append(loss(f(Xt), yt))
plt.plot(le, '-k'); plt.plot(lt, '-r'); camera.snap()
animation = camera.animate()
HTML(animation.to_html5_video())
$k$ nearest neighbor: prediction is a vote among $k$ nearest elements of the training set
Example $1-NN$:
$$f(x) = y_i \text{ s.t. } i = \arg\min_{x_j \in \mathcal{A}} \| x - x_j \|^2$$class FirstNearestNeighbor:
def __init__(self, X, y):
self.X = X
self.y = y
def __call__(self, x):
dist = ((self.X[None,:,:] - x[:,None,:])**2).sum(axis=2) # broadcast to B x n x dim
index = jnp.argmin(dist, axis=1)
return y[index]
nn = FirstNearestNeighbor(X, y)
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = nn(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
#plt.scatter(X[:,0], X[:,1], c=y)
plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
<matplotlib.collections.PathCollection at 0x7f09e1c05910>
What is the expected error of 1-NN in the limit?
Theorem (Cover & Hart, 1967): Let X be a metric space. Let $p_1$ and $p_2$ be such that with probability 1, $x$ is either 1) a continuity point of $p_1$ and $p_2$, or 2) a point on non-zero probability measure. Then the NN risk R (probability of error) has the bounds:
$$R^\star \leq R \leq 2R^\star(1-R^\star)$$With $R^\star$ the Bayes error (irreducible error) $R^\star = E[\min_j \sum_i p_i(x)L(i,j)]$.
T. Cover and P. Hart, "Nearest neighbor pattern classification," in IEEE Transactions on Information Theory, vol. 13, no. 1, pp. 21-27, January 1967, doi: 10.1109/TIT.1967.1053964.
Lemma: Let $x_n'$ denote the nearest neighbor of $x$ in the set $\{x_0, \dots, x_n\}$, then $x_n' \rightarrow x$ with probability one (continuity + point measure).
class KNearestNeighbor:
def __init__(self, X, y, k=1):
self.X = X
self.y = y
self.k = k
def __call__(self, x):
dist = ((self.X[None,:,:] - x[:,None,:])**2).sum(axis=2) # broadcast to B x n x dim
indices = jnp.argsort(dist, axis=1)
yp = 1*((self.y[indices[:,0:self.k]]).sum(axis=1) > self.k//2)
return yp
nn = KNearestNeighbor(X, y, k=1)
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = nn(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y), plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
(<matplotlib.collections.PathCollection at 0x7f09e1ff8a10>, <matplotlib.collections.PathCollection at 0x7f09e1ff8b50>)
nn = KNearestNeighbor(X, y, k=2)
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = nn(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y), plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
(<matplotlib.collections.PathCollection at 0x7f09e1ff8f90>, <matplotlib.collections.PathCollection at 0x7f0988784a10>)
nn = KNearestNeighbor(X, y, k=3)
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = nn(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y), plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
(<matplotlib.collections.PathCollection at 0x7f09e1ff4d90>, <matplotlib.collections.PathCollection at 0x7f098878a350>)
nn = KNearestNeighbor(X, y, k=5)
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = nn(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y)#, plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
<matplotlib.collections.PathCollection at 0x7f09887f6b10>
How do we select k?
They all do 0 error on $\mathcal{A}$
We can split $\mathcal{A}$ in 2:
One for training each k-NN: training set
One for evaluating each k-NN: validation set
Since the validation set is used to select a model, it cannot be used to give us an idea of the expected risk
Standard 3-split procedure: train, validation, test
Train on train
Perform model selection on validation
Evaluate on test
key = jax.random.PRNGKey(33)
Xv = jax.random.uniform(key, (50, 2))
yv = gt(Xt)
lv = []; lt = []; lr = []
for k in range(1,30):
nn = KNearestNeighbor(X, y, k)
lv.append(loss(nn(Xv), yv))
lt.append(loss(nn(Xt), yt))
lr.append(loss(nn(X), y))
plt.plot(lv, '-k'), plt.plot(lt, '-r'), plt.plot(lr, '-g')
([<matplotlib.lines.Line2D at 0x7f09a0454f50>], [<matplotlib.lines.Line2D at 0x7f09e1829610>], [<matplotlib.lines.Line2D at 0x7f09e1829a50>])
nn = KNearestNeighbor(X, y, k=17)
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = nn(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y), plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
(<matplotlib.collections.PathCollection at 0x7f09a0021850>, <matplotlib.collections.PathCollection at 0x7f09e17e9550>)
Knowing that $f$ does $\epsilon$ expected error, what is the probability that $f$ has an empirical error of $\eta$ or less on a dataset of size $n$?
For $\eta$ observe error rate $$ \sum_{k=1}^{\lfloor \eta n \rfloor}{n \choose k} \epsilon^k(1-\epsilon)^{n-k}$$
def Pn_of_eta_given_eps(n, eta, eps):
p = 0
for k in range(int(eta*n)):
p += scipy.special.comb(n, k) * eps**k * (1-eps)**(n-k)
return p
x = range(100, 1100, 100)
p = [Pn_of_eta_given_eps(i, 0.01, 0.02) for i in x]
plt.loglog(x, p)
[<matplotlib.lines.Line2D at 0x7f09e1719810>]
x = 0.1*jnp.arange(0, 6, 1)
p = [Pn_of_eta_given_eps(100, 0.1, i) for i in x]
plt.semilogy(x, p)
[<matplotlib.lines.Line2D at 0x7f09e1fe3190>]
x = 0.01*jnp.arange(0, 21, 1)
p = [Pn_of_eta_given_eps(100, i, 0.2) for i in x]
plt.semilogy(x, p)
[<matplotlib.lines.Line2D at 0x7f09e14a1310>]
Split the data into several training-validation sets and average the error
Random split: perform $r$ random splits of $x\%$ training $(1-x)\%$ validation (typically 80/20)
K-fold: split in $k$ subsets and perform $k$ permutations $k-1$ sets for training, 1 set for validation
Select model that has lowest average validation error and evaluate on test
key = jax.random.PRNGKey(4) # chosen by a fair dice roll
X = jax.random.uniform(key, (100, 2))
y = gt(X)
def randomSplit(key, X, y, train_part=0.8):
n = X.shape[0]
n_train = int(train_part*n); n_test = n - n_train
p = jax.random.permutation(key, n)
X_train = X[p[0:n_train], :]; y_train = y[p[0:n_train]]
X_val = X[p[n_train:],:] ; y_val = y[p[n_train:]]
return X_train, y_train, X_val, y_val
key = jax.random.PRNGKey(32)
l = []
for k in range(1, 30):
lk = []
for s in range(10):
key, skey = jax.random.split(key)
X_train, y_train, X_val, y_val = randomSplit(skey, X, y)
nn = KNearestNeighbor(X_train, y_train, k=k)
lk.append(loss(nn(X_val), y_val))
l.append(lk)
l = jnp.asarray(l)
plt.errorbar(range(1,30), l.mean(axis=1), l.std(axis=1), fmt='-k')
<ErrorbarContainer object of 3 artists>
Once hyperparameters are selected, train on full training set, eval on test
nn = KNearestNeighbor(X, y, k=4)
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = nn(xx).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y), plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
(<matplotlib.collections.PathCollection at 0x7f09e0fca9d0>, <matplotlib.collections.PathCollection at 0x7f09e13e8e50>)
Solving problem #2, finding a good $f$ is hard:
t = jnp.arange(-1.5, 2, 0.01)
plt.plot(t, 1-(jnp.sign(t)==1), '-k')
plt.plot(t, jnp.maximum(0, 1 - t), '-r')
plt.plot(t, jnp.log(1+jnp.exp(-t)), '-g')
plt.plot(t, jnp.exp(-t), '-b')
[<matplotlib.lines.Line2D at 0x7f09e0f59510>]
Ellipse classifier with parameter $c_1, c_2, a, b$: $$ f(X) = 1 - (a(X_1 - c_1)^2 +b(X_2 - c_2)^2)$$ In matrix form $$f(X) = 1 - (X-C)^TA(X-C)$$
Using MSE $$\min_{A, C} \sum_x (y - 1 + (x-C)^TA(x-C))^2$$
def mse(y_hat, y):
return ((y-y_hat)**2).mean()
def circle(x, a, c):
xc = x - c[None, :] # broadcast to n x 2
return 1 - (a*xc**2).sum(1) # sum on axis=1
def loss(a, c, x, y):
y_hat = circle(x, a, c)
return mse(y_hat, y)
@jax.jit
def update(a, c, x, y):
da, dc = jax.grad(loss, argnums=(0,1))(a, c, x, y)
return a - 0.1 * da, c - 0.1 * dc
key = jax.random.PRNGKey(32)
key, skey = jax.random.split(key)
c = jax.random.uniform(key, (2,))
a = jnp.ones(2)
l = []
for t in range(5000):
a, c = update(a, c, X, y)
l.append(loss(a, c, X, y))
plt.plot(l, '-k')
[<matplotlib.lines.Line2D at 0x7f09e1908090>]
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = jnp.sign(circle(xx, a, c)-0.5).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y), plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
(<matplotlib.collections.PathCollection at 0x7f09e13b4150>, <matplotlib.collections.PathCollection at 0x7f09e0cfd310>)
Rectangle $\rightarrow$ switch from $\ell_2$ to $\ell_\infty$ norm
$$ f(X) = 1 - \max(a(X_1 - c_1)^2, b(X_2 - c_2)^2)$$def square(x, a, c):
xc = x - c[None, :] # broadcast to n x 2
return 1 - (a*xc**2).max(1) # inf norm
def loss(a, c, x, y):
y_hat = square(x, a, c)
return mse(y_hat, y)
@jax.jit
def update(a, c, x, y):
da, dc = jax.grad(loss, argnums=(0,1))(a, c, x, y)
return a - 0.1 * da, c - 0.1 * dc
key = jax.random.PRNGKey(32)
key, skey = jax.random.split(key)
c = jax.random.uniform(key, (2,))
a = jnp.ones(2)
l = []
for t in range(5000):
a, c = update(a, c, X, y)
l.append(loss(a, c, X, y))
plt.plot(l, '-k')
[<matplotlib.lines.Line2D at 0x7f09e0f17c50>]
t = 50
tx = jnp.linspace(0, 1, t); ty = jnp.linspace(0, 1, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = jnp.sign(square(xx, a, c)-0.5).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels)
plt.scatter(X[:,0], X[:,1], c=y), plt.scatter(Xt[:,0], Xt[:,1], marker='v', c=yt)
(<matplotlib.collections.PathCollection at 0x7f09e136de50>, <matplotlib.collections.PathCollection at 0x7f09e136d350>)
def RandomSplitCV(key, X, y, cls_func, max_steps=10000):
# get a random 80% split of X,y
# optimize a,c using cls_func for max_steps
# keep track of training loss and validation loss
return l_train, l_val
# perform 10 random split CV
key = jax.random.PRNGKey(67)
# l_train = jax.random.uniform(key, (10000, 10)); l_val = 0.2*l_train
# plot train and val loss
x = jnp.arange(10000)
l_mean = l_train.mean(axis=1); l_std = l_train.std(1)
plt.plot(x, l_mean, '-b'); plt.fill_between(x, l_mean, l_mean-l_std, l_mean+l_std, color='b', alpha=0.5)
l_mean = l_val.mean(axis=1); l_std = l_val.std(1)
plt.plot(x, l_mean, '-r'); plt.fill_between(x, l_mean, l_mean-l_std, l_mean+l_std, color='r', alpha=0.5)
Supervised vs Unsupervised
Online vs Batch: