from abc import ABC, abstractmethod
from typing import Self
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib.figure import Figure
from rich.progress import Progress
from fdtdx.colors import XKCD_LIGHT_GREEN, Color
from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, frozen_field, frozen_private_field, private_field
from fdtdx.core.switch import OnOffSwitch
from fdtdx.objects.detectors.plotting.line_plot import plot_line_over_time, plot_waterfall_over_time
from fdtdx.objects.detectors.plotting.plot2d import plot_2d_from_slices
from fdtdx.objects.detectors.plotting.video import generate_video_from_slices, plot_from_slices
from fdtdx.objects.object import SimulationObject
from fdtdx.typing import SliceTuple3D
DetectorState = dict[str, jax.Array]
[docs]
@autoinit
class Detector(SimulationObject, ABC):
"""Base class for electromagnetic field detectors in FDTD simulations.
This class provides core functionality for recording and analyzing electromagnetic field data
during FDTD simulations. It supports flexible timing control, data collection intervals,
and visualization of results.
"""
#: Data type for detector arrays, defaults to float32.
dtype: jnp.dtype = frozen_field(default=jnp.float32)
#: Whether to use exact field interpolation. Defaults to True.
exact_interpolation: bool = frozen_field(default=True)
#: Whether to record fields in inverse time order. Defaults to false.
inverse: bool = frozen_field(default=False)
#: This switch controls the time steps that the detector is on, i.e. records data.
#: Defaults to all time steps.
switch: OnOffSwitch = frozen_field(default=OnOffSwitch())
#: Whether to generate plots of recorded data. Defaults to true.
plot: bool = frozen_field(default=True)
#: Plot inverse data in reverse time order.
if_inverse_plot_backwards: bool = frozen_field(default=True)
#: Number of workers for video generation. If None (default), then no
#: multiprocessing is used. Note that the combination of multiprocessing and matplotlib is known to produce
#: problems and can cause the entire system to freeze. It does make the video generation much faster though.
num_video_workers: int | None = frozen_field(default=None) # only used when generating video
#: RGB color for plotting. Defaults to light green.
color: Color | None = frozen_field(default=XKCD_LIGHT_GREEN)
#: Interpolation method for plots. Defualts to "gaussian".
plot_interpolation: str = frozen_field(default="gaussian")
#: DPI resolution for plots. Defaults to None.
plot_dpi: int | None = frozen_field(default=None)
_num_time_steps_on: int = frozen_private_field()
_is_on_at_time_step_arr: jax.Array = private_field()
_time_step_to_arr_idx: jax.Array = private_field()
@property
def num_time_steps_recorded(self) -> int:
"""Gets the total number of time steps that will be recorded.
Returns:
int: Number of time steps where detector will record data.
Raises:
Exception: If detector is not yet initialized.
"""
if self._num_time_steps_on is None:
raise Exception("Detector is not yet initialized")
return self._num_time_steps_on
def _calculate_on_list(
self,
) -> list[bool]:
return self.switch.calculate_on_list(
num_total_time_steps=self._config.time_steps_total,
time_step_duration=self._config.time_step_duration,
)
def _num_latent_time_steps(self) -> int:
"""Calculates total number of time steps that will be recorded.
Returns:
int: Number of time steps where detector will be active.
"""
on_list = self._calculate_on_list()
return sum(on_list)
@abstractmethod
def _shape_dtype_single_time_step(
self,
) -> dict[str, jax.ShapeDtypeStruct]:
"""Gets shape and dtype information for a single time step recording.
Returns:
dict[str, jax.ShapeDtypeStruct]: Dictionary mapping field names to their
shape and dtype specifications.
Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError()
[docs]
def place_on_grid(
self: Self,
grid_slice_tuple: SliceTuple3D,
config: SimulationConfig,
key: jax.Array,
) -> Self:
self = super().place_on_grid(
grid_slice_tuple=grid_slice_tuple,
config=config,
key=key,
)
# determine number of time steps on
on_list = self._calculate_on_list()
on_arr = jnp.asarray(on_list, dtype=jnp.bool)
self = self.aset("_is_on_at_time_step_arr", on_arr, create_new_ok=True)
self = self.aset("_num_time_steps_on", sum(on_list), create_new_ok=True)
# calculate mapping time step -> arr index
counter = 0
num_t = self._config.time_steps_total
time_to_arr_idx_list = [-1 for _ in range(num_t)]
for t in range(num_t):
if on_list[t]:
time_to_arr_idx_list[t] = counter
counter += 1
time_to_arr_idx_arr = jnp.asarray(time_to_arr_idx_list, dtype=jnp.int32)
self = self.aset("_time_step_to_arr_idx", time_to_arr_idx_arr, create_new_ok=True)
return self
[docs]
def init_state(
self: Self,
) -> DetectorState:
# init arrays
shape_dtype_dict = self._shape_dtype_single_time_step()
state = {}
latent_time_size = self._num_latent_time_steps()
for name, shape_dtype in shape_dtype_dict.items():
cur_arr = jnp.zeros(
shape=(latent_time_size, *shape_dtype.shape),
dtype=shape_dtype.dtype,
)
state[name] = cur_arr
return state
[docs]
@abstractmethod
def update(
self,
time_step: jax.Array,
E: jax.Array,
H: jax.Array,
state: DetectorState,
inv_permittivity: jax.Array,
inv_permeability: jax.Array | float,
) -> DetectorState:
"""Updates detector state with current field values.
Args:
time_step (jax.Array): Current simulation time step.
E (jax.Array): Electric field array.
H (jax.Array): Magnetic field array.
state (DetectorState): Current detector state.
inv_permittivity (jax.Array): Inverse permittivity array.
inv_permeability (jax.Array | float): Inverse permeability array.
Returns:
DetectorState: Updated detector state.
"""
del (
time_step,
E,
H,
state,
inv_permittivity,
inv_permeability,
)
raise NotImplementedError()
[docs]
def draw_plot(
self,
state: dict[str, np.ndarray],
progress: Progress | None = None,
) -> dict[str, Figure | str]:
"""Generates plots or videos from recorded detector data.
Creates visualizations based on dimensionality of recorded data and detector
configuration. Supports 1D line plots, 2D heatmaps, and video generation
for time-varying data.
Args:
state (dict[str, np.ndarray]): Dictionary containing recorded field data arrays.
progress (Progress | None, optional): Optional progress bar for video generation.
Returns:
dict[str, Figure | str]: Dictionary mapping plot names to either
matplotlib Figure objects or paths to generated video files.
"""
squeezed_arrs = {}
squeezed_ndim = None
for k, v in state.items():
v_squeezed = v.squeeze()
if self.inverse and self.if_inverse_plot_backwards and self.num_time_steps_recorded > 1:
squeezed_arrs[k] = np.asarray(v_squeezed[::-1, ...])
else:
squeezed_arrs[k] = np.asarray(v_squeezed)
if squeezed_ndim is None:
squeezed_ndim = len(v_squeezed.shape)
else:
if len(v_squeezed.shape) != squeezed_ndim:
raise Exception("Cannot plot multiple arrays with different ndim")
if squeezed_ndim is None:
raise Exception(f"empty state: {state}")
figs = {}
if squeezed_ndim == 1 and self.num_time_steps_recorded > 1:
# do line plot
time_steps = np.where(np.asarray(self._is_on_at_time_step_arr))[0]
time_steps = time_steps * self._config.time_step_duration
for k, v in squeezed_arrs.items():
fig = plot_line_over_time(arr=v, time_steps=time_steps.tolist(), metric_name=f"{self.name}: {k}")
figs[k] = fig
elif squeezed_ndim == 1 and self.num_time_steps_recorded == 1:
SCALE = 10
xlabel = None
if self.grid_shape[0] > 1 and self.grid_shape[1] <= 1 and self.grid_shape[2] <= 1:
xlabel = "X axis (μm)"
elif self.grid_shape[0] <= 1 and self.grid_shape[1] > 1 and self.grid_shape[2] <= 1:
xlabel = "Y axis (μm)"
elif self.grid_shape[0] <= 1 and self.grid_shape[1] <= 1 and self.grid_shape[2] > 1:
xlabel = "Z axis (μm)"
assert xlabel is not None, "This should never happen"
for k, v in squeezed_arrs.items():
spatial_axis = np.arange(len(v)) / SCALE
fig = plot_line_over_time(
arr=v, time_steps=spatial_axis, metric_name=f"{self.name}: {k}", xlabel=xlabel
)
figs[k] = fig
elif squeezed_ndim == 2 and self.num_time_steps_recorded > 1:
# multiple time steps, 1d spatial data - visualize as 2D waterfall plot
time_steps = np.where(np.asarray(self._is_on_at_time_step_arr))[0]
time_steps = time_steps * self._config.time_step_duration
# Determine spatial axis based on which dimension has size > 1
SCALE = 10 # μm per grid point
for k, v in squeezed_arrs.items():
# Determine which dimension is spatial (not time)
spatial_dim = 1 if v.shape[1] > 1 else 0
if spatial_dim == 0:
# Transpose if needed so time is always first dimension
v = v.T
# Create spatial axis in μm
spatial_points = np.arange(v.shape[1]) / SCALE
fig = plot_waterfall_over_time(
arr=v,
time_steps=time_steps,
spatial_steps=spatial_points,
metric_name=f"{self.name}: {k}",
spatial_unit="μm",
)
figs[k] = fig
elif squeezed_ndim == 2 and self.num_time_steps_recorded == 1:
# single time step, 2d-plot # TODO:
if all([x in squeezed_arrs.keys() for x in ["XY Plane", "XZ Plane", "YZ Plane"]]):
fig = plot_2d_from_slices(
xy_slice=squeezed_arrs["XY Plane"],
xz_slice=squeezed_arrs["XZ Plane"],
yz_slice=squeezed_arrs["YZ Plane"],
resolutions=(
self._config.resolution,
self._config.resolution,
self._config.resolution,
),
plot_dpi=self.plot_dpi,
plot_interpolation=self.plot_interpolation,
)
figs["sliced_plot"] = fig
else:
raise Exception(f"Cannot plot {squeezed_arrs.keys()}")
elif squeezed_ndim == 3 and self.num_time_steps_recorded > 1:
# multiple time steps, 2d-plots
if all([x in squeezed_arrs.keys() for x in ["XY Plane", "XZ Plane", "YZ Plane"]]):
path = generate_video_from_slices(
plt_fn=plot_from_slices,
xy_slice=squeezed_arrs["XY Plane"],
xz_slice=squeezed_arrs["XZ Plane"],
yz_slice=squeezed_arrs["YZ Plane"],
progress=progress,
num_worker=self.num_video_workers,
resolutions=(
self._config.resolution,
self._config.resolution,
self._config.resolution,
),
plot_dpi=self.plot_dpi,
plot_interpolation=self.plot_interpolation,
)
figs["sliced_video"] = path
else:
raise Exception(
f"Cannot plot {squeezed_arrs.keys()}. "
f"Consider setting plot=False for Object {self.name} ({self.__class__=})"
)
elif squeezed_ndim == 3 and self.num_time_steps_recorded == 1:
# single step, 3d-plot. # TODO: do as mean over planes
for k, v in squeezed_arrs.items():
xy_slice = squeezed_arrs[k].mean(axis=2)
xz_slice = squeezed_arrs[k].mean(axis=1)
yz_slice = squeezed_arrs[k].mean(axis=0)
fig = plot_2d_from_slices(
xy_slice=xy_slice,
xz_slice=xz_slice,
yz_slice=yz_slice,
resolutions=(
self._config.resolution,
self._config.resolution,
self._config.resolution,
),
plot_dpi=self.plot_dpi,
plot_interpolation=self.plot_interpolation,
)
figs[k] = fig
elif squeezed_ndim == 4 and self.num_time_steps_recorded > 1:
# video with 3d-volume in each time step. plot as slices
for k, v in squeezed_arrs.items():
xy_slice = squeezed_arrs[k].mean(axis=3)
xz_slice = squeezed_arrs[k].mean(axis=2)
yz_slice = squeezed_arrs[k].mean(axis=1)
path = generate_video_from_slices(
plt_fn=plot_from_slices,
xy_slice=xy_slice,
xz_slice=xz_slice,
yz_slice=yz_slice,
progress=progress,
num_worker=self.num_video_workers,
resolutions=(
self._config.resolution,
self._config.resolution,
self._config.resolution,
),
plot_dpi=self.plot_dpi,
plot_interpolation=self.plot_interpolation,
)
figs[k] = path
else:
raise Exception("Cannot plot detector with more than three dimensions")
return figs