Source code for fdtdx.objects.detectors.detector

from abc import ABC, abstractmethod
from typing import Self

import jax
import jax.numpy as jnp
import numpy as np
from matplotlib.figure import Figure
from rich.progress import Progress

from fdtdx.colors import XKCD_LIGHT_GREEN, Color
from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, frozen_field, frozen_private_field, private_field
from fdtdx.core.switch import OnOffSwitch
from fdtdx.objects.detectors.plotting.line_plot import plot_line_over_time, plot_waterfall_over_time
from fdtdx.objects.detectors.plotting.plot2d import plot_2d_from_slices
from fdtdx.objects.detectors.plotting.video import generate_video_from_slices, plot_from_slices
from fdtdx.objects.object import SimulationObject
from fdtdx.typing import SliceTuple3D

DetectorState = dict[str, jax.Array]


[docs] @autoinit class Detector(SimulationObject, ABC): """Base class for electromagnetic field detectors in FDTD simulations. This class provides core functionality for recording and analyzing electromagnetic field data during FDTD simulations. It supports flexible timing control, data collection intervals, and visualization of results. """ #: Data type for detector arrays, defaults to float32. dtype: jnp.dtype = frozen_field(default=jnp.float32) #: Whether to use exact field interpolation. Defaults to True. exact_interpolation: bool = frozen_field(default=True) #: Whether to record fields in inverse time order. Defaults to false. inverse: bool = frozen_field(default=False) #: This switch controls the time steps that the detector is on, i.e. records data. #: Defaults to all time steps. switch: OnOffSwitch = frozen_field(default=OnOffSwitch()) #: Whether to generate plots of recorded data. Defaults to true. plot: bool = frozen_field(default=True) #: Plot inverse data in reverse time order. if_inverse_plot_backwards: bool = frozen_field(default=True) #: Number of workers for video generation. If None (default), then no #: multiprocessing is used. Note that the combination of multiprocessing and matplotlib is known to produce #: problems and can cause the entire system to freeze. It does make the video generation much faster though. num_video_workers: int | None = frozen_field(default=None) # only used when generating video #: RGB color for plotting. Defaults to light green. color: Color | None = frozen_field(default=XKCD_LIGHT_GREEN) #: Interpolation method for plots. Defualts to "gaussian". plot_interpolation: str = frozen_field(default="gaussian") #: DPI resolution for plots. Defaults to None. plot_dpi: int | None = frozen_field(default=None) _num_time_steps_on: int = frozen_private_field() _is_on_at_time_step_arr: jax.Array = private_field() _time_step_to_arr_idx: jax.Array = private_field() _cached_cell_volume_weights: jax.Array = private_field() @property def num_time_steps_recorded(self) -> int: """Gets the total number of time steps that will be recorded. Returns: int: Number of time steps where detector will record data. Raises: Exception: If detector is not yet initialized. """ if self._num_time_steps_on is None: raise Exception("Detector is not yet initialized") return self._num_time_steps_on def _cell_volume_weights(self) -> jax.Array: """Return physical cell-volume weights for this detector's grid slice.""" return self._cached_cell_volume_weights def _volume_weighted_spatial_mean(self, values: jax.Array, leading_dims: int) -> jax.Array: """Average spatial detector samples using physical cell volumes. Args: values: Array whose final three dimensions match ``grid_shape``. leading_dims: Number of leading non-spatial dimensions to preserve, such as component or frequency axes. Returns: ``values`` averaged over the three spatial dimensions. """ weights = self._cell_volume_weights() weight_shape = (1,) * leading_dims + weights.shape spatial_axes = tuple(range(leading_dims, values.ndim)) return jnp.sum(values * weights.reshape(weight_shape), axis=spatial_axes) / jnp.sum(weights) def _plot_axis_centers_um(self, axis: int) -> np.ndarray: """Return detector-local cell centers in micrometres for plotting. Rectilinear grids use the physical cell centers from ``RectilinearGrid``. The coordinates are shifted so plots start at the detector slice origin, matching the historical uniform-grid display convention. """ grid = self._config.resolved_grid if grid is not None: start, stop = self.grid_slice_tuple[axis] edges = np.asarray(grid.edges(axis)[start : stop + 1]) centers = 0.5 * (edges[:-1] + edges[1:]) return (centers - edges[0]) / 1.0e-6 spacing = self._config.uniform_spacing() return (np.arange(self.grid_shape[axis]) + 0.5) * spacing / 1.0e-6 def _plot_axis_edges_um(self, axis: int) -> np.ndarray: """Return detector-local cell edges in micrometres for slice plots.""" grid = self._config.resolved_grid if grid is not None: start, stop = self.grid_slice_tuple[axis] edges = np.asarray(grid.edges(axis)[start : stop + 1]) return (edges - edges[0]) / 1.0e-6 spacing = self._config.uniform_spacing() return np.arange(self.grid_shape[axis] + 1) * spacing / 1.0e-6 def _plot_resolutions(self) -> tuple[float, float, float]: """Return a scalar-resolution tuple for legacy plotting APIs. When a non-uniform ``RectilinearGrid`` is present this value is only a fallback for call signatures; rectilinear plots receive explicit edge arrays and do not use the scalar spacing to position cells. """ if self._config.has_nonuniform_grid: assert self._config.resolved_grid is not None spacing = self._config.resolved_grid.min_spacing return (spacing, spacing, spacing) spacing = self._config.uniform_spacing() return (spacing, spacing, spacing) def _plot_coordinate_edges_um(self) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None: """Return rectilinear detector edge coordinates when needed for plots.""" if not self._config.has_nonuniform_grid: return None return ( self._plot_axis_edges_um(0), self._plot_axis_edges_um(1), self._plot_axis_edges_um(2), ) def _calculate_on_list( self, ) -> list[bool]: return self.switch.calculate_on_list( num_total_time_steps=self._config.time_steps_total, time_step_duration=self._config.time_step_duration, ) def _num_latent_time_steps(self) -> int: """Calculates total number of time steps that will be recorded. Returns: int: Number of time steps where detector will be active. """ on_list = self._calculate_on_list() return sum(on_list) @abstractmethod def _shape_dtype_single_time_step( self, ) -> dict[str, jax.ShapeDtypeStruct]: """Gets shape and dtype information for a single time step recording. Returns: dict[str, jax.ShapeDtypeStruct]: Dictionary mapping field names to their shape and dtype specifications. Raises: NotImplementedError: Must be implemented by subclasses. """ raise NotImplementedError()
[docs] def place_on_grid( self: Self, grid_slice_tuple: SliceTuple3D, config: SimulationConfig, key: jax.Array, ) -> Self: self = super().place_on_grid( grid_slice_tuple=grid_slice_tuple, config=config, key=key, ) # determine number of time steps on on_list = self._calculate_on_list() on_arr = jnp.asarray(on_list, dtype=jnp.bool) self = self.aset("_is_on_at_time_step_arr", on_arr, create_new_ok=True) self = self.aset("_num_time_steps_on", sum(on_list), create_new_ok=True) # calculate mapping time step -> arr index counter = 0 num_t = self._config.time_steps_total time_to_arr_idx_list = [-1 for _ in range(num_t)] for t in range(num_t): if on_list[t]: time_to_arr_idx_list[t] = counter counter += 1 time_to_arr_idx_arr = jnp.asarray(time_to_arr_idx_list, dtype=jnp.int32) self = self.aset("_time_step_to_arr_idx", time_to_arr_idx_arr, create_new_ok=True) # compute cell volume weights now that config and grid_slice are available grid = self._config.resolved_grid if grid is not None: weights = grid.cell_volume(self.grid_slice_tuple) else: spacing = self._config.uniform_spacing() weights = jnp.ones(self.grid_shape, dtype=self.dtype) * spacing**3 self = self.aset("_cached_cell_volume_weights", weights, create_new_ok=True) return self
[docs] def init_state( self: Self, ) -> DetectorState: # init arrays shape_dtype_dict = self._shape_dtype_single_time_step() state = {} latent_time_size = self._num_latent_time_steps() for name, shape_dtype in shape_dtype_dict.items(): cur_arr = jnp.zeros( shape=(latent_time_size, *shape_dtype.shape), dtype=shape_dtype.dtype, ) state[name] = cur_arr return state
[docs] @abstractmethod def update( self, time_step: jax.Array, E: jax.Array, H: jax.Array, state: DetectorState, inv_permittivity: jax.Array, inv_permeability: jax.Array | float, ) -> DetectorState: """Updates detector state with current field values. Args: time_step (jax.Array): Current simulation time step. E (jax.Array): Electric field array. H (jax.Array): Magnetic field array. state (DetectorState): Current detector state. inv_permittivity (jax.Array): Inverse permittivity array. inv_permeability (jax.Array | float): Inverse permeability array. Returns: DetectorState: Updated detector state. """ del ( time_step, E, H, state, inv_permittivity, inv_permeability, ) raise NotImplementedError()
[docs] def draw_plot( self, state: dict[str, np.ndarray], progress: Progress | None = None, ) -> dict[str, Figure | str]: """Generates plots or videos from recorded detector data. Creates visualizations based on dimensionality of recorded data and detector configuration. Supports 1D line plots, 2D heatmaps, and video generation for time-varying data. Args: state (dict[str, np.ndarray]): Dictionary containing recorded field data arrays. progress (Progress | None, optional): Optional progress bar for video generation. Returns: dict[str, Figure | str]: Dictionary mapping plot names to either matplotlib Figure objects or paths to generated video files. """ squeezed_arrs = {} squeezed_ndim = None for k, v in state.items(): v_squeezed = v.squeeze() if self.inverse and self.if_inverse_plot_backwards and self.num_time_steps_recorded > 1: squeezed_arrs[k] = np.asarray(v_squeezed[::-1, ...]) else: squeezed_arrs[k] = np.asarray(v_squeezed) if squeezed_ndim is None: squeezed_ndim = len(v_squeezed.shape) else: if len(v_squeezed.shape) != squeezed_ndim: raise Exception("Cannot plot multiple arrays with different ndim") if squeezed_ndim is None: raise Exception(f"empty state: {state}") figs = {} if squeezed_ndim == 1 and self.num_time_steps_recorded > 1: # do line plot time_steps = np.where(np.asarray(self._is_on_at_time_step_arr))[0] time_steps = time_steps * self._config.time_step_duration for k, v in squeezed_arrs.items(): fig = plot_line_over_time(arr=v, time_steps=time_steps.tolist(), metric_name=f"{self.name}: {k}") figs[k] = fig elif squeezed_ndim == 1 and self.num_time_steps_recorded == 1: xlabel = None spatial_axis = None if self.grid_shape[0] > 1 and self.grid_shape[1] <= 1 and self.grid_shape[2] <= 1: xlabel = "X axis (μm)" spatial_axis = self._plot_axis_centers_um(0) elif self.grid_shape[0] <= 1 and self.grid_shape[1] > 1 and self.grid_shape[2] <= 1: xlabel = "Y axis (μm)" spatial_axis = self._plot_axis_centers_um(1) elif self.grid_shape[0] <= 1 and self.grid_shape[1] <= 1 and self.grid_shape[2] > 1: xlabel = "Z axis (μm)" spatial_axis = self._plot_axis_centers_um(2) assert xlabel is not None, "This should never happen" assert spatial_axis is not None, "This should never happen" for k, v in squeezed_arrs.items(): fig = plot_line_over_time( arr=v, time_steps=spatial_axis, metric_name=f"{self.name}: {k}", xlabel=xlabel ) figs[k] = fig elif squeezed_ndim == 2 and self.num_time_steps_recorded > 1: # multiple time steps, 1d spatial data - visualize as 2D waterfall plot time_steps = np.where(np.asarray(self._is_on_at_time_step_arr))[0] time_steps = time_steps * self._config.time_step_duration for k, v in squeezed_arrs.items(): # Determine which dimension is spatial (not time) spatial_dim = 1 if v.shape[1] > 1 else 0 if spatial_dim == 0: # Transpose if needed so time is always first dimension v = v.T active_axes = [axis for axis, size in enumerate(self.grid_shape) if size > 1] if len(active_axes) != 1: raise Exception(f"Cannot infer one spatial plotting axis for grid shape {self.grid_shape}") spatial_points = self._plot_axis_centers_um(active_axes[0]) fig = plot_waterfall_over_time( arr=v, time_steps=time_steps, spatial_steps=spatial_points, metric_name=f"{self.name}: {k}", spatial_unit="μm", ) figs[k] = fig elif squeezed_ndim == 2 and self.num_time_steps_recorded == 1: # single time step, 2d-plot # TODO: if all([x in squeezed_arrs.keys() for x in ["XY Plane", "XZ Plane", "YZ Plane"]]): fig = plot_2d_from_slices( xy_slice=squeezed_arrs["XY Plane"], xz_slice=squeezed_arrs["XZ Plane"], yz_slice=squeezed_arrs["YZ Plane"], resolutions=self._plot_resolutions(), coordinate_edges_um=self._plot_coordinate_edges_um(), plot_dpi=self.plot_dpi, plot_interpolation=self.plot_interpolation, ) figs["sliced_plot"] = fig else: raise Exception(f"Cannot plot {squeezed_arrs.keys()}") elif squeezed_ndim == 3 and self.num_time_steps_recorded > 1: # multiple time steps, 2d-plots if all([x in squeezed_arrs.keys() for x in ["XY Plane", "XZ Plane", "YZ Plane"]]): path = generate_video_from_slices( plt_fn=plot_from_slices, xy_slice=squeezed_arrs["XY Plane"], xz_slice=squeezed_arrs["XZ Plane"], yz_slice=squeezed_arrs["YZ Plane"], progress=progress, num_worker=self.num_video_workers, resolutions=self._plot_resolutions(), coordinate_edges_um=self._plot_coordinate_edges_um(), plot_dpi=self.plot_dpi, plot_interpolation=self.plot_interpolation, ) figs["sliced_video"] = path else: raise Exception( f"Cannot plot {squeezed_arrs.keys()}. " f"Consider setting plot=False for Object {self.name} ({self.__class__=})" ) elif squeezed_ndim == 3 and self.num_time_steps_recorded == 1: # single step, 3d-plot. # TODO: do as mean over planes for k, v in squeezed_arrs.items(): xy_slice = squeezed_arrs[k].mean(axis=2) xz_slice = squeezed_arrs[k].mean(axis=1) yz_slice = squeezed_arrs[k].mean(axis=0) fig = plot_2d_from_slices( xy_slice=xy_slice, xz_slice=xz_slice, yz_slice=yz_slice, resolutions=self._plot_resolutions(), coordinate_edges_um=self._plot_coordinate_edges_um(), plot_dpi=self.plot_dpi, plot_interpolation=self.plot_interpolation, ) figs[k] = fig elif squeezed_ndim == 4 and self.num_time_steps_recorded > 1: # video with 3d-volume in each time step. plot as slices for k, v in squeezed_arrs.items(): xy_slice = squeezed_arrs[k].mean(axis=3) xz_slice = squeezed_arrs[k].mean(axis=2) yz_slice = squeezed_arrs[k].mean(axis=1) path = generate_video_from_slices( plt_fn=plot_from_slices, xy_slice=xy_slice, xz_slice=xz_slice, yz_slice=yz_slice, progress=progress, num_worker=self.num_video_workers, resolutions=self._plot_resolutions(), coordinate_edges_um=self._plot_coordinate_edges_um(), plot_dpi=self.plot_dpi, plot_interpolation=self.plot_interpolation, ) figs[k] = path else: raise Exception("Cannot plot detector with more than three dimensions") return figs