Source code for fdtdx.objects.detectors.energy

import jax
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.core.physics.metrics import compute_energy
from fdtdx.objects.detectors.detector import Detector, DetectorState


[docs] @autoinit class EnergyDetector(Detector): """Detector for measuring electromagnetic energy distribution. This detector computes and records the electromagnetic energy density at specified points in the simulation volume. It can operate in different modes to either record full 3D data, 2D slices, or reduced volume measurements. """ #: If True, returns energy measurements as 2D slices through the volume. #: Defaults to False. as_slices: bool = frozen_field(default=False) #: If True, reduces the volume data to a single energy value. #: Defaults to False. reduce_volume: bool = frozen_field(default=False) #: real-world positions for slice extraction. #: Defaults to None. x_slice: float | None = frozen_field(default=None) #: real-world positions for slice extraction. #: Defaults to None. y_slice: float | None = frozen_field(default=None) #: real-world positions for slice extraction. #: Defaults to None. z_slice: float | None = frozen_field(default=None) #: If "mean", aggregates slices by averaging instead of using position. #: If None, mean is used. Defaults to None. aggregate: str | None = frozen_field(default=None) # e.g., "mean" def _slice_position_to_index(self, axis: int, real_pos: float | None, axis_len: int) -> int | jax.Array: """Map a requested physical slice position to a local energy index. Uniform grids keep the historical origin-plus-spacing conversion. On a rectilinear grid, slice positions are compared against cell centers in the placed detector interval. This avoids interpreting physical metres through a single global resolution and makes the selected slice stable under local grid stretching. Returns a plain Python int for uniform grids (safe as a static index under jit) and a 0-d JAX array for non-uniform grids (safe as a dynamic index under jit via JAX's dynamic indexing semantics). """ if real_pos is None: return axis_len // 2 grid = self._config.resolved_grid if grid is not None: start, stop = self.grid_slice_tuple[axis] centers = grid.centers(axis)[start:stop] idx = jnp.clip(jnp.argmin(jnp.abs(centers - real_pos)), 0, axis_len - 1) return idx spacing = self._config.uniform_spacing() origin = self.grid_slice[axis].start * spacing idx = int((real_pos - origin) / spacing) return max(0, min(idx, axis_len - 1)) def _shape_dtype_single_time_step( self, ) -> dict[str, jax.ShapeDtypeStruct]: if self.as_slices and self.reduce_volume: raise Exception("Cannot both reduce volume and save slices!") gs = self.grid_shape if self.as_slices: return { "XY Plane": jax.ShapeDtypeStruct((gs[0], gs[1]), self.dtype), "XZ Plane": jax.ShapeDtypeStruct((gs[0], gs[2]), self.dtype), "YZ Plane": jax.ShapeDtypeStruct((gs[1], gs[2]), self.dtype), } if self.reduce_volume: return {"energy": jax.ShapeDtypeStruct((1,), self.dtype)} return {"energy": jax.ShapeDtypeStruct(self.grid_shape, self.dtype)}
[docs] def update( self, time_step: jax.Array, E: jax.Array, H: jax.Array, state: DetectorState, inv_permittivity: jax.Array, inv_permeability: jax.Array | float, ) -> DetectorState: cur_E = E[:, *self.grid_slice] cur_H = H[:, *self.grid_slice] cur_inv_permittivity = inv_permittivity[:, *self.grid_slice] if isinstance(inv_permeability, jax.Array) and inv_permeability.ndim > 0: cur_inv_permeability = inv_permeability[:, *self.grid_slice] else: cur_inv_permeability = inv_permeability energy = compute_energy( E=cur_E, H=cur_H, inv_permittivity=cur_inv_permittivity, inv_permeability=cur_inv_permeability, ) arr_idx = self._time_step_to_arr_idx[time_step] if self.as_slices: use_mean = self.aggregate == "mean" or any( slice_ is None for slice_ in (self.x_slice, self.y_slice, self.z_slice) ) if use_mean: energy_xy = energy.mean(axis=2) energy_xz = energy.mean(axis=1) energy_yz = energy.mean(axis=0) else: x_idx = self._slice_position_to_index(0, self.x_slice, energy.shape[0]) y_idx = self._slice_position_to_index(1, self.y_slice, energy.shape[1]) z_idx = self._slice_position_to_index(2, self.z_slice, energy.shape[2]) energy_xy = energy[:, :, z_idx] energy_xz = energy[:, y_idx, :] energy_yz = energy[x_idx, :, :] new_xy = state["XY Plane"].at[arr_idx].set(energy_xy) new_xz = state["XZ Plane"].at[arr_idx].set(energy_xz) new_yz = state["YZ Plane"].at[arr_idx].set(energy_yz) return { "XY Plane": new_xy, "XZ Plane": new_xz, "YZ Plane": new_yz, } if self.reduce_volume: energy = energy * self._cell_volume_weights() total_energy = energy.sum() new_arr = state["energy"].at[arr_idx].set(total_energy) return {"energy": new_arr} new_arr = state["energy"].at[arr_idx].set(energy) return {"energy": new_arr}