Source code for fdtdx.fdtd.container

"""Container module for managing collections of simulation objects and arrays.

This module provides container classes for organizing and managing simulation objects
and array data within FDTD simulations. It includes support for different object types
like sources, detectors, PML boundaries, Bloch/periodic boundaries, and devices.
"""

from typing import Callable, Self

import jax
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import TreeClass, autoinit, frozen_field
from fdtdx.interfaces.state import RecordingState
from fdtdx.materials import Material
from fdtdx.objects.boundaries.bloch import BlochBoundary
from fdtdx.objects.boundaries.boundary import BaseBoundary
from fdtdx.objects.boundaries.pec import PerfectElectricConductor
from fdtdx.objects.boundaries.perfectly_matched_layer import PerfectlyMatchedLayer
from fdtdx.objects.boundaries.pmc import PerfectMagneticConductor
from fdtdx.objects.detectors.detector import Detector, DetectorState
from fdtdx.objects.device.device import Device
from fdtdx.objects.object import SimulationObject
from fdtdx.objects.sources.source import Source
from fdtdx.objects.static_material.static import StaticMultiMaterialObject, UniformMaterialObject

# Type alias for parameter dictionaries containing JAX arrays
ParameterContainer = dict[str, dict[str, jax.Array] | jax.Array]


[docs] @autoinit class ObjectContainer(TreeClass): """Container for managing simulation objects and their relationships. This class provides a structured way to organize and access different types of simulation objects like sources, detectors, PML/periodic boundaries and devices. It maintains object lists and provides filtered access to specific object types. """ #: List of all simulation objects in the container. object_list: list[SimulationObject] #: Index of the volume object in the object list. volume_idx: int = frozen_field() @property def volume(self) -> SimulationObject: return self.object_list[self.volume_idx] @property def objects(self) -> list[SimulationObject]: return self.object_list @property def static_material_objects(self) -> list[UniformMaterialObject | StaticMultiMaterialObject]: return [o for o in self.objects if isinstance(o, (UniformMaterialObject, StaticMultiMaterialObject))] @property def sources(self) -> list[Source]: return [o for o in self.objects if isinstance(o, Source)] @property def devices(self) -> list[Device]: return [o for o in self.objects if isinstance(o, Device)] @property def detectors(self) -> list[Detector]: return [o for o in self.objects if isinstance(o, Detector)] @property def forward_detectors(self) -> list[Detector]: return [o for o in self.detectors if not o.inverse] @property def backward_detectors(self) -> list[Detector]: return [o for o in self.detectors if o.inverse] @property def pml_objects(self) -> list[PerfectlyMatchedLayer]: return [o for o in self.objects if isinstance(o, PerfectlyMatchedLayer)] @property def periodic_objects(self) -> list[BlochBoundary]: return [o for o in self.objects if isinstance(o, BlochBoundary) and not o.needs_complex_fields] @property def pec_objects(self) -> list[PerfectElectricConductor]: return [o for o in self.objects if isinstance(o, PerfectElectricConductor)] @property def pmc_objects(self) -> list[PerfectMagneticConductor]: return [o for o in self.objects if isinstance(o, PerfectMagneticConductor)] @property def bloch_objects(self) -> list[BlochBoundary]: return [o for o in self.objects if isinstance(o, BlochBoundary)] @property def boundary_objects(self) -> list[BaseBoundary]: return [o for o in self.objects if isinstance(o, BaseBoundary)] @property def all_objects_non_magnetic(self) -> bool: def _fn(m: Material): return not m.is_magnetic return self._is_material_fn_true_for_all(_fn) @property def all_objects_non_electrically_conductive(self) -> bool: def _fn(m: Material): return not m.is_electrically_conductive return self._is_material_fn_true_for_all(_fn) @property def all_objects_non_magnetically_conductive(self) -> bool: def _fn(m: Material): return not m.is_magnetically_conductive return self._is_material_fn_true_for_all(_fn) @property def all_objects_isotropic_permittivity(self) -> bool: def _fn(m: Material): return m.is_isotropic_permittivity return self._is_material_fn_true_for_all(_fn) @property def all_objects_isotropic_permeability(self) -> bool: def _fn(m: Material): return m.is_isotropic_permeability return self._is_material_fn_true_for_all(_fn) @property def all_objects_isotropic_electric_conductivity(self) -> bool: def _fn(m: Material): return m.is_isotropic_electric_conductivity return self._is_material_fn_true_for_all(_fn) @property def all_objects_isotropic_magnetic_conductivity(self) -> bool: def _fn(m: Material): return m.is_isotropic_magnetic_conductivity return self._is_material_fn_true_for_all(_fn) @property def all_objects_diagonally_anisotropic_permittivity(self) -> bool: def _fn(m: Material): return m.is_diagonally_anisotropic_permittivity return self._is_material_fn_true_for_all(_fn) @property def all_objects_diagonally_anisotropic_permeability(self) -> bool: def _fn(m: Material): return m.is_diagonally_anisotropic_permeability return self._is_material_fn_true_for_all(_fn) @property def all_objects_diagonally_anisotropic_electric_conductivity(self) -> bool: def _fn(m: Material): return m.is_diagonally_anisotropic_electric_conductivity return self._is_material_fn_true_for_all(_fn) @property def all_objects_diagonally_anisotropic_magnetic_conductivity(self) -> bool: def _fn(m: Material): return m.is_diagonally_anisotropic_magnetic_conductivity return self._is_material_fn_true_for_all(_fn) @property def all_objects_non_dispersive(self) -> bool: def _fn(m: Material): return not m.is_dispersive return self._is_material_fn_true_for_all(_fn) @property def max_num_dispersive_poles(self) -> int: """Maximum number of dispersive poles required across all objects. Walks every object (UniformMaterialObject, Device, StaticMultiMaterialObject) and returns the largest pole count of any Material attached to them. Drives the leading dimension of the per-cell dispersive coefficient and polarization arrays, which are zero-padded for materials with fewer poles. """ n = 0 for o in self.objects: if isinstance(o, UniformMaterialObject): disp = o.material.dispersion if disp is not None: n = max(n, disp.num_poles) elif isinstance(o, Device): for m in o.materials.values(): if m.dispersion is not None: n = max(n, m.dispersion.num_poles) elif isinstance(o, StaticMultiMaterialObject): for m in o.materials.values(): if m.dispersion is not None: n = max(n, m.dispersion.num_poles) return n def _is_material_fn_true_for_all( self, fn: Callable[[Material], bool], ) -> bool: for o in self.objects: if isinstance(o, UniformMaterialObject): m = o.material elif isinstance(o, Device): m = o.materials elif isinstance(o, StaticMultiMaterialObject): m = o.materials else: continue if isinstance(m, Material): if not fn(m): return False elif isinstance(m, dict): for v in m.values(): if not fn(v): return False return True def __iter__(self): return iter(self.object_list) def __getitem__( self, key: str, ) -> SimulationObject: for o in self.objects: if o.name == key: return o raise ValueError(f"Key {key} does not exist in object list: {[o.name for o in self.objects]}") def __contains__( self, key: str, ) -> bool: for o in self.objects: if o.name == key: return True return False def __setitem__( self, key: str, val: SimulationObject, ): idx = self.index(key) self.object_list[idx] = val
[docs] def index(self, name: str) -> int: for idx, o in enumerate(self.object_list): if o.name == name: return idx raise ValueError(f"Object '{name}' does not exist in object list: {[o.name for o in self.objects]}")
[docs] def copy( self, ) -> "ObjectContainer": new_list = self.object_list.copy() return ObjectContainer( object_list=new_list, volume_idx=self.volume_idx, )
[docs] def replace_sources( self, sources: list[Source], ) -> Self: new_objects = [o for o in self.objects if o not in self.sources] + sources self = self.aset("object_list", new_objects) return self
@autoinit class FieldState(TreeClass): """Dynamic electromagnetic field state that evolves each time step. Grouping these together makes it impossible to forget a field when resetting simulation state — ArrayContainer.reset() zeroes this entire struct at once. """ #: Electric field array. E: jax.Array #: Magnetic field array. H: jax.Array #: PML auxiliary electric field. psi_E: jax.Array #: PML auxiliary magnetic field. psi_H: jax.Array
[docs] @autoinit class ArrayContainer(TreeClass): """Container for simulation field arrays and states. This class holds the electromagnetic field arrays and various state information needed during FDTD simulation. It includes the E and H fields, material properties, and states for boundaries, detectors and recordings. """ #: Dynamic electromagnetic fields (E, H and PML auxiliaries). fields: FieldState #: Alpha array for PML calculations. alpha: jax.Array #: Kappa array for PML calculations. kappa: jax.Array #: Sigma array for PML calculations. sigma: jax.Array #: Inverse permittivity values array. inv_permittivities: jax.Array #: Inverse permeability values array. inv_permeabilities: jax.Array | float #: Dictionary mapping detector names to their states. detector_states: dict[str, DetectorState] #: Optional state for recording simulation data. recording_state: RecordingState | None #: field for electric conductivity terms. Defaults to None. electric_conductivity: jax.Array | None = None #: field for magnetic conductivity terms. Defaults to None. magnetic_conductivity: jax.Array | None = None #: Dispersive ADE polarization state at time step ``n``. Shape #: ``(num_poles, 3, Nx, Ny, Nz)``. ``None`` for non-dispersive simulations. dispersive_P_curr: jax.Array | None = None #: Dispersive ADE polarization state at time step ``n-1``. Shape #: ``(num_poles, 3, Nx, Ny, Nz)``. ``None`` for non-dispersive simulations. dispersive_P_prev: jax.Array | None = None #: Per-cell dispersive recurrence coefficient c1. Shape #: ``(num_poles, 1, Nx, Ny, Nz)``. ``None`` for non-dispersive simulations. dispersive_c1: jax.Array | None = None #: Per-cell dispersive recurrence coefficient c2. Shape #: ``(num_poles, 1, Nx, Ny, Nz)``. ``None`` for non-dispersive simulations. dispersive_c2: jax.Array | None = None #: Per-cell dispersive recurrence coefficient c3. Shape #: ``(num_poles, 1, Nx, Ny, Nz)``. ``None`` for non-dispersive simulations. dispersive_c3: jax.Array | None = None #: Per-cell cached ``1 / c2`` with non-dispersive cells set to 0. Lets the #: reverse-time ADE update avoid a ``jnp.where`` + division per step. #: Derived from ``dispersive_c2``; never differentiated independently. #: Shape ``(num_poles, 1, Nx, Ny, Nz)``. ``None`` for non-dispersive simulations. dispersive_inv_c2: jax.Array | None = None
[docs] def reset( self, reset_detector_states: bool = True, reset_recording_state: bool = False, ) -> "ArrayContainer": """Return a reset copy of this array container. Dynamic field arrays are zeroed while material arrays and conductivity arrays are preserved. Detector states are reset by default because they accumulate time-dependent measurements. Recording state is preserved by default so partial simulations can continue writing to the same buffers. Args: reset_detector_states: Whether to zero all detector state arrays. Defaults to True. reset_recording_state: Whether to zero recording data and state arrays when a recording state is present. Defaults to False. Returns: A new ArrayContainer with reset dynamic state. """ arrays = self.aset("fields", jax.tree.map(jnp.zeros_like, self.fields)) # Dispersive ADE polarization is also dynamic per-timestep state and must # be zeroed alongside E/H. Coefficient arrays (c1/c2/c3/inv_c2) are # material properties and are preserved. if arrays.dispersive_P_curr is not None: arrays = arrays.aset("dispersive_P_curr", jnp.zeros_like(arrays.dispersive_P_curr)) if arrays.dispersive_P_prev is not None: arrays = arrays.aset("dispersive_P_prev", jnp.zeros_like(arrays.dispersive_P_prev)) detector_states = self.detector_states if reset_detector_states: detector_states = {k: {k2: v2 * 0 for k2, v2 in v.items()} for k, v in detector_states.items()} arrays = arrays.aset("detector_states", detector_states) recording_state = self.recording_state if reset_recording_state and self.recording_state is not None: recording_state = RecordingState( data={k: v * 0 for k, v in self.recording_state.data.items()}, state={k: v * 0 for k, v in self.recording_state.state.items()}, ) arrays = arrays.aset("recording_state", recording_state) return arrays
# time step and arrays SimulationState = tuple[jax.Array, ArrayContainer]