Machine Learning and Applications - SVM and kernel machines¶

David Picard¶

École des Ponts ParisTech¶

david.picard@enpc.fr¶

Binary Linear Classification¶

  • Input $\mathbf{x} \in \mathbb{R}^d$
  • Output $y \in \{-1; 1\}$

Linear prediction function $$f(\mathbf{x}) = \text{sign}(\langle \mathbf{w}, \mathbf{x}\rangle +b) $$

Defines a hyperplane

  • Normal vector $\mathbf{w}$
  • Bias (offset) $b$
In [3]:
w = jnp.ones(2)
b = -0.5

t = 40; tx = jnp.linspace(-1, 2, t); ty = jnp.linspace(-1, 2, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = (1.*(jnp.matmul(xx, w)+b > 0)).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels);
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

ERM¶

Hinge loss:

$$ l(y, f(\mathbf{x})) = \max(0, 1 - yf(\mathbf{x})) $$

Given training set $\mathcal{A} = \{(\mathbf{x}, y)\}$, minimize the empirical risk: $$\min_{\mathbf{w}, b} \frac{1}{n}\sum_i max(0, 1 - y_i (\langle\mathbf{w}, \mathbf{x}\rangle + b)) $$

Convex problem (sum of convex) easy optimization by gradient descent

For large training sets, stochastic gradient descent works great

MNIST¶

In [9]:
X = data['X_train_bin']
y = data['y_train_bin']*2-1

def func(w, b, x):
    return jnp.matmul(x, w) + b

def hinge(w, b, x, y):
    return jax.nn.relu(1 - y * func(w, b, x)).mean()

@jax.jit
def update(w, b, x, y):
    dw, db = jax.grad(hinge, argnums=(0,1))(w, b, x, y)
    return w - 0.01*dw, b - 0.01*db
In [10]:
w = np.random.randn(784)
b = 0.

loss = []
for t in range(500):
    loss.append(hinge(w, b, X, y))
    w, b = update(w, b, X, y)
plt.plot(loss)
Out[10]:
[<matplotlib.lines.Line2D at 0x7f5cc401d5d0>]
In [11]:
def accuracy(y_pred, y_true):
    return jnp.sign(y_true*y_pred).mean()

y_pred = func(w, b, X)
print('accuracy: {}'.format(accuracy(y_pred, y)))
accuracy: 1.0
In [12]:
X_val = data['X_val_bin']
y_val = data['y_val_bin']*2-1

y_pred = func(w, b, X_val)
print('validation accuracy: {}'.format(accuracy(y_pred, y_val)))
validation accuracy: 0.9047619104385376

Equivalent solutions¶

In [13]:
w = jnp.array([-1, 1]) + 0.1*np.random.randn(5, 2)

plt.scatter([0, 1], [0, 1], c=[0, 1])
for i in range(5):
    plt.plot([0, 1], [w[i,1], w[i,0]+w[i,1]])

Complexity impacts generalization¶

In [16]:
Image('complexity.pdf.png', width=400)
Out[16]:

Structural Risk Minimization¶

The Structural Risk Minimization principle defines a trade-off between the quality of the approximation of the given data and the complexity of the approximating function (Vladimir N. Vapnik)

In [18]:
Image('vapnik.jpg', width=400)
Out[18]: