Machine Learning Lab: A Simple Neural Network on MNIST with JAX#
In this lab, we will
Downloads and loads MNIST into NumPy arrays (if it doesn’t already exist locally).
Builds a simple Multi-Layer Perceptron in JAX.
Trains the network on MNIST.
Evaluates the performance on test data.
Provides a custom inference function for your own handwriting images.
This lab is based on a JAX example. Please notice that MNIST is the “hello world” for machine learning, and there are many many examples available online, including some simplier ones that use libraries: JAX with pre-built optimizers, FLAX, and pytorch, Keras.
MNIST Data Loader#
We start by downloading the MNIST data set and store it locally. Our data loader will parse, reshape, normalize them, and return them in NumPy arrays.
from os.path import isfile
from urllib.request import urlretrieve
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
# File names
files = {
"train_images": "train-images-idx3-ubyte.gz",
"train_labels": "train-labels-idx1-ubyte.gz",
"test_images": "t10k-images-idx3-ubyte.gz",
"test_labels": "t10k-labels-idx1-ubyte.gz",
}
for key, file in files.items():
if not isfile(file):
url = base_url + file
print(f"Downloading {url} to {file}...")
urlretrieve(url, file)
else:
print(f"{file} exists; skip download")
Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to train-images-idx3-ubyte.gz...
Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to train-labels-idx1-ubyte.gz...
Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to t10k-images-idx3-ubyte.gz...
Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to t10k-labels-idx1-ubyte.gz...
import gzip
import struct
import array
from jax import numpy as np
# Parsing functions
def parse_labels(file):
with gzip.open(file, "rb") as fh:
_magic, num_data = struct.unpack(">II", fh.read(8))
# Read the label data as 1-byte unsigned integers
return np.array(array.array("B", fh.read()), dtype=np.uint8)
def parse_images(file):
with gzip.open(file, "rb") as fh:
_magic, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
# Read the image data as 1-byte unsigned integers
images = np.array(array.array("B", fh.read()), dtype=np.uint8)
# Reshape to (num_data, 28, 28)
images = images.reshape(num_data, rows, cols)
return images
# Parse raw data
train_images_raw = parse_images(files["train_images"])
train_labels_raw = parse_labels(files["train_labels"])
test_images_raw = parse_images(files["test_images"])
test_labels_raw = parse_labels(files["test_labels"])
# Standardize the images, i.e., flatten and normalize images to [0, 1]
def standardize(images):
return images.reshape(-1, 28*28).astype(np.float32) / 255
train_images = standardize(train_images_raw)
test_images = standardize(test_images_raw)
# One-hot encode labels
def one_hot(labels, num_classes=10):
return np.eye(num_classes)[labels]
train_labels = one_hot(train_labels_raw, 10).astype(np.float32)
test_labels = one_hot(test_labels_raw, 10).astype(np.float32)
Visualize Some Training and Testing Data#
from matplotlib import pyplot as plt
plt.imshow(train_images_raw[0,:,:], cmap='gray')
<matplotlib.image.AxesImage at 0x7f59d868adb0>

plt.imshow(test_images_raw[0,:,:], cmap='gray')
<matplotlib.image.AxesImage at 0x7f59d81941a0>

Define a Simple Neural Network in JAX#
In this subsection, we introduce the core function needed to initialize the parameters of a multi-layer network.
Our network will have multiple layers, each characterized by a weight matrix W
and a bias vector b
.
We will use random initialization scaled by a small factor to ensure stable starting values for training.
import numpy.random as npr
def init_params(scale, layer_sizes, rng=npr.RandomState(0)):
"""
Initialize the parameters (weights and biases) for each layer in the network.
Parameters
----------
scale : float
A scaling factor to control the initial range of the weights.
layer_sizes : list of int
The sizes of each layer in the network.
e.g., [784, 1024, 1024, 10] means:
- Input layer: 784 units
- Hidden layer 1: 1024 units
- Hidden layer 2: 1024 units
- Output layer: 10 units
rng : numpy.random.RandomState
Random state for reproducibility.
Returns
-------
params : list of tuples (W, b)
Each tuple contains (W, b) for a layer.
- W is a (input_dim, output_dim) array of weights
- b is a (output_dim,) array of biases
"""
return [
(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n in zip(layer_sizes[:-1], layer_sizes[1:])
]
How it works:
We specify a list of layer sizes: for example,
[784, 1024, 1024, 10]
.For each pair of consecutive sizes
(m, n)
, we create a weight matrix W of shape(m, n)
and a bias vectorb
of shape(n,)
.Multiplying by scale ensures that initial values are not too large, which helps prevent numerical issues early in training.
We store all
(W, b)
pairs in a list, one pair per layer, to be used throughout training and inference.
By calling init_params(scale, layer_sizes)
, you obtain an easy-to-manipulate structure that keeps all the parameters needed for your neural network.
# Define network architecture and hyperparameters
layer_sizes = [784, 1024, 1024, 10] # 2 hidden layers
param_scale = 0.1
# Initialize parameters
params = init_params(param_scale, layer_sizes)
Forward Pass: The predict
Function#
Once the network parameters are initialized, we need a function to perform the forward pass, producing an output for each batch of inputs.
Below, we define predict
to process data through multiple layers—using a tanh
activation on the hidden layers—and compute a log-softmax on the final output layer for stability.
from jax import numpy as np
from jax.scipy.special import logsumexp
def predict(params, inputs):
"""
Compute the network's output logits for a batch of inputs, then subtract
log-sum-exp for numerical stability (log-softmax).
Network architecture:
- Hidden layers use tanh activation
- Output layer is linear (we'll do log-softmax here)
Parameters
----------
params : list of (W, b) tuples
Network's parameters for each layer.
inputs : np.ndarray
A batch of input data of shape (batch_size, input_dim).
Returns
-------
np.ndarray
Log probabilities of shape (batch_size, 10).
"""
activations = inputs
# Hidden layers
for w, b in params[:-1]:
outputs = np.dot(activations, w) + b
activations = np.tanh(outputs)
# Final layer (logits)
final_w, final_b = params[-1]
logits = np.dot(activations, final_w) + final_b
# Log-Softmax: subtract logsumexp for numerical stability
return logits - logsumexp(logits, axis=1, keepdims=True)
Hidden Layers (
tanh
): Each hidden layer applies a linear transformation (np.dot(activations, w) + b
) followed by the hyperbolic tangent activation function (np.tanh
).Final Layer (
logits
): The last layer’s output is not activated by tanh; instead, we use it directly as logits.Log-Softmax: We transform logits to log probabilities by subtracting the logsumexp(logits) along the class dimension. This step ensures numerical stability and can be directly used to compute losses like cross-entropy.
Defining the Loss Function#
To guide training, we need a loss function that measures how well our network’s predictions match the true labels. This is like \(\chi^2\) when we need to fit a curve. Below, we define a function that computes the negative log-likelihood (NLL) over a batch of data.
def loss(params, batch):
"""
Computes the average negative log-likelihood loss for a batch.
Parameters
----------
params : list of (W, b) tuples
The network's parameters.
batch : tuple (inputs, targets)
- inputs: np.ndarray of shape (batch_size, 784)
- targets: np.ndarray of shape (batch_size, 10) (one-hot labels)
Returns
-------
float
Mean negative log-likelihood over the batch.
"""
inputs, targets = batch
preds = predict(params, inputs)
# preds are log-probs, multiply with one-hot targets and sum -> log-likelihood
return -np.mean(np.sum(preds * targets, axis=1))
Inputs and Targets: A single batch typically consists of a set of input vectors (inputs) and corresponding one-hot encoded labels (targets).
Forward Pass: We call predict(params, inputs), which returns the log probabilities for each class.
NLL Computation: We multiply the log probabilities by the one-hot labels (so we only pick out the log probability of the correct class for each example). Summing these values (log-likelihood) and then negating yields the negative log-likelihood.
Mean Value: We take the average across the batch, yielding a scalar loss.
This loss metric drives parameter updates: minimizing it pushes the network to assign higher probabilities to the correct classes.
Evaluating Model Performance#
While the network is trained by minimizing the negative log-likelihood (NLL), we often monitor accuracy to get an intuitive sense of model performance. The function below calculates the fraction of samples in a batch that are correctly classified.
def accuracy(params, batch):
"""
Computes classification accuracy of the network on a given batch.
Parameters
----------
params : list of (W, b) tuples
The network's parameters.
batch : tuple (inputs, targets)
- inputs: np.ndarray (batch_size, 784)
- targets: np.ndarray (batch_size, 10) (one-hot labels)
Returns
-------
float
Fraction of correctly classified samples.
"""
inputs, targets = batch
target_class = np.argmax(targets, axis=1) # ground truth index
predicted_class = np.argmax(predict(params, inputs), axis=1)
return np.mean(predicted_class == target_class)
Predicted Class:
We use predict(params, inputs) to get log probabilities.
Taking the argmax across the class dimension finds the class with the highest log probability.
Compare to Ground Truth:
We similarly get the ground truth label indices from the one-hot targets by using np.argmax(targets, axis=1).
Accuracy Computation:
We compute the fraction of instances where the predicted class matches the ground-truth class.
This value ranges between 0 (no correct predictions) and 1 (perfect classification).
Monitoring accuracy alongside the loss offers a more intuitive measure of how well the model performs on a classification task.
Gradient Descent for Training: JIT-Compiled Update Function#
To optimize our network, we can use Stochastic Gradient Descent (SGD), updating parameters in the direction that reduces the loss.
This is essentially the same algorithm we implemented in our optimization class!
Except we only implement a single step for now.
Here, we decorate our update step with @jit
to compile it for efficient execution on CPU or GPU.
from jax import jit, grad
@jit
def update(params, batch, step_size):
"""
Single step of gradient-based parameter update using simple SGD.
grad(loss)(params, batch) computes the gradient of the loss function
with respect to the parameters for the given batch.
"""
grads = grad(loss)(params, batch)
return [
(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)
]
Key Ideas:
grad(loss) automatically differentiates the loss function w.r.t. all parameters (params), yielding parameter gradients.
SGD Update:
For each weight w, we update it by w - step_size * dw.
Similarly for each bias b.
@jit
Decorator:Compiles the update step using XLA (Accelerated Linear Algebra).
Improves performance by running the update efficiently on CPU/GPU.
Preparing the Batching Mechanism#
batch_size = 128 # the number of samples per parameter update.
num_train = train_images.shape[0]
num_batches = num_train // batch_size
def get_batch(rng=npr.RandomState(0)):
"""
Generator function that yields shuffled batches indefinitely.
"""
while True:
# Randomly permute the indices
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i*batch_size : (i+1)*batch_size]
# Yield a tuple (inputs, labels) for this batch
yield (train_images[batch_idx], train_labels[batch_idx])
train_batch_generator = get_batch()
Shuffling: Each epoch, we shuffle the training indices (
perm = rng.permutation(num_train)
) to ensure that each mini-batch contains a random subset of the dataset.Batch Extraction: We slice the permuted indices into chunks of size batch_size. Each chunk defines which samples from train_images and train_labels go into the current batch.
By continuously yielding batches in an infinite while True:
loop, we can keep calling next(train_batch_generator)
without manually restarting the data pipeline each epoch.
The Training Loop#
Now we can train our neural network by iterating over epochs and batches:
learning_rate = 0.001 # the number of times we iterate over the entire training dataset.
num_epochs = 5 # the number of samples per parameter update.
from time import time
for epoch in range(num_epochs):
start_time = time()
# Go through the entire training set
for _ in range(num_batches):
batch_data = next(train_batch_generator)
params = update(params, batch_data, step_size=learning_rate)
epoch_time = time() - start_time
# Evaluate training and test accuracy
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
print(f"Training accuracy: {train_acc:.4f}")
print(f"Test accuracy: {test_acc:.4f}")
Epoch 0 in 5.58 sec
Training accuracy: 0.7378
Test accuracy: 0.7515
Epoch 1 in 5.21 sec
Training accuracy: 0.8143
Test accuracy: 0.8279
Epoch 2 in 5.26 sec
Training accuracy: 0.8446
Test accuracy: 0.8569
Epoch 3 in 5.36 sec
Training accuracy: 0.8624
Test accuracy: 0.8707
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[20], line 8
6 # Go through the entire training set
7 for _ in range(num_batches):
----> 8 batch_data = next(train_batch_generator)
9 params = update(params, batch_data, step_size=learning_rate)
11 epoch_time = time() - start_time
Cell In[18], line 14, in get_batch(rng)
12 batch_idx = perm[i*batch_size : (i+1)*batch_size]
13 # Yield a tuple (inputs, labels) for this batch
---> 14 yield (train_images[batch_idx], train_labels[batch_idx])
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/array.py:383, in ArrayImpl.__getitem__(self, idx)
378 out = lax.squeeze(out, dimensions=dims)
380 return ArrayImpl(
381 out.aval, sharding, [out], committed=False, _skip_checks=True)
--> 383 return indexing.rewriting_take(self, idx)
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:640, in rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding)
637 return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
639 treedef, static_idx, dynamic_idx = split_index_for_jit(idx, arr.shape)
--> 640 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
641 unique_indices, mode, fill_value, out_sharding)
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:649, in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value, out_sharding)
646 def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
647 unique_indices, mode, fill_value, out_sharding):
648 idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
--> 649 indexer = index_to_gather(np.shape(arr), idx) # shared with _scatter_update
650 jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices)
651 y = arr
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:825, in index_to_gather(x_shape, idx, normalize_indices)
822 if normalize_indices:
823 advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
824 for e, i, j in advanced_pairs)
--> 825 advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
827 x_axis = 0 # Current axis in x.
828 y_axis = 0 # Current axis in y, before collapsing. See below.
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:823, in <genexpr>(.0)
819 advanced_pairs = (
820 (lax_numpy.asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
821 if lax_numpy.isscalar(e) or isinstance(e, (Sequence, Array, np.ndarray)))
822 if normalize_indices:
--> 823 advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
824 for e, i, j in advanced_pairs)
825 advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
827 x_axis = 0 # Current axis in x.
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:201, in _normalize_index(index, axis_size)
199 return lax.add(index, axis_size_val) if index < 0 else index
200 else:
--> 201 return lax.select(index < 0, lax.add(index, axis_size_val), index)
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/lax/lax.py:1195, in add(x, y)
1175 r"""Elementwise addition: :math:`x + y`.
1176
1177 This function lowers directly to the `stablehlo.add`_ operation.
(...) 1192 .. _stablehlo.add: https://openxla.org/stablehlo/spec#add
1193 """
1194 x, y = core.standard_insert_pvary(x, y)
-> 1195 return add_p.bind(x, y)
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/core.py:496, in Primitive.bind(self, *args, **params)
494 def bind(self, *args, **params):
495 args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 496 return self._true_bind(*args, **params)
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/core.py:512, in Primitive._true_bind(self, *args, **params)
510 trace_ctx.set_trace(eval_trace)
511 try:
--> 512 return self.bind_with_trace(prev_trace, args, params)
513 finally:
514 trace_ctx.set_trace(prev_trace)
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/core.py:517, in Primitive.bind_with_trace(self, trace, args, params)
516 def bind_with_trace(self, trace, args, params):
--> 517 return trace.process_primitive(self, args, params)
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/core.py:1017, in EvalTrace.process_primitive(self, primitive, args, params)
1015 args = map(full_lower, args)
1016 check_eval_args(args)
-> 1017 return primitive.impl(*args, **params)
File /opt/hostedtoolcache/Python/3.12.10/x64/lib/python3.12/site-packages/jax/_src/dispatch.py:89, in apply_primitive(prim, *args, **params)
87 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
88 try:
---> 89 outs = fun(*args)
90 finally:
91 lib.jax_jit.swap_thread_local_state_disable_jit(prev)
KeyboardInterrupt:
Breaking this down:
Epoch Loop: We run
for epoch in range(num_epochs):
to repeat the training process multiple times over the dataset.Batch Loop: For each epoch, we execute a loop
for _ in range(num_batches):
to process all training batches.Parameter Update:
We call next(train_batch_generator) to obtain the next (inputs, labels) batch.
We then update the network parameters by calling:
params = update(params, batch_data, step_size=learning_rate)
This performs a single SGD step, moving each parameter slightly toward reducing the loss.
Timing: We measure how long each epoch takes by recording the start time with
time()
and subtracting from the end time.Evaluation: After processing all batches for the epoch, we compute:
train_acc
: Accuracy on the entire training set.test_acc
: Accuracy on the reserved test set.
Logging: We print out the epoch number, epoch duration, and both training and test accuracies. Monitoring test accuracy helps assess how well the model generalizes beyond the training data.
Loading and Preprocessing a Custom Image#
To run inference on your own handwriting, we first need to load the image from disk and convert it into a format suitable for our trained network:
from PIL import Image
def load_image(file):
"""
Loads a grayscale image from `file`, resizes it to 28x28,
and converts it to a (784,) float32 array with values in [0, 1].
Parameters
----------
file : str
Path to the image file (e.g., a PNG or JPG).
Returns
-------
np.ndarray
An array of shape (1, 784) containing normalized pixel values
suitable as input to our trained model.
"""
# Convert the image to grayscale and resize to 28x28
img = Image.open(file).convert('L').resize((28, 28))
# Convert to a NumPy array and normalize pixel intensities to [0, 1]
arr = np.array(img).astype(np.float32) / 255.0
# Flatten the 28x28 image into a single 784-dimensional vector
arr = arr.flatten()
# Reshape to (1, 784) to match the model's expected input batch shape
return np.array([arr])
Predicting the Digit Class#
With a properly formatted image, we can classify it using our trained model’s parameters:
def predict_digit(params, file_path):
"""
Predict the digit class for a custom handwritten image.
Parameters
----------
params : list of (W, b) tuples
The trained network parameters.
file_path : str
Path to the custom image file.
Returns
-------
int
The predicted digit label (0 through 9).
"""
# Convert the image to a suitable NumPy array
arr_np = load_image(file_path)
# Use our 'predict' function to get log probabilities for each class
log_probs = predict(params, arr_np) # shape: (1, 10)
# Select the digit class with the highest log probability
return int(np.argmax(log_probs, axis=1)[0])
:::{exercise}
Capture or scan your handwritten digit and save as “sample.png”.
Call the function:
label = predict_digit(params, "sample.png") print(f"Predicted digit: {label}")
Inspect the result: See whether the predicted label matches the actual digit you wrote.
With these two functions, your MNIST-trained model can be used in real-world testing scenarios, allowing you to evaluate its performance on custom, hand-drawn images. :::
label = predict_digit(params, "sample.png")
print(f"Predicted digit: {label}")