Linear prediction function $$f(\mathbf{x}) = \text{sign}(\langle \mathbf{w}, \mathbf{x}\rangle +b) $$
Defines a hyperplane
w = jnp.ones(2)
b = -0.5
t = 40; tx = jnp.linspace(-1, 2, t); ty = jnp.linspace(-1, 2, t)
xv, yv = jnp.meshgrid(tx, ty, sparse=True); xv = xv.squeeze(); yv = yv.squeeze()
xx = jnp.array([[xx, yy] for yy in yv for xx in xv])
levels=jnp.linspace(-1.5, 1.5, 10)
y_pred = (1.*(jnp.matmul(xx, w)+b > 0)).reshape(t, t)
plt.contourf(xv, yv, -y_pred, levels=levels);
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Hinge loss:
$$ l(y, f(\mathbf{x})) = \max(0, 1 - yf(\mathbf{x})) $$Given training set $\mathcal{A} = \{(\mathbf{x}, y)\}$, minimize the empirical risk: $$\min_{\mathbf{w}, b} \frac{1}{n}\sum_i max(0, 1 - y_i (\langle\mathbf{w}, \mathbf{x}\rangle + b)) $$
Convex problem (sum of convex) easy optimization by gradient descent
For large training sets, stochastic gradient descent works great
X = data['X_train_bin']
y = data['y_train_bin']*2-1
def func(w, b, x):
return jnp.matmul(x, w) + b
def hinge(w, b, x, y):
return jax.nn.relu(1 - y * func(w, b, x)).mean()
@jax.jit
def update(w, b, x, y):
dw, db = jax.grad(hinge, argnums=(0,1))(w, b, x, y)
return w - 0.01*dw, b - 0.01*db
w = np.random.randn(784)
b = 0.
loss = []
for t in range(500):
loss.append(hinge(w, b, X, y))
w, b = update(w, b, X, y)
plt.plot(loss)
[<matplotlib.lines.Line2D at 0x7f5cc401d5d0>]
def accuracy(y_pred, y_true):
return jnp.sign(y_true*y_pred).mean()
y_pred = func(w, b, X)
print('accuracy: {}'.format(accuracy(y_pred, y)))
accuracy: 1.0
X_val = data['X_val_bin']
y_val = data['y_val_bin']*2-1
y_pred = func(w, b, X_val)
print('validation accuracy: {}'.format(accuracy(y_pred, y_val)))
validation accuracy: 0.9047619104385376
w = jnp.array([-1, 1]) + 0.1*np.random.randn(5, 2)
plt.scatter([0, 1], [0, 1], c=[0, 1])
for i in range(5):
plt.plot([0, 1], [w[i,1], w[i,0]+w[i,1]])
Image('complexity.pdf.png', width=400)
The Structural Risk Minimization principle defines a trade-off between the quality of the approximation of the given data and the complexity of the approximating function (Vladimir N. Vapnik)
Image('vapnik.jpg', width=400)