Source code for fdtdx.objects.detectors.poynting_flux

from typing import Literal, Self

import jax
import jax.numpy as jnp

from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, frozen_field, private_field
from fdtdx.core.physics.metrics import compute_poynting_flux
from fdtdx.objects.detectors.detector import Detector, DetectorState
from fdtdx.typing import SliceTuple3D


[docs] @autoinit class PoyntingFluxDetector(Detector): """Detector for measuring Poynting flux in electromagnetic simulations. This detector computes the Poynting flux (power flow) through a specified surface in the simulation volume. It can measure flux in either positive or negative direction along the propagation axis, and optionally reduce measurements to a single value by summing over the detection surface. """ #: Direction of flux measurement, either "+" for positive or "-" for negative along the propagation axis. direction: Literal["+", "-"] = frozen_field() #: If True, reduces measurements to a single value by summing over the detection surface. #: If False, maintains spatial distribution. Defaults to True. reduce_volume: bool = frozen_field(default=True) #: By default, the propagation axis for calculating the poynting #: flux is the axis, where the detector has a grid shape of 1. If the detector has a shape of 1 in more than #: one axes or a different axis should be used, then this attribute can/has to be set. Defaults to None. fixed_propagation_axis: int | None = frozen_field(default=None) #: By default, only the poynting flux component for the propagation axis #: is returned (scalar). If true, all three vector components are returned. Defaults to False. keep_all_components: bool = frozen_field(default=False) _cached_face_area_weights: jax.Array = private_field() @property def propagation_axis(self) -> int: """Determines the axis along which Poynting flux is measured. The propagation axis is identified as the dimension with size 1 in the detector's grid shape, representing a plane perpendicular to the flux measurement direction. Returns: int: Index of the propagation axis (0 for x, 1 for y, 2 for z) Raises: Exception: If detector shape does not have exactly one dimension of size 1 """ if self.fixed_propagation_axis is not None: if self.fixed_propagation_axis not in [0, 1, 2]: raise Exception(f"Invalid: {self.fixed_propagation_axis=}") return self.fixed_propagation_axis if sum([a == 1 for a in self.grid_shape]) != 1: raise Exception(f"Invalid poynting flux detector shape: {self.grid_shape}") return self.grid_shape.index(1)
[docs] def place_on_grid( self: Self, grid_slice_tuple: SliceTuple3D, config: SimulationConfig, key: jax.Array, ) -> Self: self = super().place_on_grid(grid_slice_tuple=grid_slice_tuple, config=config, key=key) can_determine_axis = self.keep_all_components or ( self.fixed_propagation_axis is not None or sum(a == 1 for a in self.grid_shape) == 1 ) if can_determine_axis: grid = self._config.resolved_grid if grid is not None: if self.keep_all_components: weights = jnp.stack( [grid.face_area(axis=axis, slice_tuple=self.grid_slice_tuple) for axis in range(3)] ) else: weights = grid.face_area(axis=self.propagation_axis, slice_tuple=self.grid_slice_tuple) else: spacing = self._config.uniform_spacing() area = jnp.ones(self.grid_shape, dtype=self.dtype) * spacing * spacing weights = jnp.stack([area, area, area]) if self.keep_all_components else area self = self.aset("_cached_face_area_weights", weights, create_new_ok=True) return self
def _shape_dtype_single_time_step( self, ) -> dict[str, jax.ShapeDtypeStruct]: if self.keep_all_components: shape = (3,) if self.reduce_volume else (3, *self.grid_shape) else: shape = (1,) if self.reduce_volume else self.grid_shape return {"poynting_flux": jax.ShapeDtypeStruct(shape, self.dtype)} def _face_area_weights(self) -> jax.Array: """Return face-area weights matching this detector's grid slice.""" return self._cached_face_area_weights
[docs] def update( self, time_step: jax.Array, E: jax.Array, H: jax.Array, state: DetectorState, inv_permittivity: jax.Array, inv_permeability: jax.Array | float, ) -> DetectorState: del inv_permeability, inv_permittivity cur_E = E[:, *self.grid_slice] cur_H = H[:, *self.grid_slice] pf = compute_poynting_flux(cur_E, cur_H).real if not self.keep_all_components: pf = pf[self.propagation_axis] if self.direction == "-": pf = -pf if self.reduce_volume: pf = pf * self._face_area_weights() if self.keep_all_components: pf = pf.sum(axis=(1, 2, 3)) else: pf = pf.sum() arr_idx = self._time_step_to_arr_idx[time_step] new_full_arr = state["poynting_flux"].at[arr_idx].set(pf) new_state = {"poynting_flux": new_full_arr} return new_state