Source code for fdtdx.objects.detectors.energy

import jax

from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.core.physics.metrics import compute_energy
from fdtdx.objects.detectors.detector import Detector, DetectorState


[docs] @autoinit class EnergyDetector(Detector): """Detector for measuring electromagnetic energy distribution. This detector computes and records the electromagnetic energy density at specified points in the simulation volume. It can operate in different modes to either record full 3D data, 2D slices, or reduced volume measurements. """ #: If True, returns energy measurements as 2D slices through the volume. #: Defaults to False. as_slices: bool = frozen_field(default=False) #: If True, reduces the volume data to a single energy value. #: Defaults to False. reduce_volume: bool = frozen_field(default=False) #: real-world positions for slice extraction. #: Defaults to None. x_slice: float | None = frozen_field(default=None) #: real-world positions for slice extraction. #: Defaults to None. y_slice: float | None = frozen_field(default=None) #: real-world positions for slice extraction. #: Defaults to None. z_slice: float | None = frozen_field(default=None) #: If "mean", aggregates slices by averaging instead of using position. #: If None, mean is used. Defaults to None. aggregate: str | None = frozen_field(default=None) # e.g., "mean" def _shape_dtype_single_time_step( self, ) -> dict[str, jax.ShapeDtypeStruct]: if self.as_slices and self.reduce_volume: raise Exception("Cannot both reduce volume and save slices!") gs = self.grid_shape if self.as_slices: return { "XY Plane": jax.ShapeDtypeStruct((gs[0], gs[1]), self.dtype), "XZ Plane": jax.ShapeDtypeStruct((gs[0], gs[2]), self.dtype), "YZ Plane": jax.ShapeDtypeStruct((gs[1], gs[2]), self.dtype), } if self.reduce_volume: return {"energy": jax.ShapeDtypeStruct((1,), self.dtype)} return {"energy": jax.ShapeDtypeStruct(self.grid_shape, self.dtype)}
[docs] def update( self, time_step: jax.Array, E: jax.Array, H: jax.Array, state: DetectorState, inv_permittivity: jax.Array, inv_permeability: jax.Array | float, ) -> DetectorState: cur_E = E[:, *self.grid_slice] cur_H = H[:, *self.grid_slice] cur_inv_permittivity = inv_permittivity[:, *self.grid_slice] if isinstance(inv_permeability, jax.Array) and inv_permeability.ndim > 0: cur_inv_permeability = inv_permeability[:, *self.grid_slice] else: cur_inv_permeability = inv_permeability energy = compute_energy( E=cur_E, H=cur_H, inv_permittivity=cur_inv_permittivity, inv_permeability=cur_inv_permeability, ) arr_idx = self._time_step_to_arr_idx[time_step] if self.as_slices: use_mean = self.aggregate == "mean" or any( slice_ is None for slice_ in (self.x_slice, self.y_slice, self.z_slice) ) if use_mean: energy_xy = energy.mean(axis=2) energy_xz = energy.mean(axis=1) energy_yz = energy.mean(axis=0) else: # Convert real-world positions to indices origin_x = self.grid_slice[0].start * self._config.resolution origin_y = self.grid_slice[1].start * self._config.resolution origin_z = self.grid_slice[2].start * self._config.resolution def to_index(real_pos, origin, axis_len): if real_pos is not None: idx = int((real_pos - origin) / self._config.resolution) return max(0, min(idx, axis_len - 1)) return axis_len // 2 x_idx = to_index(self.x_slice, origin_x, energy.shape[0]) y_idx = to_index(self.y_slice, origin_y, energy.shape[1]) z_idx = to_index(self.z_slice, origin_z, energy.shape[2]) energy_xy = energy[:, :, z_idx] energy_xz = energy[:, y_idx, :] energy_yz = energy[x_idx, :, :] new_xy = state["XY Plane"].at[arr_idx].set(energy_xy) new_xz = state["XZ Plane"].at[arr_idx].set(energy_xz) new_yz = state["YZ Plane"].at[arr_idx].set(energy_yz) return { "XY Plane": new_xy, "XZ Plane": new_xz, "YZ Plane": new_yz, } if self.reduce_volume: total_energy = energy.sum() new_arr = state["energy"].at[arr_idx].set(total_energy) return {"energy": new_arr} new_arr = state["energy"].at[arr_idx].set(energy) return {"energy": new_arr}