from typing import Literal, Self, Sequence
import jax
import jax.numpy as jnp
import numpy as np
from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, frozen_field, private_field
from fdtdx.core.null import Null
from fdtdx.core.physics.modes import compute_mode
from fdtdx.dispersion import effective_inv_permittivity
from fdtdx.objects.detectors.detector import DetectorState
from fdtdx.objects.detectors.phasor import PhasorDetector
from fdtdx.typing import SliceTuple3D
[docs]
@autoinit
class ModeOverlapDetector(PhasorDetector):
"""
Detector for measuring the overlap of a waveguide mode with the simulation fields.
This detector computes the overlap of a mode with the phasor fields at a specified
frequency, enabling frequency-domain analysis of the electromagnetic fields.
The mode overlap is calculated by integrating the cross product of the mode fields
with the simulation fields over a cross-sectional plane. This is useful for
analyzing waveguide coupling efficiency, transmission coefficients, and modal
decomposition of electromagnetic fields.
"""
#: Direction of mode propagation, either "+" (forward) or "-" (backward).
#: Determines which direction along the waveguide axis the mode is assumed to propagate.
direction: Literal["+", "-"] = frozen_field()
#: Index of the waveguide mode to use for overlap calculation.
#: Defaults to 0 (fundamental mode). Higher indices correspond to higher-order modes.
mode_index: int = frozen_field(default=0)
#: Optional polarization filter for the mode calculation.
#: Can be "te" (transverse electric), "tm" (transverse magnetic), or None (no filtering).
#: When specified, only modes of the given polarization type are considered. Defaults to None.
filter_pol: Literal["te", "tm"] | None = frozen_field(default=None)
#: Bend radius of the waveguide in meters. When set, the mode solver accounts for the conformal
#: transformation introduced by the bend. Must be set together with bend_axis. Defaults to None
#: (straight waveguide).
bend_radius: float | None = frozen_field(default=None)
#: Physical axis index (0=x, 1=y, 2=z) pointing from the waveguide center toward the center of
#: curvature. Must differ from the propagation axis. Required when bend_radius is set.
bend_axis: int | None = frozen_field(default=None)
#: Cannot be specified here since the detector needs all components.
components: Sequence[Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]] = frozen_field(
default=("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"),
init=False, # in this detector, we always want all components. Do not give user a choice
)
#: Cannot be specified here since plotting a single scalar is useless.
plot: bool = frozen_field(default=False, init=False) # single scalar is useless for plotting
_mode_E: jax.Array = private_field()
_mode_H: jax.Array = private_field()
_mode_neff: jax.Array = private_field() # not required for detection, used for inspection
_cached_face_area_weights: jax.Array = private_field()
@property
def propagation_axis(self) -> int:
if sum([a == 1 for a in self.grid_shape]) != 1:
raise Exception(f"Invalid ModeOverlapDetector shape: {self.grid_shape}")
return self.grid_shape.index(1)
[docs]
def place_on_grid(
self: Self,
grid_slice_tuple: SliceTuple3D,
config: SimulationConfig,
key: jax.Array,
) -> Self:
self = super().place_on_grid(
grid_slice_tuple=grid_slice_tuple,
config=config,
key=key,
)
if len(self.wave_characters) > 1:
raise NotImplementedError()
if (self.bend_radius is None) != (self.bend_axis is None):
raise ValueError("bend_radius and bend_axis must both be set or both be None")
if self.bend_axis is not None and self.bend_axis == self.propagation_axis:
raise ValueError(
f"bend_axis ({self.bend_axis}) must differ from the propagation axis ({self.propagation_axis})"
)
grid = self._config.resolved_grid
if grid is not None:
weights = grid.face_area(axis=self.propagation_axis, slice_tuple=self.grid_slice_tuple)
else:
spacing = self._config.uniform_spacing()
weights = jnp.ones(self.grid_shape, dtype=jnp.float32) * spacing * spacing
self = self.aset("_cached_face_area_weights", weights, create_new_ok=True)
return self
def _face_area_weights(self) -> jax.Array:
"""Return detector-plane face areas for mode-overlap integration."""
return self._cached_face_area_weights
def _transverse_edge_coordinates(self) -> tuple[jax.Array, jax.Array] | None:
"""Return physical transverse edge coordinates for the mode solver.
Tidy3D can solve modes on rectilinear non-uniform grids when supplied
with edge-coordinate arrays. Returning ``None`` keeps the uniform scalar
spacing path for legacy configurations and older tests.
"""
grid = self._config.resolved_grid
if grid is None:
return None
transverse_edges = []
for axis in range(3):
if axis == self.propagation_axis:
continue
lower, upper = self.grid_slice_tuple[axis]
transverse_edges.append(grid.edges(axis)[lower : upper + 1])
e0, e1 = transverse_edges
return e0, e1
def _mode_solver_resolution(self) -> float:
"""Return scalar resolution only for legacy uniform mode-solver setup.
``compute_mode`` ignores this value when explicit transverse coordinates
are supplied. For non-uniform grids we pass a harmless finite value so
the compatibility argument does not force a uniform-grid check.
"""
if self._config.has_nonuniform_grid:
assert self._config.resolved_grid is not None
return self._config.resolved_grid.min_spacing
return self._config.uniform_spacing()
[docs]
def apply(
self,
key: jax.Array,
inv_permittivities: jax.Array,
inv_permeabilities: jax.Array | float,
dispersive_c1: jax.Array | None = None,
dispersive_c2: jax.Array | None = None,
dispersive_c3: jax.Array | None = None,
) -> Self:
del key
inv_permittivity_slice = inv_permittivities[:, *self.grid_slice]
if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
inv_permeability_slice = inv_permeabilities[:, *self.grid_slice]
else:
inv_permeability_slice = inv_permeabilities
# Frequency-correct the permittivity seen by the mode solver so the
# reference mode profile reflects ε(ω_c) of any dispersive medium the
# detector sits in, not just ε∞. Matches the pattern in
# ModePlaneSource.apply. Cells with no pole have c1=c2=c3=0, in which
# case effective_inv_permittivity returns inv_eps unchanged.
if dispersive_c1 is not None and dispersive_c2 is not None and dispersive_c3 is not None:
c1_slice = dispersive_c1[:, :, *self.grid_slice]
c2_slice = dispersive_c2[:, :, *self.grid_slice]
c3_slice = dispersive_c3[:, :, *self.grid_slice]
inv_permittivity_slice = effective_inv_permittivity(
inv_eps=inv_permittivity_slice,
c1=c1_slice,
c2=c2_slice,
c3=c3_slice,
omega=2.0 * np.pi * self.wave_characters[0].get_frequency(),
dt=self._config.time_step_duration,
)
mode_E, mode_H, mode_neff = compute_mode(
frequency=self.wave_characters[0].get_frequency(),
inv_permittivities=inv_permittivity_slice,
inv_permeabilities=inv_permeability_slice,
resolution=self._mode_solver_resolution(),
direction=self.direction,
mode_index=self.mode_index,
filter_pol=self.filter_pol,
dtype=self._config.dtype,
bend_radius=self.bend_radius,
bend_axis=self.bend_axis,
transverse_coords=self._transverse_edge_coordinates(),
)
self = self.aset("_mode_E", mode_E, create_new_ok=True)
self = self.aset("_mode_H", mode_H, create_new_ok=True)
self = self.aset("_mode_neff", mode_neff, create_new_ok=True)
return self
[docs]
def compute_overlap_to_mode(
self,
state: DetectorState,
mode_E: jax.Array,
mode_H: jax.Array,
) -> jax.Array:
# shape (time step, num_freqs, num_components, *spatial)
# time steps is always 1 and num_components always 6
phasors = state["phasor"]
phasors_E, phasors_H = phasors[0, 0, :3], phasors[0, 0, 3:]
E_cross_H_star_sim = jnp.cross(
mode_E,
jnp.conj(phasors_H),
axis=0,
)[self.propagation_axis]
E_star_cross_H_sim = jnp.cross(
jnp.conj(phasors_E),
mode_H,
axis=0,
)[self.propagation_axis]
integrand = E_cross_H_star_sim + E_star_cross_H_sim
integrand = integrand * self._face_area_weights()
alpha_coeff = jnp.sum(integrand)
# in pulsed mode return unscaled coefficient
if self.scaling_mode != "pulse":
alpha_coeff = alpha_coeff / 4.0
return alpha_coeff
[docs]
def compute_overlap(
self,
state: DetectorState,
) -> jax.Array:
if isinstance(self._mode_E, Null) or isinstance(self._mode_H, Null):
raise Exception("Need to call apply on ModeOverlapDetector before calling compute_mode_overlap!")
return self.compute_overlap_to_mode(
state=state,
mode_E=self._mode_E,
mode_H=self._mode_H,
)