I've recently experimented with Jax, Google's new deep learning framework. Overall, I have a good impression of it, with one exception: data loading. Both the Jax tutorial and Google's Vision Transformer1 use tensorflow datasets for data loading. To be honest, I'm not a big fan of this approach.

Fortunately, Pytorch has great data loaders! And they are not too hard to modify for Jax. I'll borrow albumentation docs' code for downloading the Cats & Dogs dataset for illustration.

from collections import defaultdict
import copy
import random
import os
import shutil
from urllib.request import urlretrieve

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

dataset_directory = "/mnt/data/cats-vs-dogs"

class TqdmUpTo(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
   = tsize
        self.update(b * bsize - self.n)

def download_url(url, filepath):
    directory = os.path.dirname(os.path.abspath(filepath))
    os.makedirs(directory, exist_ok=True)
    if os.path.exists(filepath):

    with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=os.path.basename(filepath)) as t:
        urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None) = t.n

def extract_archive(filepath):
    extract_dir = os.path.dirname(os.path.abspath(filepath))
    shutil.unpack_archive(filepath, extract_dir)
def display_image_grid(images_filepaths, predicted_labels=(), cols=5):
    rows = len(images_filepaths) // cols
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
    for i, image_filepath in enumerate(images_filepaths):
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        true_label = os.path.normpath(image_filepath).split(os.sep)[-2]
        predicted_label = predicted_labels[i] if predicted_labels else true_label
        color = "green" if true_label == predicted_label else "red"
        ax.ravel()[i].set_title(predicted_label, color=color)

filepath = os.path.join(dataset_directory, "")

root_directory = os.path.join(dataset_directory, "PetImages")

cat_directory = os.path.join(root_directory, "Cat")
dog_directory = os.path.join(root_directory, "Dog")

cat_images_filepaths = sorted([os.path.join(cat_directory, f) for f in os.listdir(cat_directory)])
dog_images_filepaths = sorted([os.path.join(dog_directory, f) for f in os.listdir(dog_directory)])
images_filepaths = [*cat_images_filepaths, *dog_images_filepaths]

train_images_filepaths = [i for i in images_filepaths[:200] if cv2.imread(i) is not None]

print(f"Size of (example) training set: {len(train_images_filepaths)}")

print(f"Example file paths: ")
for i in range(5): print(train_images_filepaths[i])
Size of (example) training set: 199
Example file paths: 

The images have varying shapes; for this example I'm using opencv to load them (PIL would work just as well).


Creating a Dataset

To create a standard Pytorch dataset, we subclass it and implement a __length__ and __getitem__ method:

class Dataset():
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

We're going to use the example dataset from the albumentation docs:

class CatsVsDogsDataset(Dataset):
    def __init__(self, images_filepaths, transform=None):
        self.images_filepaths = images_filepaths
        self.transform = transform

    def __len__(self):
        return len(self.images_filepaths)

    def __getitem__(self, idx):
        image_filepath = self.images_filepaths[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if os.path.normpath(image_filepath).split(os.sep)[-2] == "Cat":
            label = 1.0
            label = 0.0
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        return image, label

So far, there is nothing new. Our dataset is an object we can index into, and it returns numpy array along with binary label.

dataset = CatsVsDogsDataset(train_images_filepaths)

dataset[3][0].shape, dataset[0][1]
((424, 500, 3), 0.0)

Adding transforms

Since we're going along the albumentations tutorial, adding transforms comes out of the box. The only change to be aware of: don't add a ToTensor conversion as we won't use pytorch!

train_transform = A.Compose(
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
        A.RandomCrop(height=128, width=128),
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
dataset = CatsVsDogsDataset(train_images_filepaths, transform=train_transform)

def visualize_augmentations(dataset, idx=0, samples=5, cols=5):
    dataset = copy.deepcopy(dataset)
    dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
    rows = samples // cols
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
    for i in range(samples):
        image, _ = dataset[idx]


Of course, we can use any other data augmentation library, or even build our own from scratch. Doing the latter might be worthwhile for two reasons:

  1. since Jax uses a very-numpy-like API, implementing (batched) transformations on the GPU shouldn't be too hard.
  2. albumentations uses a different random seed generator, for reproducibility it's best to use the one from Jax.

Anyway, for this example, using Albumentations will suffice.


The sampler is the first modification. We want to rely on jax's random number generator instead of the one built into Pytorch, making random number generation very explicit.

To implement a sampler, we just have to create an iterator over indices. Sampling sequentially works for a validation set, but for training, we want to shuffle the data.

import jax
import jax.numpy as jnp

from import Sampler, SequentialSampler

class RandomSampler(Sampler):
    def __init__(self, data_source, rng_key):
        self.data_source = data_source
        self.rng_key = rng_key
    def __len__(self):
        return len(self.data_source)
    def __iter__(self):
        self.rng_key, current_rng = jax.random.split(self.rng_key)
        return iter(jax.random.permutation(current_rng, jnp.arange(len(self))).tolist())


Next, we have to modify the default collation method (so we don't get tensors). We'll use classic numpy instead of Jax, as we want to do these operations on the CPU (Jax GPU is not going to work with num_workers > 0), and have to put the whole batch to the GPU after it's assembled.

import numpy as np
from import DataLoader

def stack_collate(batch):
    return [np.stack(x) for x in zip(*batch)]

class JaxDataLoader(DataLoader):
    def __init__(self, dataset, rng_key, batch_size=1, shuffle=False,
        if shuffle:
            sampler = RandomSampler(dataset, rng_key)
            sampler = SequentialSampler(dataset)
        super().__init__(dataset, batch_size, sampler=sampler,
                         collate_fn=collate_fn, **kwargs)

That's it! The main difference to pytorch is that we have to supply the random number key explicitly (which actually has quite a few benefits).

rng = jax.random.PRNGKey(0)
dl = JaxDataLoader(dataset, rng, 8, shuffle=True, collate_fn=stack_collate, num_workers=4)

We can inspect the first element of a batch the usual way:

xb, yb = next(iter(dl))

type(xb), xb.shape, type(yb), yb.shape
(numpy.ndarray, (8, 128, 128, 3), numpy.ndarray, (8,))

Plotting a batch is easy (it's just numpy!) since the spatial and channel axes are not reshuffled:

fig, axes = plt.subplots(1, 8, figsize = (15, 5))
for i, ax in enumerate(axes.flat):
    ax.set_title("Cat" if yb[i] == 1.0 else "Dog")

We can easily put the numpy batch on the GPU:

GpuDevice(id=0, process_index=0)

To speed up data-loading, I used 4 worker threads:

for xb, yb in dl:
150 ms ± 889 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

The data loading speed is on-par with pytorch:

import torch
from import DataLoader

torch_dl = DataLoader(dataset, 8, shuffle=True, num_workers=4)
for xb, yb in torch_dl:
195 ms ± 2.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


1. The official ViT implementation seems to be in Jax. I like that Google is moving away from tensorflow