Introduction to JAX#

JAX is a high-performance numerical computing library developed by Google that brings together the familiar NumPy API with powerful features like automatic differentiation, just-in-time (JIT) compilation, and seamless GPU/TPU acceleration. Originally designed for machine learning research, JAX has become popular across scientific computing applications due to its speed and flexibility.

Jax itself provides a good introduction here and here. Otherwise, the following is a small crash course.

import jax
import jax.numpy as jnp
import fdtdx

Functional Programming Paradigm#

JAX operates exclusively in a functional programming style, which means it requires you to write pure functions without side effects. This functional approach has several important implications:

Immutable data#

Arrays and other data structures are treated as immutable. Operations create new objects rather than modifying existing ones, similar to how NumPy handles broadcasting operations.

This functional constraint enables JAX’s powerful transformations like jit (compilation), grad (automatic differentiation), vmap (vectorization), and pmap (parallelization). While the functional style requires some adjustment if you’re used to imperative programming, it unlocks JAX’s ability to automatically optimize and transform your numerical code in ways that would be impossible with stateful operations.

JAX functions cannot modify variables in-place or maintain internal state. Instead of operations like array[0] = 5, you must use functional equivalents like array.at[0].set(5) that return new arrays.

# This won't work in JAX
def bad_function(x):
    x[0] = x[0] + 1  # In-place modification
    return x

# This is the JAX way
def good_function(x):
    return x.at[0].add(1)  # Returns new array

print(good_function(jnp.asarray([4.0])))
print(bad_function(jnp.asarray([4.0])))
[5.]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 11
      8     return x.at[0].add(1)  # Returns new array
     10 print(good_function(jnp.asarray([4.0])))
---> 11 print(bad_function(jnp.asarray([4.0])))

Cell In[2], line 3, in bad_function(x)
      2 def bad_function(x):
----> 3     x[0] = x[0] + 1  # In-place modification
      4     return x

File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:599, in _unimplemented_setitem(self, i, x)
    595 def _unimplemented_setitem(self, i, x):
    596   msg = ("JAX arrays are immutable and do not support in-place item assignment."
    597          " Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:"
    598          " https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html")
--> 599   raise TypeError(msg.format(type(self)))

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

No Side Effects (pure functions)#

Functions should not print to console, write to files, or modify global variables during compilation. JAX’s Just-in-Time (JIT) compiler optimizes based on the assumption that functions are deterministic and side-effect free. As a consequence print statements are only executed during compilation (the first function call), but not afterwards.

def example_function():
    x = jnp.ones((4,))
    print(x)

jitted_fn = jax.jit(example_function)

jitted_fn() # this will print traced value of x
jitted_fn() # this executes compiled function does not print anything
JitTracer<float32[4]>

Static Shapes during computation#

All Jax arrays need to have a static shape in compiled functions (as long as the input shape does not change). This means that there is a distinction between static and dynamic data. Static data (like python scalars) do not change when called with different input values. This static data can be used in if-clauses, or alter the shape of jax arrays. Dynamic data are jax arrays with possibly arbitrary values. This dynamic data cannot be used in if-clauses or to change the shapes of other jax arrays. As a rule of thumb, the computational graph of a function can only change based on static arrays, but not jax arrays.

def if_clause(x):
    return 1.0 if x else 2.0  # computational graph changes depending on the value of x

print(jax.jit(if_clause)(jnp.asarray(True)))
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[7], line 4
      1 def if_clause(x):
      2     return 1.0 if x else 2.0#
----> 4 print(jax.jit(if_clause)(jnp.asarray(True)))

    [... skipping hidden 13 frame]

Cell In[7], line 2, in if_clause(x)
      1 def if_clause(x):
----> 2     return 1.0 if x else 2.0

    [... skipping hidden 1 frame]

File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/core.py:1721, in concretization_function_error.<locals>.error(self, arg)
   1720 def error(self, arg):
-> 1721   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function if_clause at /tmp/ipykernel_278296/1626192598.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
def indexing_fn(x):
    return jnp.asarray([4.0, 2.0, 1.0, 3.0])[:x]  # depending on value of x different array shape is returned

print(jax.jit(indexing_fn)(jnp.asarray(1)))
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[10], line 4
      1 def indexing_fn(x):
      2     return jnp.asarray([4.0, 2.0, 1.0, 3.0])[:x]
----> 4 print(jax.jit(indexing_fn)(jnp.asarray(1)))

    [... skipping hidden 13 frame]

Cell In[10], line 2, in indexing_fn(x)
      1 def indexing_fn(x):
----> 2     return jnp.asarray([4.0, 2.0, 1.0, 3.0])[:x]

File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:1125, in _forward_operator_to_aval.<locals>.op(self, *args)
   1124 def op(self, *args):
-> 1125   return getattr(self.aval, f"_{name}")(self, *args)

File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:660, in _getitem(self, item)
    659 def _getitem(self, item):
--> 660   return indexing.rewriting_take(self, item)

File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:664, in rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value, normalize_indices, out_sharding)
    660   out_sharding = canonicalize_sharding(out_sharding, 'take')
    661   return auto_axes(internal_gather, out_sharding=out_sharding,
    662                    axes=out_sharding.mesh.explicit_axes,  # type: ignore
    663                    )(arr, dynamic_idx)
--> 664 return internal_gather(arr, dynamic_idx)

File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:673, in _gather(arr, dynamic_idx, treedef, static_idx, indices_are_sorted, unique_indices, mode, fill_value, normalize_indices)
    670 def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted,
    671             unique_indices, mode, fill_value, normalize_indices):
    672   idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
--> 673   indexer = index_to_gather(np.shape(arr), idx, normalize_indices=normalize_indices)  # shared with _scatter_update
    674   jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices)
    675   y = arr

File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:940, in index_to_gather(x_shape, idx, normalize_indices)
    931 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
    932            for elt in (i.start, i.stop, i.step)):
    933   msg = ("Array slice indices must have static start/stop/step to be used "
    934          "with NumPy indexing syntax. "
    935          f"Found slice({i.start}, {i.stop}, {i.step}). "
   (...)    938          "dynamic_update_slice (JAX does not support dynamically sized "
    939          "arrays within JIT compiled functions).")
--> 940   raise IndexError(msg)
    942 start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis])
    943 slice_shape.append(slice_size)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, JitTracer<~int32[]>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

TreeClass Objects in FDTDX#

FDTDX leverages JAX’s functional programming paradigm through a specialized TreeClass system that makes it easy to work with complex hierarchical data structures while maintaining JAX compatibility. The TreeClass provides a clean, object-oriented interface that automatically integrates with JAX’s pytree system, allowing for seamless use with JAX transformations.

TreeClass Structure#

The TreeClass system uses dataclass-like syntax with the @fdtdx.autoinit decorator to automatically generate initialization methods. Here’s how it works:

@fdtdx.autoinit
class A(fdtdx.TreeClass):
    a: float = 2
    x: int = 5

@fdtdx.autoinit
class B(fdtdx.TreeClass):
    a1: A
    z: int = 7

@fdtdx.autoinit
class C(fdtdx.TreeClass):
    b_list: list[B]
    c: float = 2

These classes can be nested arbitrarily deep and contain lists, dictionaries, or other complex data structures. The @fdtdx.autoinit decorator automatically generates init methods that handle default values and type checking.

Working with TreeClass Instances#

# Create instances with default or custom values
b = B(a1=A())  # Uses defaults: A(a=2, x=5), z=7
print(b)
b = b.aset("z", 25)  # Functional update
print(b)
B(a1=A(a=2, x=5), z=7)
B(a1=A(a=2, x=5), z=25)
# Collections of TreeClass instances
b2 = B(a1=A(a=10, x=11), z=12)
b3 = B(a1=A(a=20, x=21), z=22)
c = C(b_list=[b, b2])
print(c)

# Deep nested updates using path syntax
c2 = c.aset("b_list->[0]->a1->a", 100)
print(c2)
C(b_list=[B(a1=A(a=2, x=5), z=25), B(a1=A(a=10, x=11), z=12)], c=2)
C(b_list=[B(a1=A(a=100, x=5), z=25), B(a1=A(a=10, x=11), z=12)], c=2)

The aset Method: Functional Updates Made Easy#

The aset method is the cornerstone of FDTDX’s functional approach. Unlike JAX’s standard .at[].set() which only works on pytree leaf nodes (typically arrays), aset can update any attribute at any level of nesting within a TreeClass hierarchy.

Path Syntax: The method uses an intuitive string-based path syntax to navigate nested structures:#

  • “attribute” - Direct attribute access

  • “a->b” - Nested attribute access (a.b)

  • “a->[0]” - List indexing

  • “a->[‘key’]” - Dictionary key access

  • “b_list->[0]->a1->a” - Complex nested path

In the example c2 = c.aset(“b_list->[0]->a1->a”, 100), this path means: - Access the b_list attribute of c - Get the first element [0] of that list - Access the a1 attribute of that element - Access the a attribute of a1 - Set that value to 100

The method returns a completely new instance with the updated value, maintaining JAX’s functional programming requirements. This allows FDTDX data structures to be used seamlessly with JAX transformations like jit, grad, and vmap, while providing a much more intuitive interface than manually reconstructing nested data structures. This approach bridges the gap between JAX’s powerful functional capabilities and the practical need for complex, hierarchical data management in scientific computing applications.

How JAX is used in FDTDX#

For a full example on how to use JAX with fdtdx, check out this example or this example. The script demonstrates FDTDX’s seamless integration with JAX’s jit transformation. The core simulation function sim_fn takes FDTDX TreeClass structures as arguments and is JIT-compiled:

def sim_fn(
    params: fdtdx.ParameterContainer,
    arrays: fdtdx.ArrayContainer, 
    key: jax.Array,
):
    # Complex FDTD simulation logic with TreeClass structures
    arrays, new_objects, info = fdtdx.apply_params(arrays, objects, params, key)
    final_state = fdtdx.run_fdtd(arrays=arrays, objects=new_objects, config=config, key=key)
    # ... more operations
    return arrays, new_info
jitted_loss = jax.jit(sim_fn, donate_argnames=["arrays"]).lower(params, arrays, key).compile()

JIT compilation with TreeClass arguments#

Key Features:

  • TreeClass Compatibility: The ParameterContainer and ArrayContainer are FDTDX TreeClass structures that work seamlessly with jit. JAX automatically handles the pytree registration, allowing these complex nested structures to be compiled efficiently.

  • Memory Optimization: The donate_argnames=[“arrays”] parameter tells JAX it can reuse the memory of the arrays argument, which is crucial for large electromagnetic field arrays in FDTD simulations.

  • Compilation Pipeline: The script uses .lower().compile() to explicitly control the compilation process, providing timing information for performance analysis.

While this specific example focuses on forward simulation, FDTDX is designed for gradient-based optimization. The GradientConfig setup shows how gradients would be computed:

gradient_config = fdtdx.GradientConfig(
    recorder=fdtdx.Recorder(
        modules=[fdtdx.DtypeConversion(dtype=jnp.bfloat16)]
    )
)

For gradient computation, you would typically use:

# Hypothetical gradient computation
grad_fn = jax.grad(sim_fn, argnums=0)  # Gradient w.r.t. params
gradients = grad_fn(params, arrays, key)