Source code for fdtdx.objects.detectors.poynting_flux

from typing import Literal

import jax

from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.core.physics.metrics import compute_poynting_flux
from fdtdx.objects.detectors.detector import Detector, DetectorState


[docs] @autoinit class PoyntingFluxDetector(Detector): """Detector for measuring Poynting flux in electromagnetic simulations. This detector computes the Poynting flux (power flow) through a specified surface in the simulation volume. It can measure flux in either positive or negative direction along the propagation axis, and optionally reduce measurements to a single value by summing over the detection surface. """ #: Direction of flux measurement, either "+" for positive or "-" for negative along the propagation axis. direction: Literal["+", "-"] = frozen_field() #: If True, reduces measurements to a single value by summing over the detection surface. #: If False, maintains spatial distribution. Defaults to True. reduce_volume: bool = frozen_field(default=True) #: By default, the propagation axis for calculating the poynting #: flux is the axis, where the detector has a grid shape of 1. If the detector has a shape of 1 in more than #: one axes or a different axis should be used, then this attribute can/has to be set. Defaults to None. fixed_propagation_axis: int | None = frozen_field(default=None) #: By default, only the poynting flux component for the propagation axis #: is returned (scalar). If true, all three vector components are returned. Defaults to False. keep_all_components: bool = frozen_field(default=False) @property def propagation_axis(self) -> int: """Determines the axis along which Poynting flux is measured. The propagation axis is identified as the dimension with size 1 in the detector's grid shape, representing a plane perpendicular to the flux measurement direction. Returns: int: Index of the propagation axis (0 for x, 1 for y, 2 for z) Raises: Exception: If detector shape does not have exactly one dimension of size 1 """ if self.fixed_propagation_axis is not None: if self.fixed_propagation_axis not in [0, 1, 2]: raise Exception(f"Invalid: {self.fixed_propagation_axis=}") return self.fixed_propagation_axis if sum([a == 1 for a in self.grid_shape]) != 1: raise Exception(f"Invalid poynting flux detector shape: {self.grid_shape}") return self.grid_shape.index(1) def _shape_dtype_single_time_step( self, ) -> dict[str, jax.ShapeDtypeStruct]: if self.keep_all_components: shape = (3,) if self.reduce_volume else (3, *self.grid_shape) else: shape = (1,) if self.reduce_volume else self.grid_shape return {"poynting_flux": jax.ShapeDtypeStruct(shape, self.dtype)}
[docs] def update( self, time_step: jax.Array, E: jax.Array, H: jax.Array, state: DetectorState, inv_permittivity: jax.Array, inv_permeability: jax.Array | float, ) -> DetectorState: del inv_permeability, inv_permittivity cur_E = E[:, *self.grid_slice] cur_H = H[:, *self.grid_slice] pf = compute_poynting_flux(cur_E, cur_H).real if not self.keep_all_components: pf = pf[self.propagation_axis] if self.direction == "-": pf = -pf if self.reduce_volume: if self.keep_all_components: pf = pf.sum(axis=(1, 2, 3)) else: pf = pf.sum() arr_idx = self._time_step_to_arr_idx[time_step] new_full_arr = state["poynting_flux"].at[arr_idx].set(pf) new_state = {"poynting_flux": new_full_arr} return new_state