Goodbye NumPy

For years, NumPy has been the bedrock of numerical computing in Python. If you’ve done any data science or machine learning, you’ve used it. It’s the default, the standard, the tool you reach for without a second thought. And for good reason. It’s fast, efficient, and its API for array and matrix operations is second to none.

Diagram of NumPy array data structure
Array programming with NumPy | Nature

The secret to its speed is that the heavy lifting isn’t done in Python at all. The core operations are written in highly optimized C. NumPy’s multi-dimensional arrays, or ndarrays, are stored in contiguous blocks of memory, which is a massive advantage over standard Python lists. This allows for vectorized operations that eliminate slow Python loops and make the best possible use of your CPU’s cache.

Comparison of memory allocation between Python lists and NumPy arrays
Simple Comparison of Memory Allocation using Python Lists and NumPy, by the Author

The performance difference isn’t subtle. A simple sum operation on a large list versus a NumPy array shows just how dramatic the speedup is.

import numpy as np
import sys
import time

# Create a large list and a NumPy array
size = 1000000
python_list = list(range(size))
numpy_array = np.array(range(size))

# Time the Python list sum
start_time = time.time()
sum_list = sum(python_list)
list_time = time.time() - start_time
print(f"Time taken to sum all elements in Python list: {list_time:.6f} seconds")

# Time the NumPy array sum
start_time = time.time()
sum_array = np.sum(numpy_array)
array_time = time.time() - start_time
print(f"Time taken to sum all elements in NumPy array: {array_time:.6f} seconds")

The results speak for themselves. NumPy is an order of magnitude faster.

Time taken to sum all elements in Python list: 0.011590 seconds Time taken to sum all elements in NumPy array: 0.001382 seconds

But for all its power, NumPy has a fundamental limitation. It was built for a different era of computing.

The wall you hit with NumPy

NumPy is fundamentally CPU-bound. Its architecture was designed to make the most of your central processing unit. But the world of high-performance computing, especially in machine learning, has moved on. The real power today lies in parallel processing on specialized hardware like GPUs and TPUs.

This is where NumPy hits a wall. It has no native ability to run on these modern accelerators. While it’s incredibly efficient on a CPU, it simply can’t tap into the thousands of cores available on a modern GPU. For the massive datasets and complex algorithms that define modern AI, this isn’t just a minor inconvenience; it’s a hard bottleneck.

The demand for GPU-accelerated computing is exploding. It’s the engine driving the entire deep learning revolution. And in this new landscape, a new tool was needed.

Graph showing the growth of the GPU market
GPU Market, Graphic Processing Unit Market Size 2023–2032

JAX(imus)

JAX is the answer. It’s a library for high-performance numerical computing that takes the familiar, beloved API of NumPy and supercharges it with the ability to run on GPUs and TPUs. It’s designed from the ground up to leverage the massive parallelism of modern hardware.

A GPU isn’t just a faster CPU; it’s a completely different architecture. Where a CPU has a few powerful cores designed for sequential tasks, a GPU has thousands of smaller cores designed to run many tasks simultaneously. JAX is built to exploit this.

Architectural comparison of a CPU and a GPU
CUDA — Introduction to the GPU

The best part? You barely have to change your code. JAX provides a drop-in replacement for NumPy. You can take your existing NumPy code, change import numpy as np to import jax.numpy as jnp, and you’re ready to run on the GPU.

Let’s look at a simple matrix multiplication benchmark on a GPU-enabled machine.

NumPy Code:

import numpy as np
import time

# Create two large matrices
a = np.random.rand(1000, 1000)
b = np.random.rand(1000, 1000)

# Measure the time taken by NumPy
start_time = time.time()
np.dot(a, b)
numpy_time = time.time() - start_time
print(f"Time taken by NumPy: {numpy_time} seconds")

Time taken by NumPy: 0.11988210678100586 seconds

JAX Code:

import jax.numpy as jnp
from jax import jit
import time

# Create two large matrices on the device
a = jnp.array(np.random.rand(1000, 1000))
b = jnp.array(np.random.rand(1000, 1000))

# JIT compile the function for maximum speed
@jit
def jax_dot(a, b):
    return jnp.dot(a, b)

# Measure the time taken by JAX
start_time = time.time()
# block_until_ready() ensures the computation is finished before stopping the timer
jax_dot(a, b).block_until_ready()
jax_time = time.time() - start_time
print(f"Time taken by JAX: {jax_time} seconds")

Time taken by JAX: 0.02443838119506836 seconds

Again, the results are clear. JAX is significantly faster by leveraging the GPU. But how does it actually achieve this?

The secret sauce: XLA and JIT

Two key technologies make JAX so powerful: XLA (Accelerated Linear Algebra) and JIT (Just-in-Time) compilation.

XLA is a domain-specific compiler for linear algebra that was developed at Google. Think of it as an optimization engine. It takes your high-level JAX code, analyzes the graph of mathematical operations, and fuses them into highly optimized machine code kernels tailored specifically for your hardware, whether it’s a CPU, GPU, or TPU. It performs clever optimizations like eliminating intermediate memory allocations and combining operations, which dramatically improves performance.

Diagram of the XLA compilation process
XLA Systematic Process, by the Author

JIT compilation is the mechanism that triggers XLA. By adding the @jit decorator to your Python function, you’re telling JAX: “Take this function, trace its operations, send it to XLA for optimization, and compile it down to machine code.”

The first time you call the JIT-compiled function, there’s a small overhead for the compilation. But every subsequent call executes the pre-compiled, highly-optimized code directly, resulting in massive speedups.

from jax import jit
import jax.numpy as jnp

def matmul(a, b):
    return jnp.dot(a, b)

# JIT compile the function
jit_matmul = jit(matmul)

# The first call compiles the function
result = jit_matmul(a, b)
# Subsequent calls are much faster
result = jit_matmul(a, b)

However, this power comes with a constraint. To be JIT-compiled, a function’s computational graph must be static. This means JAX struggles with dynamic Python features like if statements or for loops that depend on runtime values.

Fortunately, JAX provides JIT-compatible alternatives like jax.lax.cond and jax.lax.scan that allow you to express this logic in a way the compiler can understand and optimize.

from jax import lax

@jit
def optimized_function(x):
    # Define functions for the true and false branches
    def true_fun(_):
        return jnp.arange(int(x)).sum()
    def false_fun(_):
        return jnp.zeros((5,5))
    # Use lax.cond to handle the conditional logic
    return lax.cond(x > 0, true_fun, false_fun, None)

The superpowers

Beyond raw speed, JAX has other features that make it a true successor to NumPy for modern computing.

Effortless Parallelism with pmap

JAX makes it incredibly simple to scale your computations across multiple devices (e.g., multiple GPUs or TPU cores) with the pmap (parallel map) function. It automatically handles the distribution of data and computation, allowing you to parallelize your code with a single function call.

Diagram showing how pmap distributes data
data partitioning and distributing to GPU cores by PMAP function.
from jax import pmap

# A simple function to be parallelized
def square(x):
    return x ** 2

# Parallelize the function across all available devices
parallel_square = pmap(square)

# JAX automatically splits the input array across devices and runs in parallel
result = parallel_square(jnp.arange(8))

Autodiff

Perhaps the most revolutionary feature of JAX is its built-in automatic differentiation (Autodiff). It can automatically and efficiently compute gradients (derivatives) of your functions. For anyone in machine learning, this is a massive deal, as gradient computation is the core of training neural networks.

JAX’s autodiff is more precise than numerical differentiation and avoids the “expression swell” of symbolic differentiation. It uses reverse-mode differentiation, which is incredibly efficient for functions with many inputs, like the loss function of a neural network.

When you combine grad with @jit, you get JIT-compiled, hardware-accelerated gradient calculations. This is the recipe for incredibly fast model training. A comparison of training a simple neural network in JAX versus TensorFlow highlights the astonishing performance.

JAX NN Training:

# ... (setup code for data and model)
grad_loss = jit(grad(loss, argnums=(0, 1)))
start_time = time.time()
for i in range(num_epochs):
    dw, db = grad_loss(w, b, x, y)
    w -= learning_rate * dw
    b -= learning_rate * db
jax_time = time.time() - start_time
print(f"JAX Training Time: {jax_time} seconds")

JAX Training Time: 1.185472011566162 seconds

TensorFlow NN Training:

# ... (setup code for data and model)
model.compile(optimizer='sgd', loss='mean_squared_error')
start_time = time.time()
model.fit(x, y, epochs=1000, verbose=0)
tf_time = time.time() - start_time
print(f"TensorFlow Training Time: {tf_time} seconds")

TensorFlow Training Time: 82.8932785987854 seconds

The performance difference is staggering.

JAX in the real world

This isn’t just an academic toy. JAX is the foundation for a rapidly growing ecosystem of powerful libraries used in cutting-edge research and production systems.

  • EvoJAX: A neuroevolution library from Google that uses JAX to evolve neural networks for complex optimization and robotics tasks.
Diagram of EvoJAX solving a packing problem
EvoJAX: Bringing the Power of Neuroevolution to Solve Your Problems
  • Myriad: A testbed for trajectory optimization and deep learning, bridging the gap between control theory and modern reinforcement learning.
Official poster for the Myriad testbed
Myriad: a real-world testbed to bridge trajectory optimization and deep learning
  • Jumanji: A library from InstaDeep for creating high-performance, parallelizable reinforcement learning environments.
Jumanji library
Welcome to the Jungle!
  • Haiku: A simple and elegant neural network library from DeepMind, built on top of JAX.
Jumanji library
Haiku: Sonnet for JAX

Make the switch

JAX represents the next logical step in the evolution of numerical computing in Python. It takes the simplicity and power of the NumPy API and combines it with the performance of modern hardware accelerators. With features like JIT compilation, effortless parallelism, and built-in automatic differentiation, it’s not just a faster NumPy; it’s a fundamentally more powerful tool.

For anyone working in machine learning, scientific computing, or any field that demands high performance, the question is no longer if you should learn JAX, but when. The future of high-performance Python is here, and it’s time to start using it.

Categories:

Updated: