5  Speeding up Python for Linear Algebra Tasks

Python is easy to read, but pure-Python loops can be slow if you do not leverage optimized libraries (BLAS, LAPACK). Here we explore two tools—Numba and JAX—to accelerate common linear algebra operations.

5.1 Numba: Just-In-Time Compilation

Numba uses LLVM to compile Python functions to machine code at runtime. Key points:

  • Decorators: Use @njit (nopython mode) for best speed.
  • Type inference: Numba infers types on first run, then compiles specialized code.
  • Compilation overhead: The first call incurs compilation time; subsequent calls are fast.
  • Object mode vs nopython mode: Always aim for nopython mode to avoid Python object overhead.

JIT Workflow 1. Call function → type inference → LLVM IR generation.
2. LLVM IR → machine code (cached).
3. Subsequent calls use cached machine code.

Example: Matrix–Vector Multiplication

from numba import njit
import numpy as np
import time # for timing

@njit
def matvec(A, x):
    m, n = A.shape
    y = np.zeros(m)
    for i in range(m):
        temp = 0.0
        for j in range(n):
            temp += A[i, j] * x[j]
        y[i] = temp
    return y

# Prepare data
dim = 500
A = np.random.rand(dim, dim)
x = np.random.rand(dim)

# Using NumPy's dot product
start = time.time()
y0 = A @ x
end = time.time()
print("NumPy time (ms): ", 1e3*(end - start))

# Using Numba's compiled function
y0 = matvec(A, x) # First call for compilation

start = time.time()
y1 = matvec(A, x)
end = time.time()
print("Numba time (ms): ", 1e3*(end - start))
NumPy time (ms):  0.3027915954589844
Numba time (ms):  0.29087066650390625

In practice, Numba can speed up this looped version by 10×–100× compared to pure Python, approaching the speed of NumPy’s optimized routines. The reader is encouraged to try the code without the @njit decorator to see the difference in performance.

5.2 JAX: XLA Compilation and Automatic Differentiation

JAX is a high-performance library from Google Research that extends NumPy with just-in-time compilation and automatic differentiation. It - Compiles array operations via XLA, fusing kernels and reducing Python overhead. - Supports GPU and TPU backends with minimal code changes. - Provides grad for gradients of scalar functions, enabling optimisation and machine-learning tasks. - Offers advanced transformations like vmap (vectorisation) and pmap (parallelism on multiple devices).

JAX is widely used in deep learning frameworks (e.g. Flax, Haiku), reinforcement learning, and scientific research (including physics simulations), thanks to its blend of speed and flexibility.

5.2.1 A Quick Overview of Automatic Differentiation

Automatic differentiation (AD) is a family of techniques to compute exact derivatives of functions defined by computer programs. Unlike symbolic differentiation (which can lead to expression swell) or numerical finite-difference (which suffers from truncation and round-off error), AD exploits the fact that any complex function is ultimately composed of a finite set of elementary operations (addition, multiplication, sin, exp, …) whose derivatives are known exactly.

Limitations of Finite Differences

A common finite-difference formula for a scalar function \(f(x)\) is the central difference

\[ \frac{df}{dx}(x)\approx \frac{f(x+h)-f(x-h)}{2h}, \]

with local truncation error \(\mathcal{O}(h^2)\). However, this approach has important limitations:

  1. Truncation vs. round-off: If \(h\) is too large, the \(\mathcal{O}(h^2)\) term dominates. If \(h\) is too small, floating-point cancellation makes the numerator \(f(x+h)-f(x-h)\) inaccurate.

  2. Cost with many parameters: For \(f:\mathbb{R}^n\to\mathbb{R}\), the gradient component \(i\) is

    \[ \frac{\partial f}{\partial x_i}(\mathbf{x})\approx \frac{f(\mathbf{x}+h\mathbf{e}_i)-f(\mathbf{x}-h\mathbf{e}_i)}{2h}. \]

    Computing all \(n\) components requires \(2n\) evaluations of \(f\), so the cost scales as \(\mathcal{O}(n)\) in \(f\)-calls. For large \(n\) (many parameters), this becomes prohibitive.

  3. Non-smooth or branching code: When \(f\) contains control flow or non-differentiable operations, finite differences may give misleading or undefined results.

Automatic Differentiation and the Chain Rule

Automatic differentiation (AD) applies the chain rule to each elementary operation in code (addition, multiplication, sin, exp, etc.), yielding exact derivatives up to floating-point precision. For a composition

\[ u = g(x),\quad y = f(u), \]

AD uses the chain rule:

\[ \frac{dy}{dx} = \frac{df}{du}\frac{dg}{dx}. \]

In more complex nests, e.g.

\[ v = h(u),\quad u = g(x),\quad y = f(v), \]

we get

\[ \frac{dy}{dx} = \frac{df}{dv}\frac{dh}{du}\frac{dg}{dx}. \]

AD comes in two modes:

  • Forward mode (propagate derivatives from inputs to outputs).
  • Reverse mode (propagate sensitivities from outputs back to inputs).

JAX implements both and selects the most efficient strategy automatically.

Comparing Accuracy: AD vs Finite Differences

Below is a Quarto code cell that plots the error of finite differences (varying step size \(h\)) and automatic differentiation against the true derivative of \(f(x) = e^{\sin(x)}\) at \(x=1.0\).

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt # for plotting

# Set JAX to use 64-bit floats
jax.config.update("jax_enable_x64", True)

# Define function and true derivative
def f_np(x):
    return np.exp(np.sin(x))

def df_true(x):
    return np.cos(x) * np.exp(np.sin(x))

# Point of evaluation
x0 = 1.0

# Finite-difference errors for varying h
hs = np.logspace(-8, -1, 50)
errors_fd = []
for h in hs:
    df_fd = (f_np(x0 + h) - f_np(x0 - h)) / (2 * h)
    errors_fd.append(abs(df_fd - df_true(x0)))

# Automatic differentiation error (constant)
df_ad = jax.grad(lambda x: jnp.exp(jnp.sin(x)))(x0)
error_ad = abs(np.array(df_ad) - df_true(x0))

print(f"AD error: {error_ad}")
print(f"FD minimum error: {min(errors_fd)}")

# Plot
fig, ax = plt.subplots()
ax.loglog(hs, errors_fd, marker="o")
ax.set_xlabel("Step size $h$")
ax.set_ylabel("Error of Finite Differences")

# Show in Quarto
plt.savefig("_tmp_fig.svg")
plt.close(fig)
SVG("_tmp_fig.svg")
AD error: 0.0
FD minimum error: 7.006839553014288e-12

This plot illustrates that finite differences achieve minimal error at an optimal \(h\), but degrade for too large or too small \(h\), while AD remains accurate to machine precision regardless of step size.

5.3 Why Computing Gradients Is Important in Quantum Physics

In quantum physics, many problems reduce to optimizing parameters in a model or a control protocol. Computing gradients of a cost function with respect to these parameters is essential for efficient and reliable optimization.

  1. Variational quantum algorithms: In methods like the variational quantum eigensolver (VQE)(Peruzzo et al. 2014), a parametrised quantum state \(|\psi(\boldsymbol{\theta})\rangle\) depends on parameters \(\boldsymbol{\theta}=(\theta_1,\dots,\theta_n)\). One minimises the expectation \[ E(\boldsymbol{\theta}) = \langle\psi(\boldsymbol{\theta})|\hat H|\psi(\boldsymbol{\theta})\rangle. \]

    Gradient-based methods require

    \[ \frac{\partial E}{\partial \theta_i} = \frac{\partial}{\partial \theta_i} \langle\psi(\boldsymbol{\theta})|\hat H|\psi(\boldsymbol{\theta})\rangle. \]

    AD enables exact evaluation of these derivatives through the quantum circuit parameters, improving convergence compared to gradient-free methods.

  2. Quantum optimal control(D’Alessandro 2021; Khaneja et al. 2005): One shapes control fields \(u(t)\) in the Hamiltonian \[ \hat H(t; u) = \hat H_0 + \sum_i u_i(t) \hat H_i \]

    to drive the system from an initial state \(|\psi_0\rangle\) to a target \(|\psi_T\rangle\). A typical cost function is

    \[ J[u] = 1 - |\langle\psi_T|\mathcal U_T[u]|\psi_0\rangle|^2, \]

    where \(\mathcal U_T[u]\) is the time-ordered evolution. Computing gradients \(\delta J/\delta u_i(t)\) is needed for gradient-ascent pulse engineering (GRAPE) algorithms. AD can differentiate through time-discretised propagators and ODE solvers, automating derivation of \(\delta J/\delta u_i(t)\) and providing machine-precision gradients for faster convergence.

  3. Parameter estimation and tomography(Lvovsky and Raymer 2009): Maximum-likelihood estimation for quantum states or processes often involves maximising a log-likelihood \(L(\boldsymbol{\theta})\). Gradients speed up estimation and enable standard optimisers (e.g. L-BFGS).

By providing exact, efficient gradients even through complex quantum simulations (time evolution, measurement models, noise), automatic differentiation (via JAX or similar frameworks) has become a key tool in modern quantum physics research.

5.4 Summary

  • Numba: Best for speeding up existing NumPy loops with minimal code changes. Ideal when you do not need gradients or accelerators.
  • JAX: Ideal for optimisation tasks requiring gradients, large-scale batch operations, or GPU/TPU acceleration. The XLA compiler often outperforms loop-based JIT for fused kernels.