Source code for fdtdx.objects.detectors.field

from typing import Literal, Sequence

import jax
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.objects.detectors.detector import Detector, DetectorState


[docs] @autoinit class FieldDetector(Detector): """Detector for measuring field components of electromagnetic fields in the time domain.""" #: If True, reduces the volume of recorded data. Defaults to False. reduce_volume: bool = frozen_field(default=False) #: Sequence of field components to #: measure. Can include any of: "Ex", "Ey", "Ez", "Hx", "Hy", "Hz". #: Defaults to ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"). components: Sequence[Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]] = frozen_field( default=("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"), ) def _shape_dtype_single_time_step( self, ) -> dict[str, jax.ShapeDtypeStruct]: num_components = len(self.components) component_shape = (num_components,) if self.reduce_volume else (num_components, *self.grid_shape) return {"fields": jax.ShapeDtypeStruct(shape=component_shape, dtype=self.dtype)}
[docs] def update( self, time_step: jax.Array, E: jax.Array, H: jax.Array, state: DetectorState, inv_permittivity: jax.Array, inv_permeability: jax.Array | float, ) -> DetectorState: del inv_permeability, inv_permittivity E, H = E[:, *self.grid_slice], H[:, *self.grid_slice] fields = [] if "Ex" in self.components: fields.append(E[0]) if "Ey" in self.components: fields.append(E[1]) if "Ez" in self.components: fields.append(E[2]) if "Hx" in self.components: fields.append(H[0]) if "Hy" in self.components: fields.append(H[1]) if "Hz" in self.components: fields.append(H[2]) EH = jnp.stack(fields, axis=0) if self.reduce_volume: EH = self._volume_weighted_spatial_mean(EH, leading_dims=1) arr_idx = self._time_step_to_arr_idx[time_step] new_full_arr = state["fields"].at[arr_idx].set(EH) new_state = {"fields": new_full_arr} return new_state