from typing import Literal, Sequence
import jax
import jax.numpy as jnp
from fdtdx.core.jax.pytrees import autoinit, field, frozen_field
from fdtdx.core.wavelength import WaveCharacter
from fdtdx.objects.detectors.detector import Detector, DetectorState
[docs]
@autoinit
class PhasorDetector(Detector):
"""Detector for measuring frequency components of electromagnetic fields using an efficient Phasor Implementation.
This detector computes complex phasor representations of the field components at specified
frequencies, enabling frequency-domain analysis of the electromagnetic fields.
The amplitude and phase of the original phase can be reconstructed using jnp.abs(phasor) and jnp.angle(phasor).
The reconstruction itself can then be achieved using amplitude * jnp.cos(2 * jnp.pi * freq * t + phase).
"""
#: WaveCharacters to analyze.
wave_characters: Sequence[WaveCharacter] = field()
#: If True, reduces the volume of recorded data. Defaults to False.
reduce_volume: bool = frozen_field(default=False)
#: Sequence of field components to measure.
#: Can include any of: "Ex", "Ey", "Ez", "Hx", "Hy", "Hz".
components: Sequence[Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]] = frozen_field(
default=("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"),
)
dtype: jnp.dtype = frozen_field(
default=jnp.complex64,
)
#: Whether to plot the measured data. Defaults to False.
plot: bool = frozen_field(default=False)
#: Scaling of the resulting phasor. In continuous mode, the result is scaled by a factor of 2 / N, where N is
#: the number of time steps recorded. This allows accurate reconstruction of a continuous signal.
#: In pulse mode, the result is not scaled.
scaling_mode: Literal["continuous", "pulse"] = frozen_field(default="continuous")
def __post_init__(
self,
):
if self.dtype not in [jnp.complex64, jnp.complex128]:
raise Exception(f"Invalid dtype in PhasorDetector: {self.dtype}")
@property
def _angular_frequencies(self) -> jax.Array:
freqs = [wc.get_frequency() for wc in self.wave_characters]
return 2 * jnp.pi * jnp.array(freqs)
def _num_latent_time_steps(self) -> int:
return 1
def _shape_dtype_single_time_step(
self,
) -> dict[str, jax.ShapeDtypeStruct]:
field_dtype = jnp.complex128 if self.dtype == jnp.complex128 else jnp.complex64
num_components = len(self.components)
num_frequencies = len(self._angular_frequencies)
grid_shape = self.grid_shape if not self.reduce_volume else tuple([])
phasor_shape = (num_frequencies, num_components, *grid_shape)
return {"phasor": jax.ShapeDtypeStruct(shape=phasor_shape, dtype=field_dtype)}
[docs]
def update(
self,
time_step: jax.Array,
E: jax.Array,
H: jax.Array,
state: DetectorState,
inv_permittivity: jax.Array,
inv_permeability: jax.Array | float,
) -> DetectorState:
del inv_permeability, inv_permittivity
time_passed = time_step * self._config.time_step_duration
if self.scaling_mode == "continuous":
static_scale = 2 / self.num_time_steps_recorded
elif self.scaling_mode == "pulse":
static_scale = 1
else:
raise Exception(f"Invalid scaling mode: {self.scaling_mode=}")
E, H = E[:, *self.grid_slice], H[:, *self.grid_slice]
fields = []
if "Ex" in self.components:
fields.append(E[0])
if "Ey" in self.components:
fields.append(E[1])
if "Ez" in self.components:
fields.append(E[2])
if "Hx" in self.components:
fields.append(H[0])
if "Hy" in self.components:
fields.append(H[1])
if "Hz" in self.components:
fields.append(H[2])
EH = jnp.stack(fields, axis=0)
# Vectorized phasor calculation for all frequencies
phase_angles = self._angular_frequencies * time_passed # Shape: (num_freqs,)
phasors = jnp.exp(1j * phase_angles) # Shape: (num_freqs,)
# Reshape phasors to (num_freqs, 1, 1, 1, 1) for proper broadcasting with EH (num_components, x, y, z)
phasors = phasors.reshape((len(self._angular_frequencies),) + (1,) * EH.ndim)
new_phasors = EH * phasors * static_scale # Shape: (num_freqs, num_components, *grid_shape)
if self.reduce_volume:
# Average over spatial dimensions using physical cell volumes.
new_phasors = self._volume_weighted_spatial_mean(new_phasors, leading_dims=2)
if self.inverse:
result = state["phasor"] - new_phasors[None, ...]
else:
result = state["phasor"] + new_phasors[None, ...]
return {"phasor": result.astype(self.dtype)}