Introduction to JAX#
JAX is a high-performance numerical computing library developed by Google that brings together the familiar NumPy API with powerful features like automatic differentiation, just-in-time (JIT) compilation, and seamless GPU/TPU acceleration. Originally designed for machine learning research, JAX has become popular across scientific computing applications due to its speed and flexibility.
Jax itself provides a good introduction here and here. Otherwise, the following is a small crash course.
import jax
import jax.numpy as jnp
import fdtdx
Functional Programming Paradigm#
JAX operates exclusively in a functional programming style, which means it requires you to write pure functions without side effects. This functional approach has several important implications:
Immutable data#
Arrays and other data structures are treated as immutable. Operations create new objects rather than modifying existing ones, similar to how NumPy handles broadcasting operations.
This functional constraint enables JAX’s powerful transformations like jit (compilation), grad (automatic differentiation), vmap (vectorization), and pmap (parallelization). While the functional style requires some adjustment if you’re used to imperative programming, it unlocks JAX’s ability to automatically optimize and transform your numerical code in ways that would be impossible with stateful operations.
JAX functions cannot modify variables in-place or maintain internal state. Instead of operations like array[0] = 5, you must use functional equivalents like array.at[0].set(5) that return new arrays.
# This won't work in JAX
def bad_function(x):
x[0] = x[0] + 1 # In-place modification
return x
# This is the JAX way
def good_function(x):
return x.at[0].add(1) # Returns new array
print(good_function(jnp.asarray([4.0])))
print(bad_function(jnp.asarray([4.0])))
[5.]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 11
8 return x.at[0].add(1) # Returns new array
10 print(good_function(jnp.asarray([4.0])))
---> 11 print(bad_function(jnp.asarray([4.0])))
Cell In[2], line 3, in bad_function(x)
2 def bad_function(x):
----> 3 x[0] = x[0] + 1 # In-place modification
4 return x
File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:599, in _unimplemented_setitem(self, i, x)
595 def _unimplemented_setitem(self, i, x):
596 msg = ("JAX arrays are immutable and do not support in-place item assignment."
597 " Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:"
598 " https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html")
--> 599 raise TypeError(msg.format(type(self)))
TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html
No Side Effects (pure functions)#
Functions should not print to console, write to files, or modify global variables during compilation. JAX’s Just-in-Time (JIT) compiler optimizes based on the assumption that functions are deterministic and side-effect free. As a consequence print statements are only executed during compilation (the first function call), but not afterwards.
def example_function():
x = jnp.ones((4,))
print(x)
jitted_fn = jax.jit(example_function)
jitted_fn() # this will print traced value of x
jitted_fn() # this executes compiled function does not print anything
JitTracer<float32[4]>
Static Shapes during computation#
All Jax arrays need to have a static shape in compiled functions (as long as the input shape does not change). This means that there is a distinction between static and dynamic data. Static data (like python scalars) do not change when called with different input values. This static data can be used in if-clauses, or alter the shape of jax arrays. Dynamic data are jax arrays with possibly arbitrary values. This dynamic data cannot be used in if-clauses or to change the shapes of other jax arrays. As a rule of thumb, the computational graph of a function can only change based on static arrays, but not jax arrays.
def if_clause(x):
return 1.0 if x else 2.0 # computational graph changes depending on the value of x
print(jax.jit(if_clause)(jnp.asarray(True)))
---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[7], line 4
1 def if_clause(x):
2 return 1.0 if x else 2.0#
----> 4 print(jax.jit(if_clause)(jnp.asarray(True)))
[... skipping hidden 13 frame]
Cell In[7], line 2, in if_clause(x)
1 def if_clause(x):
----> 2 return 1.0 if x else 2.0
[... skipping hidden 1 frame]
File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/core.py:1721, in concretization_function_error.<locals>.error(self, arg)
1720 def error(self, arg):
-> 1721 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function if_clause at /tmp/ipykernel_278296/1626192598.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
def indexing_fn(x):
return jnp.asarray([4.0, 2.0, 1.0, 3.0])[:x] # depending on value of x different array shape is returned
print(jax.jit(indexing_fn)(jnp.asarray(1)))
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[10], line 4
1 def indexing_fn(x):
2 return jnp.asarray([4.0, 2.0, 1.0, 3.0])[:x]
----> 4 print(jax.jit(indexing_fn)(jnp.asarray(1)))
[... skipping hidden 13 frame]
Cell In[10], line 2, in indexing_fn(x)
1 def indexing_fn(x):
----> 2 return jnp.asarray([4.0, 2.0, 1.0, 3.0])[:x]
File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:1125, in _forward_operator_to_aval.<locals>.op(self, *args)
1124 def op(self, *args):
-> 1125 return getattr(self.aval, f"_{name}")(self, *args)
File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:660, in _getitem(self, item)
659 def _getitem(self, item):
--> 660 return indexing.rewriting_take(self, item)
File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:664, in rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value, normalize_indices, out_sharding)
660 out_sharding = canonicalize_sharding(out_sharding, 'take')
661 return auto_axes(internal_gather, out_sharding=out_sharding,
662 axes=out_sharding.mesh.explicit_axes, # type: ignore
663 )(arr, dynamic_idx)
--> 664 return internal_gather(arr, dynamic_idx)
File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:673, in _gather(arr, dynamic_idx, treedef, static_idx, indices_are_sorted, unique_indices, mode, fill_value, normalize_indices)
670 def _gather(arr, dynamic_idx, *, treedef, static_idx, indices_are_sorted,
671 unique_indices, mode, fill_value, normalize_indices):
672 idx = merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
--> 673 indexer = index_to_gather(np.shape(arr), idx, normalize_indices=normalize_indices) # shared with _scatter_update
674 jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices)
675 y = arr
File ~/nobackup/fdtdx-notebooks/.venv/lib/python3.12/site-packages/jax/_src/numpy/indexing.py:940, in index_to_gather(x_shape, idx, normalize_indices)
931 if not all(_is_slice_element_none_or_constant_or_symbolic(elt)
932 for elt in (i.start, i.stop, i.step)):
933 msg = ("Array slice indices must have static start/stop/step to be used "
934 "with NumPy indexing syntax. "
935 f"Found slice({i.start}, {i.stop}, {i.step}). "
(...) 938 "dynamic_update_slice (JAX does not support dynamically sized "
939 "arrays within JIT compiled functions).")
--> 940 raise IndexError(msg)
942 start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis])
943 slice_shape.append(slice_size)
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, JitTracer<~int32[]>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
TreeClass Objects in FDTDX#
FDTDX leverages JAX’s functional programming paradigm through a specialized TreeClass system that makes it easy to work with complex hierarchical data structures while maintaining JAX compatibility. The TreeClass provides a clean, object-oriented interface that automatically integrates with JAX’s pytree system, allowing for seamless use with JAX transformations.
TreeClass Structure#
The TreeClass system uses dataclass-like syntax with the @fdtdx.autoinit decorator to automatically generate initialization methods. Here’s how it works:
@fdtdx.autoinit
class A(fdtdx.TreeClass):
a: float = 2
x: int = 5
@fdtdx.autoinit
class B(fdtdx.TreeClass):
a1: A
z: int = 7
@fdtdx.autoinit
class C(fdtdx.TreeClass):
b_list: list[B]
c: float = 2
These classes can be nested arbitrarily deep and contain lists, dictionaries, or other complex data structures. The @fdtdx.autoinit decorator automatically generates init methods that handle default values and type checking.
Working with TreeClass Instances#
# Create instances with default or custom values
b = B(a1=A()) # Uses defaults: A(a=2, x=5), z=7
print(b)
b = b.aset("z", 25) # Functional update
print(b)
B(a1=A(a=2, x=5), z=7)
B(a1=A(a=2, x=5), z=25)
# Collections of TreeClass instances
b2 = B(a1=A(a=10, x=11), z=12)
b3 = B(a1=A(a=20, x=21), z=22)
c = C(b_list=[b, b2])
print(c)
# Deep nested updates using path syntax
c2 = c.aset("b_list->[0]->a1->a", 100)
print(c2)
C(b_list=[B(a1=A(a=2, x=5), z=25), B(a1=A(a=10, x=11), z=12)], c=2)
C(b_list=[B(a1=A(a=100, x=5), z=25), B(a1=A(a=10, x=11), z=12)], c=2)
The aset Method: Functional Updates Made Easy#
The aset method is the cornerstone of FDTDX’s functional approach. Unlike JAX’s standard .at[].set() which only works on pytree leaf nodes (typically arrays), aset can update any attribute at any level of nesting within a TreeClass hierarchy.
How JAX is used in FDTDX#
For a full example on how to use JAX with fdtdx, check out this example or this example. The script demonstrates FDTDX’s seamless integration with JAX’s jit transformation. The core simulation function sim_fn takes FDTDX TreeClass structures as arguments and is JIT-compiled:
def sim_fn(
params: fdtdx.ParameterContainer,
arrays: fdtdx.ArrayContainer,
key: jax.Array,
):
# Complex FDTD simulation logic with TreeClass structures
arrays, new_objects, info = fdtdx.apply_params(arrays, objects, params, key)
final_state = fdtdx.run_fdtd(arrays=arrays, objects=new_objects, config=config, key=key)
# ... more operations
return arrays, new_info
jitted_loss = jax.jit(sim_fn, donate_argnames=["arrays"]).lower(params, arrays, key).compile()
JIT compilation with TreeClass arguments#
Key Features:
TreeClass Compatibility: The ParameterContainer and ArrayContainer are FDTDX TreeClass structures that work seamlessly with jit. JAX automatically handles the pytree registration, allowing these complex nested structures to be compiled efficiently.
Memory Optimization: The donate_argnames=[“arrays”] parameter tells JAX it can reuse the memory of the arrays argument, which is crucial for large electromagnetic field arrays in FDTD simulations.
Compilation Pipeline: The script uses .lower().compile() to explicitly control the compilation process, providing timing information for performance analysis.
While this specific example focuses on forward simulation, FDTDX is designed for gradient-based optimization. The GradientConfig setup shows how gradients would be computed:
gradient_config = fdtdx.GradientConfig(
recorder=fdtdx.Recorder(
modules=[fdtdx.DtypeConversion(dtype=jnp.bfloat16)]
)
)
For gradient computation, you would typically use:
# Hypothetical gradient computation
grad_fn = jax.grad(sim_fn, argnums=0) # Gradient w.r.t. params
gradients = grad_fn(params, arrays, key)