Source code for fdtdx.config
import math
from typing import Literal
import jax
import jax.numpy as jnp
from loguru import logger
from fdtdx import constants
from fdtdx.core.grid import RectilinearGrid, UniformGrid
from fdtdx.core.jax.pytrees import TreeClass, autoinit, field, frozen_field
from fdtdx.interfaces.recorder import Recorder
from fdtdx.typing import BackendOption
[docs]
@autoinit
class GradientConfig(TreeClass):
"""Configuration for gradient computation in simulations.
This class handles settings for automatic differentiation, supporting either
invertible differentiation with a recorder or checkpointing-based differentiation.
"""
#: Method for gradient computation.
#: Can be either "reversible" when using the time reversible autodiff, or "checkpointed" for the exact checkpointing algorithm.
method: Literal["reversible", "checkpointed"] = frozen_field(default="reversible")
#: Optional recorder for invertible differentiation. Needs to be provided for reversible autodiff. Defaults to None
recorder: Recorder | None = field(default=None)
#: Optional number of checkpoints for checkpointing-based differentiation.
#: Needs to be provided for checkpointing gradient computation. Defaults to None.
num_checkpoints: int | None = frozen_field(default=None)
def __post_init__(self):
if self.method == "reversible" and self.recorder is None:
raise Exception("Need Recorder in gradient config to compute reversible gradients")
if self.method == "checkpointed" and self.num_checkpoints is None:
raise Exception("Need Checkpoint Number in gradient config to compute checkpointed gradients")
[docs]
@autoinit
class SimulationConfig(TreeClass):
"""Configuration settings for FDTD simulations.
This class contains all the parameters needed to configure and run an FDTD
simulation, including spatial and temporal discretization, hardware backend,
and gradient computation settings.
"""
#: Total simulation time in seconds.
time: float = frozen_field()
#: Spatial grid configuration.
#:
#: ``UniformGrid`` is an unresolved policy used while the final volume shape
#: is still being inferred. ``RectilinearGrid`` is the realized solver grid
#: with explicit physical edge coordinates. Placement resolves policies to
#: ``RectilinearGrid`` so compiled FDTD code has exactly one metric source.
grid: UniformGrid | RectilinearGrid = field()
#: Computation backend ('gpu', 'tpu', 'cpu' or 'METAL'). Defaults to "gpu".
backend: BackendOption = frozen_field(default="gpu")
#: Data type for numerical computations. Defaults to jnp.float32.
dtype: jnp.dtype = frozen_field(default=jnp.float32)
#: Whether to use complex-valued field arrays.
#: None (default): auto-detect based on boundary conditions (e.g. Bloch).
#: True: force complex fields (complex64 if dtype=float32, complex128 if dtype=float64).
#: False: force real fields (raises error if Bloch boundaries are present).
use_complex_fields: bool | None = frozen_field(default=None)
#: Safety factor for the Courant condition (default: 0.99).
courant_factor: float = frozen_field(default=0.99)
#: Optional configuration for gradient computation.
gradient_config: GradientConfig | None = field(default=None)
def __post_init__(self):
from jax import extend
current_platform = extend.backend.get_backend().platform
if current_platform == "METAL" and self.backend == "gpu":
self.backend = "METAL"
if self.backend == "METAL":
try:
jax.devices()
if __name__ == "__main__":
logger.info("METAL device found and will be used for computations")
jax.config.update("jax_platform_name", "metal")
except RuntimeError:
if __name__ == "__main__":
logger.warning("METAL initialization failed, falling back to CPU!")
self.backend = "cpu"
elif self.backend in ["gpu", "tpu"]:
try:
jax.devices(self.backend)
if __name__ == "__main__":
logger.info(f"{str.upper(self.backend)} found and will be used for computations")
jax.config.update("jax_platform_name", self.backend)
except RuntimeError:
if __name__ == "__main__":
logger.warning(f"{str.upper(self.backend)} not found, falling back to CPU!")
self.backend = "cpu"
if self.backend == "cpu":
jax.config.update("jax_platform_name", "cpu")
@property
def courant_number(self) -> float:
"""Calculate the Courant number for the simulation.
The Courant number is a dimensionless quantity that determines stability
of the FDTD simulation. It represents the ratio of the physical propagation
speed to the numerical propagation speed.
Returns:
float: The Courant number, scaled by the courant_factor and normalized
for 3D simulations.
"""
return self.courant_factor / math.sqrt(3)
[docs]
def resolve_grid(self, shape: tuple[int, int, int] | None = None) -> RectilinearGrid:
"""Return a concrete solver grid.
Args:
shape: Required when ``grid`` is an unresolved ``UniformGrid``.
Returns:
A concrete ``RectilinearGrid``.
"""
if isinstance(self.grid, RectilinearGrid):
return self.grid
if shape is None:
raise ValueError("A grid shape is required to resolve UniformGrid.")
return self.grid.resolve(shape)
@property
def resolved_grid(self) -> RectilinearGrid | None:
"""Return the concrete solver grid, or ``None`` if not yet resolved.
``UniformGrid`` has no edge arrays until the simulation shape is known.
Callers that need coordinates, areas, or volumes should use this
property and fall back to ``uniform_spacing`` when it returns ``None``.
"""
if isinstance(self.grid, RectilinearGrid):
return self.grid
return None
@property
def has_nonuniform_grid(self) -> bool:
"""Whether the realized solver grid is non-uniform."""
grid = self.resolved_grid
return grid is not None and not grid.is_uniform
@property
def time_step_duration(self) -> float:
"""Calculate the duration of a single time step.
The time step duration is determined by the Courant condition to ensure
numerical stability. Realized rectilinear grids use their smallest
per-axis spacings. Unresolved uniform grids use their configured scalar
spacing.
Returns:
float: Time step duration in seconds, calculated using the Courant
condition and spatial resolution.
"""
if isinstance(self.grid, RectilinearGrid):
return self.grid.cfl_time_step(self.courant_factor)
return self.courant_number * self.grid.spacing / constants.c
@property
def time_steps_total(self) -> int:
"""Calculate the total number of time steps for the simulation.
Determines how many discrete time steps are needed to simulate the
specified total simulation time, based on the time step duration.
Returns:
int: Total number of time steps needed to reach the specified
simulation time.
"""
return round(self.time / self.time_step_duration)
@property
def max_travel_distance(self) -> float:
"""Calculate the maximum distance light can travel during the simulation.
This represents the theoretical maximum distance that light could travel
through the simulation volume, useful for determining if the simulation
time is sufficient for light to traverse the entire domain.
Returns:
float: Maximum travel distance in meters, based on the speed of light
and total simulation time.
"""
return constants.c * self.time
@property
def only_forward(self) -> bool:
"""Check if the simulation is forward-only (no gradient computation).
Forward-only simulations don't compute gradients and are used when only
the forward propagation of electromagnetic fields is needed, without
optimization.
Returns:
bool: True if no gradient configuration is specified, False otherwise.
"""
return self.gradient_config is None
@property
def invertible_optimization(self) -> bool:
"""Check if invertible optimization is enabled.
Invertible optimization uses time-reversibility of Maxwell's equations
to compute gradients with reduced memory requirements compared to
checkpointing-based methods.
Returns:
bool: True if gradient computation uses invertible differentiation
(recorder is specified), False otherwise.
"""
if self.gradient_config is None:
return False
return self.gradient_config.recorder is not None
DUMMY_SIMULATION_CONFIG = SimulationConfig(
time=-1,
grid=UniformGrid(spacing=1),
)