from collections import namedtuple
from types import SimpleNamespace
from typing import List, Literal, Sequence
import jax
import jax.numpy as jnp
import numpy as np
import tidy3d
from jax.typing import ArrayLike
from tidy3d.components.mode.solver import compute_modes as _compute_modes
from fdtdx.core.misc import expand_to_3x3
from fdtdx.core.physics.metrics import normalize_by_poynting_flux
ModeTupleType = namedtuple("ModeTupleType", ["neff", "Ex", "Ey", "Ez", "Hx", "Hy", "Hz"])
"""A named tuple containing the mode fields and effective index.
Attributes:
neff: Complex effective refractive index of the mode
Ex: x-component of the electric field
Ey: y-component of the electric field
Ez: z-component of the electric field
Hx: x-component of the magnetic field
Hy: y-component of the magnetic field
Hz: z-component of the magnetic field
"""
def compute_mode_polarization_fraction(
mode: ModeTupleType,
tangential_axes: tuple[int, int],
pol: Literal["te", "tm"],
) -> float:
"""Mode polarization fraction.
Args:
mode (ModeTupleType): a ModeTupleType instance
tangential_axes (tuple[int, int]): indices of transverse E-field component axes.
pol (Literal["te", "tm"]): "te" or "tm" determines which axis is 'E1'
Returns:
float: Polarization fraction between 0 and 1.
"""
E_fields = [mode.Ex, mode.Ey, mode.Ez]
E1 = E_fields[tangential_axes[0]]
E2 = E_fields[tangential_axes[1]]
if pol == "te":
numerator = np.sum(np.abs(E1) ** 2)
elif pol == "tm":
numerator = np.sum(np.abs(E2) ** 2)
else:
raise ValueError(f"pol must be 'te' or 'tm', but got {pol}")
denominator = np.sum(np.abs(E1) ** 2 + np.abs(E2) ** 2) + 1e-18
return numerator / denominator
def sort_modes(
modes: list[ModeTupleType],
filter_pol: Literal["te", "tm"] | None,
tangential_axes: tuple[int, int],
) -> list[ModeTupleType]:
"""
Sort modes by polarization.
Args:
modes (list[ModeTupleType]): list of modes.
filter_pol (Literal["te", "tm"] | None): If not none, sort by polarization specificaton.
tangential_axes (tuple[int, int]): indices of transverse E-field component axes.
Returns:
list[ModeTupleType]: sorted list of modes.
"""
if filter_pol is None:
return sorted(modes, key=lambda m: float(np.real(m.neff)), reverse=True)
def is_matching(mode):
frac = compute_mode_polarization_fraction(mode, tangential_axes, filter_pol)
return frac >= 0.5
matching = [m for m in modes if is_matching(m)]
non_matching = [m for m in modes if not is_matching(m)]
matching_sorted = sorted(matching, key=lambda m: float(np.real(m.neff)), reverse=True)
non_matching_sorted = sorted(non_matching, key=lambda m: float(np.real(m.neff)), reverse=True)
return matching_sorted + non_matching_sorted
[docs]
def compute_mode(
frequency: float,
inv_permittivities: jax.Array, # shape (nx, ny, nz)
inv_permeabilities: jax.Array | float,
resolution: float | None = None,
direction: Literal["+", "-"] = "+",
mode_index: int = 0,
filter_pol: Literal["te", "tm"] | None = None,
dtype: jnp.dtype = jnp.float32,
bend_radius: float | None = None,
bend_axis: int | None = None,
transverse_coords: Sequence[jax.Array] | None = None,
) -> tuple[
jax.Array, # E
jax.Array, # H
jax.Array, # complex propagation constant
]:
"""Compute optical modes of a waveguide cross-section.
This function uses the Tidy3D mode solver to compute the optical modes of a given waveguide cross-section defined
by its permittivity distribution.
By default modes are sorted by their effective index. The mode_index argument indexes this sorted list of modes and
returns the desired mode. With filter_pol, it is also possible to only index a specific polarization.
Args:
frequency (float): Operating frequency in Hz
inv_permittivities (jax.Array): 3D array of inverse relative permittivity values
inv_permeabilities (jax.Array | float): 3D array of inverse relative permittivity values or single float for
uniform permeability distribution.
resolution (float | None): Uniform-grid spacing in metres. Required when ``transverse_coords`` is not
provided (uniform-grid path). Ignored when ``transverse_coords`` is given. Defaults to None.
direction (Literal["+", "-"]): Propagation direction, either "+" or "-".
mode_index (int, optional): Index of the mode to compute. Defaults to 0.
filter_pol (Literal["te", "tm"] | None, optional). If not None, modes are filtered by polarization.
dtype (jnp.dtype, optional): Float dtype of the simulation. Controls whether mode fields are returned
as complex64 (float32) or complex128 (float64). Defaults to jnp.float32.
bend_radius (float | None, optional): Bend radius of the waveguide in meters. Must be set together with
bend_axis. When set, the mode solver uses a conformal transformation to account for the bend. Defaults to
None (straight waveguide).
bend_axis (int | None, optional): Physical axis index (0/1/2) pointing from the waveguide toward the center
of curvature. Must differ from the propagation axis. Required when bend_radius is set. Defaults to None.
transverse_coords: Optional pair of physical edge-coordinate arrays, in metres, for the two axes transverse
to propagation. Each array must have one more entry than the corresponding transverse cell count.
When provided, the Tidy3D mode solver receives the non-uniform rectilinear grid directly.
JAX arrays are accepted; the numpy conversion happens inside the tidy3d callback so the function
remains compatible with ``jax.jit``.
Returns:
Tuple[jax.Array, jax.Array, jax.Array]:
Tuple of E, H field and the effective index as complex-valued jax arrays.
"""
# Input validation
if (
not (inv_permittivities.ndim == 4 and inv_permittivities.shape[0] in [1, 3, 9])
or sum(dim == 1 for dim in inv_permittivities.shape[1:]) != 1
):
raise Exception(f"Invalid shape of inv_permittivities: {inv_permittivities.shape}")
if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
if (
not (inv_permeabilities.ndim == 4 and inv_permeabilities.shape[0] in [1, 3, 9])
or sum(dim == 1 for dim in inv_permeabilities.shape[1:]) != 1
):
raise Exception(f"Invalid shape of inv_permeabilities: {inv_permeabilities.shape}")
if (bend_radius is None) != (bend_axis is None):
raise ValueError("bend_radius and bend_axis must both be set or both be None")
np_complex_dtype = np.complex128 if dtype == jnp.float64 else np.complex64
def mode_helper(permittivity, permeability, c0_um, c1_um):
# c0_um, c1_um are concrete numpy arrays here (materialised by pure_callback)
coords = [np.asarray(c0_um), np.asarray(c1_um)]
if bend_radius is not None:
assert bend_axis is not None
transverse_axes = [ax for ax in range(3) if ax != propagation_axis]
tidy3d_bend_axis = transverse_axes.index(bend_axis)
bend_radius_um = bend_radius / 1e-6
plane_center = (float(0.5 * (coords[0][0] + coords[0][-1])), float(0.5 * (coords[1][0] + coords[1][-1])))
else:
tidy3d_bend_axis = None
bend_radius_um = None
plane_center = None
modes = tidy3d_mode_computation_wrapper(
frequency=frequency,
permittivity_cross_section=permittivity,
permeability_cross_section=permeability,
coords=coords,
direction=direction,
num_modes=2 * (mode_index + 1) + 10,
bend_radius=bend_radius_um,
bend_axis=tidy3d_bend_axis,
plane_center=plane_center,
)
# sort modes by polarization
# tidy3d assumes propagation in the z-direction. The tangential axes are therefore x and y.
modes = sort_modes(modes, filter_pol, (0, 1))
mode = modes[mode_index]
if propagation_axis == 0:
mode_E, mode_H = (
np.stack([mode.Ez, mode.Ex, mode.Ey], axis=0).astype(np_complex_dtype),
np.stack([mode.Hz, mode.Hx, mode.Hy], axis=0).astype(np_complex_dtype),
)
elif propagation_axis == 1:
mode_E, mode_H = (
np.stack([mode.Ex, mode.Ez, mode.Ey], axis=0).astype(np_complex_dtype),
-np.stack([mode.Hx, mode.Hz, mode.Hy], axis=0).astype(np_complex_dtype),
)
elif propagation_axis == 2:
mode_E, mode_H = (
np.stack([mode.Ex, mode.Ey, mode.Ez], axis=0).astype(np_complex_dtype),
np.stack([mode.Hx, mode.Hy, mode.Hz], axis=0).astype(np_complex_dtype),
)
else:
raise Exception("This should never happen")
neff = np.asarray(mode.neff).astype(np_complex_dtype)
return mode_E, mode_H, neff
# compute input to tidy3d Mode solver
if inv_permittivities.shape[0] == 9:
eps = expand_to_3x3(inv_permittivities)
# Invert the 3x3 matrix
perm = (2, 3, 4, 0, 1) # (3, 3, nx, ny, nz) -> (nx, ny, nz, 3, 3)
inv_perm = (3, 4, 0, 1, 2) # (nx, ny, nz, 3, 3) -> (3, 3, nx, ny, nz)
permittivities = (
jnp.linalg.inv(eps.transpose(perm)).transpose(inv_perm).reshape(9, *inv_permittivities.shape[1:])
)
else:
permittivities = 1 / inv_permittivities
other_axes = [a for a in range(1, 4) if permittivities.shape[a] != 1]
propagation_axis = permittivities.shape[1:].index(1)
if transverse_coords is None:
if resolution is None:
raise ValueError("resolution is required when transverse_coords is not provided")
# Uniform grid: build concrete coordinate arrays in µm and pass as callback args.
c0_um = jnp.asarray(np.arange(permittivities.shape[other_axes[0]] + 1) * resolution / 1e-6)
c1_um = jnp.asarray(np.arange(permittivities.shape[other_axes[1]] + 1) * resolution / 1e-6)
normalization_area_weights = None
else:
if len(transverse_coords) != 2:
raise ValueError(
f"transverse_coords must contain exactly two coordinate arrays, got {len(transverse_coords)}"
)
# Shape validation uses .shape which is always concrete, even for JAX tracers.
expected_lengths = [permittivities.shape[dim] + 1 for dim in other_axes]
for axis_idx, (coord, expected_length) in enumerate(zip(transverse_coords, expected_lengths, strict=True)):
if coord.ndim != 1 or coord.shape[0] != expected_length:
raise ValueError(
f"transverse_coords[{axis_idx}] must be 1D with length {expected_length}, got {coord.shape}"
)
# Convert to µm for tidy3d; keep as JAX arrays so jax.jit can trace through.
c0_um = jnp.asarray(transverse_coords[0]) / 1e-6
c1_um = jnp.asarray(transverse_coords[1]) / 1e-6
# area_2d in m²: use jnp.diff so this works with traced JAX arrays.
area_2d = (
jnp.diff(jnp.asarray(transverse_coords[0]))[:, None] * jnp.diff(jnp.asarray(transverse_coords[1]))[None, :]
).astype(dtype)
weight_shape = [1, 1, 1]
weight_shape[other_axes[0] - 1] = area_2d.shape[0]
weight_shape[other_axes[1] - 1] = area_2d.shape[1]
normalization_area_weights = area_2d.reshape(weight_shape)
permittivity_squeezed = jnp.take(
permittivities,
indices=0,
axis=propagation_axis + 1,
)
# Rotate permittivity components to match tidy3d coordinate system
# tidy3d assumes propagation along z, so we need to map physical axes to tidy3d axes:
# - tidy3d x → first transverse axis
# - tidy3d y → second transverse axis
# - tidy3d z → propagation axis
if propagation_axis == 0:
# propagation along x: tidy3d (x,y,z) → physical (y,z,x)
perm_idx = [1, 2, 0]
perm_idx_full_anisotropy = [4, 5, 3, 7, 8, 6, 1, 2, 0]
elif propagation_axis == 1:
# propagation along y: tidy3d (x,y,z) → physical (x,z,y)
perm_idx = [0, 2, 1]
perm_idx_full_anisotropy = [0, 2, 1, 6, 8, 7, 3, 5, 4]
else: # propagation_axis == 2
# propagation along z: tidy3d (x,y,z) → physical (x,y,z)
perm_idx = [0, 1, 2]
perm_idx_full_anisotropy = [0, 1, 2, 3, 4, 5, 6, 7, 8]
# Only apply rotation if anisotropic (3 components)
if permittivity_squeezed.shape[0] == 3:
permittivity_squeezed = permittivity_squeezed[jnp.array(perm_idx), :, :]
if permittivity_squeezed.shape[0] == 9:
permittivity_squeezed = permittivity_squeezed[jnp.array(perm_idx_full_anisotropy), :, :]
jnp_complex_dtype = jnp.complex128 if dtype == jnp.float64 else jnp.complex64
result_shape_dtype = (
jnp.zeros((3, *permittivity_squeezed.shape[1:]), dtype=jnp_complex_dtype),
jnp.zeros((3, *permittivity_squeezed.shape[1:]), dtype=jnp_complex_dtype),
jnp.zeros(shape=(), dtype=jnp_complex_dtype),
)
if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0 and inv_permeabilities.shape[0] == 9:
mu = expand_to_3x3(inv_permeabilities)
# Invert the 3x3 matrix
perm = (2, 3, 4, 0, 1) # (3, 3, nx, ny, nz) -> (nx, ny, nz, 3, 3)
inv_perm = (3, 4, 0, 1, 2) # (nx, ny, nz, 3, 3) -> (3, 3, nx, ny, nz)
permeabilities = (
jnp.linalg.inv(mu.transpose(perm)).transpose(inv_perm).reshape(9, *inv_permeabilities.shape[1:])
)
else:
permeabilities = 1 / inv_permeabilities
if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
permeability_squeezed = jnp.take(
permeabilities,
indices=0,
axis=propagation_axis + 1,
)
# Apply same rotation to permeability if anisotropic
if permeability_squeezed.shape[0] == 3:
permeability_squeezed = permeability_squeezed[jnp.array(perm_idx), :, :]
if permeability_squeezed.shape[0] == 9:
permeability_squeezed = permeability_squeezed[jnp.array(perm_idx_full_anisotropy), :, :]
else: # float
permeability_squeezed = permeabilities
# pure callback to tidy3d is necessary to work in jitted environment.
# c0_um and c1_um are passed as explicit args so JAX materialises them to
# concrete numpy arrays before calling mode_helper, allowing np.asarray()
# inside the callback without raising TracerArrayConversionError.
mode_E_raw, mode_H_raw, eff_idx = jax.pure_callback(
mode_helper,
result_shape_dtype,
jax.lax.stop_gradient(permittivity_squeezed),
jax.lax.stop_gradient(permeability_squeezed),
jax.lax.stop_gradient(c0_um),
jax.lax.stop_gradient(c1_um),
)
mode_E = jnp.expand_dims(mode_E_raw, axis=propagation_axis + 1)
mode_H = jnp.expand_dims(mode_H_raw, axis=propagation_axis + 1)
# Tidy3D uses different scaling internally, so convert back
mode_H = mode_H * tidy3d.constants.ETA_0
mode_E_norm, mode_H_norm = normalize_by_poynting_flux(
mode_E,
mode_H,
axis=propagation_axis,
area_weights=normalization_area_weights,
)
return mode_E_norm, mode_H_norm, eff_idx
def tidy3d_mode_computation_wrapper(
frequency: float,
permittivity_cross_section: ArrayLike,
coords: List[np.ndarray],
direction: Literal["+", "-"],
permeability_cross_section: ArrayLike | float | None = None,
target_neff: float | None = None,
angle_theta: float = 0.0,
angle_phi: float = 0.0,
num_modes: int = 10,
precision: Literal["single", "double"] = "double",
bend_radius: float | None = None,
bend_axis: int | None = None,
plane_center: tuple[float, float] | None = None,
) -> List[ModeTupleType]:
"""Compute optical modes of a waveguide cross-section.
This function uses the Tidy3D mode solver to compute the optical modes of a given
waveguide cross-section defined by its permittivity distribution.
Args:
frequency (float): Operating frequency in Hz
permittivity_cross_section (jax.Array): 2D array of relative permittivity values
coords (List[np.ndarray]): List of coordinate arrays [x, y] defining the grid
direction (Literal["+", "-"], optional): Propagation direction, either "+" or "-"
permeability_cross_section (jax.Array | float | None, optional): 2D array of relative permeability values.
Defauts to None.
target_neff (float | None, optional): Target effective index to search around. Defaults to None.
angle_theta (float, optional): Polar angle in radians. Defaults to 0.0.
angle_phi (float, optional): Azimuthal angle in radians. Defaults to 0.0.
num_modes (int, optional): Number of modes to compute. Defaults to 10.
precision (Literal["single", "double"], optional): Numerical precision. Defaults to "double".
bend_radius (float | None, optional): Bend radius in microns (tidy3d units). Defaults to None.
bend_axis (int | None, optional): Axis index (0 or 1) of the center of curvature in tidy3d's transverse
coordinate frame. Defaults to None.
plane_center (tuple[float, float] | None, optional): Center of the mode plane in the same units as coords.
Required by tidy3d when bend_radius is set. Defaults to None.
Notes:
tidy3d assumes propagation in z-direction. The output fields should be handled accordingly.
Returns:
List[ModeTupleType]: List of computed modes sorted by decreasing real part of
effective index. Each mode contains the field components and effective index.
"""
# see https://docs.flexcompute.com/projects/tidy3d/en/latest/_autosummary/tidy3d.ModeSpec.html#tidy3d.ModeSpec
mode_spec = SimpleNamespace(
# Note that the filter_pol argument is not used here since it does not work from tidy3d
num_modes=num_modes,
target_neff=target_neff,
num_pml=(0, 0),
angle_theta=angle_theta,
angle_phi=angle_phi,
bend_radius=bend_radius,
bend_axis=bend_axis,
precision=precision,
track_freq="central",
group_index_step=False,
)
permittivity_cross_section = jnp.asarray(permittivity_cross_section)
permittivity_cross_section = expand_to_3x3(permittivity_cross_section)
permittivity_cross_section = permittivity_cross_section.reshape(9, *permittivity_cross_section.shape[2:])
eps_cross = [
permittivity_cross_section[0],
permittivity_cross_section[1],
permittivity_cross_section[2],
permittivity_cross_section[3],
permittivity_cross_section[4],
permittivity_cross_section[5],
permittivity_cross_section[6],
permittivity_cross_section[7],
permittivity_cross_section[8],
]
mu_cross = None
if permeability_cross_section is not None:
permeability_cross_section = jnp.asarray(permeability_cross_section)
permeability_cross_section = expand_to_3x3(permeability_cross_section)
permeability_cross_section = permeability_cross_section.reshape(9, *permeability_cross_section.shape[2:])
mu_cross = [
permeability_cross_section[0],
permeability_cross_section[1],
permeability_cross_section[2],
permeability_cross_section[3],
permeability_cross_section[4],
permeability_cross_section[5],
permeability_cross_section[6],
permeability_cross_section[7],
permeability_cross_section[8],
]
EH, neffs, _ = _compute_modes(
eps_cross=eps_cross,
coords=coords,
freq=frequency,
precision=precision,
mode_spec=mode_spec,
direction=direction,
mu_cross=mu_cross,
plane_center=plane_center,
)
((Ex, Ey, Ez), (Hx, Hy, Hz)) = EH.squeeze()
if num_modes == 1:
modes = [
ModeTupleType(
Ex=Ex,
Ey=Ey,
Ez=Ez,
Hx=Hx,
Hy=Hy,
Hz=Hz,
neff=float(neffs.real) + 1j * float(neffs.imag),
)
for _ in range(num_modes)
]
else:
modes = [
ModeTupleType(
Ex=Ex[..., i],
Ey=Ey[..., i],
Ez=Ez[..., i],
Hx=Hx[..., i],
Hy=Hy[..., i],
Hz=Hz[..., i],
neff=neffs[i],
)
for i in range(num_modes)
]
return modes