from abc import ABC, abstractmethod
from typing import Self
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib.figure import Figure
from rich.progress import Progress
from fdtdx.colors import XKCD_LIGHT_GREEN, Color
from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, frozen_field, frozen_private_field, private_field
from fdtdx.core.switch import OnOffSwitch
from fdtdx.objects.detectors.plotting.line_plot import plot_line_over_time, plot_waterfall_over_time
from fdtdx.objects.detectors.plotting.plot2d import plot_2d_from_slices
from fdtdx.objects.detectors.plotting.video import generate_video_from_slices, plot_from_slices
from fdtdx.objects.object import SimulationObject
from fdtdx.typing import SliceTuple3D
DetectorState = dict[str, jax.Array]
[docs]
@autoinit
class Detector(SimulationObject, ABC):
"""Base class for electromagnetic field detectors in FDTD simulations.
This class provides core functionality for recording and analyzing electromagnetic field data
during FDTD simulations. It supports flexible timing control, data collection intervals,
and visualization of results.
"""
#: Data type for detector arrays, defaults to float32.
dtype: jnp.dtype = frozen_field(default=jnp.float32)
#: Whether to use exact field interpolation. Defaults to True.
exact_interpolation: bool = frozen_field(default=True)
#: Whether to record fields in inverse time order. Defaults to false.
inverse: bool = frozen_field(default=False)
#: This switch controls the time steps that the detector is on, i.e. records data.
#: Defaults to all time steps.
switch: OnOffSwitch = frozen_field(default=OnOffSwitch())
#: Whether to generate plots of recorded data. Defaults to true.
plot: bool = frozen_field(default=True)
#: Plot inverse data in reverse time order.
if_inverse_plot_backwards: bool = frozen_field(default=True)
#: Number of workers for video generation. If None (default), then no
#: multiprocessing is used. Note that the combination of multiprocessing and matplotlib is known to produce
#: problems and can cause the entire system to freeze. It does make the video generation much faster though.
num_video_workers: int | None = frozen_field(default=None) # only used when generating video
#: RGB color for plotting. Defaults to light green.
color: Color | None = frozen_field(default=XKCD_LIGHT_GREEN)
#: Interpolation method for plots. Defualts to "gaussian".
plot_interpolation: str = frozen_field(default="gaussian")
#: DPI resolution for plots. Defaults to None.
plot_dpi: int | None = frozen_field(default=None)
_num_time_steps_on: int = frozen_private_field()
_is_on_at_time_step_arr: jax.Array = private_field()
_time_step_to_arr_idx: jax.Array = private_field()
_cached_cell_volume_weights: jax.Array = private_field()
@property
def num_time_steps_recorded(self) -> int:
"""Gets the total number of time steps that will be recorded.
Returns:
int: Number of time steps where detector will record data.
Raises:
Exception: If detector is not yet initialized.
"""
if self._num_time_steps_on is None:
raise Exception("Detector is not yet initialized")
return self._num_time_steps_on
def _cell_volume_weights(self) -> jax.Array:
"""Return physical cell-volume weights for this detector's grid slice."""
return self._cached_cell_volume_weights
def _volume_weighted_spatial_mean(self, values: jax.Array, leading_dims: int) -> jax.Array:
"""Average spatial detector samples using physical cell volumes.
Args:
values: Array whose final three dimensions match ``grid_shape``.
leading_dims: Number of leading non-spatial dimensions to preserve,
such as component or frequency axes.
Returns:
``values`` averaged over the three spatial dimensions.
"""
weights = self._cell_volume_weights()
weight_shape = (1,) * leading_dims + weights.shape
spatial_axes = tuple(range(leading_dims, values.ndim))
return jnp.sum(values * weights.reshape(weight_shape), axis=spatial_axes) / jnp.sum(weights)
def _plot_axis_centers_um(self, axis: int) -> np.ndarray:
"""Return detector-local cell centers in micrometres for plotting.
Rectilinear grids use the physical cell centers from ``RectilinearGrid``. The
coordinates are shifted so plots start at the detector slice origin,
matching the historical uniform-grid display convention.
"""
grid = self._config.resolved_grid
if grid is not None:
start, stop = self.grid_slice_tuple[axis]
edges = np.asarray(grid.edges(axis)[start : stop + 1])
centers = 0.5 * (edges[:-1] + edges[1:])
return (centers - edges[0]) / 1.0e-6
spacing = self._config.uniform_spacing()
return (np.arange(self.grid_shape[axis]) + 0.5) * spacing / 1.0e-6
def _plot_axis_edges_um(self, axis: int) -> np.ndarray:
"""Return detector-local cell edges in micrometres for slice plots."""
grid = self._config.resolved_grid
if grid is not None:
start, stop = self.grid_slice_tuple[axis]
edges = np.asarray(grid.edges(axis)[start : stop + 1])
return (edges - edges[0]) / 1.0e-6
spacing = self._config.uniform_spacing()
return np.arange(self.grid_shape[axis] + 1) * spacing / 1.0e-6
def _plot_resolutions(self) -> tuple[float, float, float]:
"""Return a scalar-resolution tuple for legacy plotting APIs.
When a non-uniform ``RectilinearGrid`` is present this value is only a fallback
for call signatures; rectilinear plots receive explicit edge arrays and
do not use the scalar spacing to position cells.
"""
if self._config.has_nonuniform_grid:
assert self._config.resolved_grid is not None
spacing = self._config.resolved_grid.min_spacing
return (spacing, spacing, spacing)
spacing = self._config.uniform_spacing()
return (spacing, spacing, spacing)
def _plot_coordinate_edges_um(self) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
"""Return rectilinear detector edge coordinates when needed for plots."""
if not self._config.has_nonuniform_grid:
return None
return (
self._plot_axis_edges_um(0),
self._plot_axis_edges_um(1),
self._plot_axis_edges_um(2),
)
def _calculate_on_list(
self,
) -> list[bool]:
return self.switch.calculate_on_list(
num_total_time_steps=self._config.time_steps_total,
time_step_duration=self._config.time_step_duration,
)
def _num_latent_time_steps(self) -> int:
"""Calculates total number of time steps that will be recorded.
Returns:
int: Number of time steps where detector will be active.
"""
on_list = self._calculate_on_list()
return sum(on_list)
@abstractmethod
def _shape_dtype_single_time_step(
self,
) -> dict[str, jax.ShapeDtypeStruct]:
"""Gets shape and dtype information for a single time step recording.
Returns:
dict[str, jax.ShapeDtypeStruct]: Dictionary mapping field names to their
shape and dtype specifications.
Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError()
[docs]
def place_on_grid(
self: Self,
grid_slice_tuple: SliceTuple3D,
config: SimulationConfig,
key: jax.Array,
) -> Self:
self = super().place_on_grid(
grid_slice_tuple=grid_slice_tuple,
config=config,
key=key,
)
# determine number of time steps on
on_list = self._calculate_on_list()
on_arr = jnp.asarray(on_list, dtype=jnp.bool)
self = self.aset("_is_on_at_time_step_arr", on_arr, create_new_ok=True)
self = self.aset("_num_time_steps_on", sum(on_list), create_new_ok=True)
# calculate mapping time step -> arr index
counter = 0
num_t = self._config.time_steps_total
time_to_arr_idx_list = [-1 for _ in range(num_t)]
for t in range(num_t):
if on_list[t]:
time_to_arr_idx_list[t] = counter
counter += 1
time_to_arr_idx_arr = jnp.asarray(time_to_arr_idx_list, dtype=jnp.int32)
self = self.aset("_time_step_to_arr_idx", time_to_arr_idx_arr, create_new_ok=True)
# compute cell volume weights now that config and grid_slice are available
grid = self._config.resolved_grid
if grid is not None:
weights = grid.cell_volume(self.grid_slice_tuple)
else:
spacing = self._config.uniform_spacing()
weights = jnp.ones(self.grid_shape, dtype=self.dtype) * spacing**3
self = self.aset("_cached_cell_volume_weights", weights, create_new_ok=True)
return self
[docs]
def init_state(
self: Self,
) -> DetectorState:
# init arrays
shape_dtype_dict = self._shape_dtype_single_time_step()
state = {}
latent_time_size = self._num_latent_time_steps()
for name, shape_dtype in shape_dtype_dict.items():
cur_arr = jnp.zeros(
shape=(latent_time_size, *shape_dtype.shape),
dtype=shape_dtype.dtype,
)
state[name] = cur_arr
return state
[docs]
@abstractmethod
def update(
self,
time_step: jax.Array,
E: jax.Array,
H: jax.Array,
state: DetectorState,
inv_permittivity: jax.Array,
inv_permeability: jax.Array | float,
) -> DetectorState:
"""Updates detector state with current field values.
Args:
time_step (jax.Array): Current simulation time step.
E (jax.Array): Electric field array.
H (jax.Array): Magnetic field array.
state (DetectorState): Current detector state.
inv_permittivity (jax.Array): Inverse permittivity array.
inv_permeability (jax.Array | float): Inverse permeability array.
Returns:
DetectorState: Updated detector state.
"""
del (
time_step,
E,
H,
state,
inv_permittivity,
inv_permeability,
)
raise NotImplementedError()
[docs]
def draw_plot(
self,
state: dict[str, np.ndarray],
progress: Progress | None = None,
) -> dict[str, Figure | str]:
"""Generates plots or videos from recorded detector data.
Creates visualizations based on dimensionality of recorded data and detector
configuration. Supports 1D line plots, 2D heatmaps, and video generation
for time-varying data.
Args:
state (dict[str, np.ndarray]): Dictionary containing recorded field data arrays.
progress (Progress | None, optional): Optional progress bar for video generation.
Returns:
dict[str, Figure | str]: Dictionary mapping plot names to either
matplotlib Figure objects or paths to generated video files.
"""
squeezed_arrs = {}
squeezed_ndim = None
for k, v in state.items():
v_squeezed = v.squeeze()
if self.inverse and self.if_inverse_plot_backwards and self.num_time_steps_recorded > 1:
squeezed_arrs[k] = np.asarray(v_squeezed[::-1, ...])
else:
squeezed_arrs[k] = np.asarray(v_squeezed)
if squeezed_ndim is None:
squeezed_ndim = len(v_squeezed.shape)
else:
if len(v_squeezed.shape) != squeezed_ndim:
raise Exception("Cannot plot multiple arrays with different ndim")
if squeezed_ndim is None:
raise Exception(f"empty state: {state}")
figs = {}
if squeezed_ndim == 1 and self.num_time_steps_recorded > 1:
# do line plot
time_steps = np.where(np.asarray(self._is_on_at_time_step_arr))[0]
time_steps = time_steps * self._config.time_step_duration
for k, v in squeezed_arrs.items():
fig = plot_line_over_time(arr=v, time_steps=time_steps.tolist(), metric_name=f"{self.name}: {k}")
figs[k] = fig
elif squeezed_ndim == 1 and self.num_time_steps_recorded == 1:
xlabel = None
spatial_axis = None
if self.grid_shape[0] > 1 and self.grid_shape[1] <= 1 and self.grid_shape[2] <= 1:
xlabel = "X axis (μm)"
spatial_axis = self._plot_axis_centers_um(0)
elif self.grid_shape[0] <= 1 and self.grid_shape[1] > 1 and self.grid_shape[2] <= 1:
xlabel = "Y axis (μm)"
spatial_axis = self._plot_axis_centers_um(1)
elif self.grid_shape[0] <= 1 and self.grid_shape[1] <= 1 and self.grid_shape[2] > 1:
xlabel = "Z axis (μm)"
spatial_axis = self._plot_axis_centers_um(2)
assert xlabel is not None, "This should never happen"
assert spatial_axis is not None, "This should never happen"
for k, v in squeezed_arrs.items():
fig = plot_line_over_time(
arr=v, time_steps=spatial_axis, metric_name=f"{self.name}: {k}", xlabel=xlabel
)
figs[k] = fig
elif squeezed_ndim == 2 and self.num_time_steps_recorded > 1:
# multiple time steps, 1d spatial data - visualize as 2D waterfall plot
time_steps = np.where(np.asarray(self._is_on_at_time_step_arr))[0]
time_steps = time_steps * self._config.time_step_duration
for k, v in squeezed_arrs.items():
# Determine which dimension is spatial (not time)
spatial_dim = 1 if v.shape[1] > 1 else 0
if spatial_dim == 0:
# Transpose if needed so time is always first dimension
v = v.T
active_axes = [axis for axis, size in enumerate(self.grid_shape) if size > 1]
if len(active_axes) != 1:
raise Exception(f"Cannot infer one spatial plotting axis for grid shape {self.grid_shape}")
spatial_points = self._plot_axis_centers_um(active_axes[0])
fig = plot_waterfall_over_time(
arr=v,
time_steps=time_steps,
spatial_steps=spatial_points,
metric_name=f"{self.name}: {k}",
spatial_unit="μm",
)
figs[k] = fig
elif squeezed_ndim == 2 and self.num_time_steps_recorded == 1:
# single time step, 2d-plot # TODO:
if all([x in squeezed_arrs.keys() for x in ["XY Plane", "XZ Plane", "YZ Plane"]]):
fig = plot_2d_from_slices(
xy_slice=squeezed_arrs["XY Plane"],
xz_slice=squeezed_arrs["XZ Plane"],
yz_slice=squeezed_arrs["YZ Plane"],
resolutions=self._plot_resolutions(),
coordinate_edges_um=self._plot_coordinate_edges_um(),
plot_dpi=self.plot_dpi,
plot_interpolation=self.plot_interpolation,
)
figs["sliced_plot"] = fig
else:
raise Exception(f"Cannot plot {squeezed_arrs.keys()}")
elif squeezed_ndim == 3 and self.num_time_steps_recorded > 1:
# multiple time steps, 2d-plots
if all([x in squeezed_arrs.keys() for x in ["XY Plane", "XZ Plane", "YZ Plane"]]):
path = generate_video_from_slices(
plt_fn=plot_from_slices,
xy_slice=squeezed_arrs["XY Plane"],
xz_slice=squeezed_arrs["XZ Plane"],
yz_slice=squeezed_arrs["YZ Plane"],
progress=progress,
num_worker=self.num_video_workers,
resolutions=self._plot_resolutions(),
coordinate_edges_um=self._plot_coordinate_edges_um(),
plot_dpi=self.plot_dpi,
plot_interpolation=self.plot_interpolation,
)
figs["sliced_video"] = path
else:
raise Exception(
f"Cannot plot {squeezed_arrs.keys()}. "
f"Consider setting plot=False for Object {self.name} ({self.__class__=})"
)
elif squeezed_ndim == 3 and self.num_time_steps_recorded == 1:
# single step, 3d-plot. # TODO: do as mean over planes
for k, v in squeezed_arrs.items():
xy_slice = squeezed_arrs[k].mean(axis=2)
xz_slice = squeezed_arrs[k].mean(axis=1)
yz_slice = squeezed_arrs[k].mean(axis=0)
fig = plot_2d_from_slices(
xy_slice=xy_slice,
xz_slice=xz_slice,
yz_slice=yz_slice,
resolutions=self._plot_resolutions(),
coordinate_edges_um=self._plot_coordinate_edges_um(),
plot_dpi=self.plot_dpi,
plot_interpolation=self.plot_interpolation,
)
figs[k] = fig
elif squeezed_ndim == 4 and self.num_time_steps_recorded > 1:
# video with 3d-volume in each time step. plot as slices
for k, v in squeezed_arrs.items():
xy_slice = squeezed_arrs[k].mean(axis=3)
xz_slice = squeezed_arrs[k].mean(axis=2)
yz_slice = squeezed_arrs[k].mean(axis=1)
path = generate_video_from_slices(
plt_fn=plot_from_slices,
xy_slice=xy_slice,
xz_slice=xz_slice,
yz_slice=yz_slice,
progress=progress,
num_worker=self.num_video_workers,
resolutions=self._plot_resolutions(),
coordinate_edges_um=self._plot_coordinate_edges_um(),
plot_dpi=self.plot_dpi,
plot_interpolation=self.plot_interpolation,
)
figs[k] = path
else:
raise Exception("Cannot plot detector with more than three dimensions")
return figs