from abc import ABC, abstractmethod
from typing import Self
import jax
import jax.numpy as jnp
import numpy as np
from fdtdx.core.grid import calculate_time_offset_yee
from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.core.linalg import get_wave_vector_raw, rotate_vector
from fdtdx.core.misc import expand_to_3x3, linear_interpolated_indexing, normalize_polarization_for_source
from fdtdx.core.physics.metrics import compute_energy
from fdtdx.dispersion import effective_inv_permittivity
from fdtdx.objects.sources.tfsf import TFSFPlaneSource, _build_dispersive_H_filter
@autoinit
class LinearlyPolarizedPlaneSource(TFSFPlaneSource, ABC):
#: the electric polarization vector
fixed_E_polarization_vector: tuple[float, float, float] | None = frozen_field(default=None)
#: the magnetic polarization vector
fixed_H_polarization_vector: tuple[float, float, float] | None = frozen_field(default=None)
#: whether to normalize the polarization vector
normalize_by_energy: bool = frozen_field(default=True)
def apply(
self: Self,
key: jax.Array,
inv_permittivities: jax.Array,
inv_permeabilities: jax.Array | float,
dispersive_c1: jax.Array | None = None,
dispersive_c2: jax.Array | None = None,
dispersive_c3: jax.Array | None = None,
):
# inv_permittivities shape: (3, Nx, Ny, Nz) - slice with component dimension
inv_permittivities = inv_permittivities[:, *self.grid_slice]
if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
# inv_permeabilities shape: (3, Nx, Ny, Nz) - slice with component dimension
inv_permeabilities = inv_permeabilities[:, *self.grid_slice]
# Keep a handle to the raw (ε∞) inverse permittivity before any
# carrier-frequency correction — the broadband impedance filter
# computed below needs ε∞ to reconstruct the full ε(ω) spectrum.
inv_eps_inf_slice = inv_permittivities
# If the simulation is dispersive, evaluate the real effective inverse
# permittivity at the source carrier frequency so that the impedance and
# energy normalization reflect the true medium the source sits in,
# not just the high-frequency permittivity epsilon_infinity.
c1_slice = c2_slice = c3_slice = None
if dispersive_c1 is not None and dispersive_c2 is not None and dispersive_c3 is not None:
# dispersive_c* shape: (num_poles, 1, Nx, Ny, Nz) → slice spatial axes
c1_slice = dispersive_c1[:, :, *self.grid_slice]
c2_slice = dispersive_c2[:, :, *self.grid_slice]
c3_slice = dispersive_c3[:, :, *self.grid_slice]
inv_permittivities = effective_inv_permittivity(
inv_eps=inv_permittivities,
c1=c1_slice,
c2=c2_slice,
c3=c3_slice,
omega=2.0 * np.pi * self.wave_character.get_frequency(),
dt=self._config.time_step_duration,
)
# determine E/H polarization
e_pol_raw, h_pol_raw = normalize_polarization_for_source(
direction=self.direction,
propagation_axis=self.propagation_axis,
fixed_E_polarization_vector=self.fixed_E_polarization_vector,
fixed_H_polarization_vector=self.fixed_H_polarization_vector,
dtype=self._config.dtype,
)
wave_vector_raw = get_wave_vector_raw(
direction=self.direction,
propagation_axis=self.propagation_axis,
dtype=self._config.dtype,
)
center, azimuth, elevation = self._get_random_parts(key)
# tilt polarizations
axes_tpl = (self.horizontal_axis, self.vertical_axis, self.propagation_axis)
wave_vector = rotate_vector(wave_vector_raw, azimuth, elevation, axes_tpl)
e_pol = rotate_vector(e_pol_raw, azimuth, elevation, axes_tpl)
h_pol = rotate_vector(h_pol_raw, azimuth, elevation, axes_tpl)
# update is amplitude multiplied by polarization
amplitude_raw = self._get_amplitude_raw(center)[None, ...]
# map amplitude to propagation plane
w, h = jnp.meshgrid(
jnp.arange(self.grid_shape[self.horizontal_axis]),
jnp.arange(self.grid_shape[self.vertical_axis]),
indexing="ij",
)
wh_indices = jnp.stack((w, h), axis=-1)
wh_indices -= center
# basis in plane
h_list = [0, 0, 0]
h_list[self.horizontal_axis] = 1
h_axis = jnp.asarray(h_list, dtype=self._config.dtype)
u_basis = h_axis - jnp.dot(h_axis, wave_vector) * wave_vector
u_basis = u_basis / jnp.linalg.norm(u_basis)
v_basis = jnp.cross(wave_vector, u_basis)
# projection
def project(point):
point_list = [point[0], point[1]]
point_list.insert(self.propagation_axis, 0)
point = jnp.asarray(point_list, dtype=self._config.dtype)
projection = point - jnp.dot(point, wave_vector) * wave_vector
# Convert to plane coordinates
u = jnp.dot(projection, u_basis)
v = jnp.dot(projection, v_basis)
return jnp.asarray((u, v), dtype=self._config.dtype)
float_projected = jax.vmap(project)(wh_indices.reshape(-1, 2))
float_projected += center
# interpolate floating indices in original array
index_fn = jax.vmap(linear_interpolated_indexing, in_axes=(0, None))
interp = index_fn(float_projected, amplitude_raw.squeeze())
amplitude = interp.reshape(*amplitude_raw.shape)
E = amplitude * e_pol[:, None, None, None]
H = amplitude * h_pol[:, None, None, None]
if self.normalize_by_energy:
energy = compute_energy(
E=E,
H=H,
inv_permittivity=inv_permittivities,
inv_permeability=inv_permeabilities,
)
total_energy_root = jnp.sqrt(energy.sum())
E = E / total_energy_root
H = H / total_energy_root
# adjust H for impedance of the medium
# check if fully anisotropic
if (
isinstance(inv_permittivities, jax.Array)
and inv_permittivities.ndim >= 1
and inv_permittivities.shape[0] == 9
) or (
isinstance(inv_permeabilities, jax.Array)
and inv_permeabilities.ndim >= 1
and inv_permeabilities.shape[0] == 9
):
# convert to 3x3 tensors
inv_eps_tensor = expand_to_3x3(inv_permittivities) # shape: (3, 3, Nx, Ny, Nz)
inv_mu_tensor = expand_to_3x3(inv_permeabilities) # shape: (3, 3, Nx, Ny, Nz)
# invert to get eps and mu tensors
perm = (2, 3, 4, 0, 1) # (3, 3, nx, ny, nz) -> (nx, ny, nz, 3, 3)
inv_perm = (3, 4, 0, 1, 2) # (nx, ny, nz, 3, 3) -> (3, 3, nx, ny, nz)
eps = jnp.linalg.inv(inv_eps_tensor.transpose(perm)).transpose(inv_perm)
mu = jnp.linalg.inv(inv_mu_tensor.transpose(perm)).transpose(inv_perm)
# compute effective permittivity and permeability along polarization directions
eps_eff = jnp.einsum("i,ijxyz,j->xyz", e_pol, eps, e_pol)
mu_eff = jnp.einsum("i,ijxyz,j->xyz", h_pol, mu, h_pol)
impedance = jnp.sqrt(mu_eff / eps_eff)
else:
impedance = jnp.sqrt(inv_permittivities / inv_permeabilities)
H = H / impedance
time_offset_E, time_offset_H = calculate_time_offset_yee(
center=center,
wave_vector=wave_vector,
inv_permittivities=inv_permittivities,
inv_permeabilities=inv_permeabilities,
resolution=self._config.resolution,
time_step_duration=self._config.time_step_duration,
e_polarization=e_pol,
h_polarization=h_pol,
)
self = self.aset("_E", E, create_new_ok=True)
self = self.aset("_H", H, create_new_ok=True)
self = self.aset("_time_offset_E", time_offset_E, create_new_ok=True)
self = self.aset("_time_offset_H", time_offset_H, create_new_ok=True)
# Broadband impedance correction. The carrier-frequency rescale above
# only matches η at ω_c; a wide-bandwidth pulse (e.g. GaussianPulseProfile)
# sees a frequency-dependent impedance in a dispersive medium and the
# TFSF boundary leaks unphysical reflections for frequencies away from
# ω_c. Precompute a filtered H-side temporal profile s_H(t) whose
# spectrum is S(ω)·√(ε(ω)/ε(ω_c)) so that the injected H field has
# the frequency-dependent impedance correction baked in.
if c1_slice is not None and c2_slice is not None and c3_slice is not None:
filtered = _build_dispersive_H_filter(
temporal_profile=self.temporal_profile,
wave_character=self.wave_character,
dt=self._config.time_step_duration,
num_time_steps=self._config.time_steps_total,
c1_slice=c1_slice,
c2_slice=c2_slice,
c3_slice=c3_slice,
inv_eps_inf_slice=inv_eps_inf_slice,
dtype=self._config.dtype,
)
self = self.aset("_temporal_H_filter", filtered, create_new_ok=True)
else:
# Reused source applied in a non-dispersive context: clear any stale
# H-side filter left over from a previous dispersive apply, otherwise
# the TFSF inner loop would keep injecting filtered amplitudes.
self = self.aset("_temporal_H_filter", None, create_new_ok=True)
return self
@abstractmethod
def _get_amplitude_raw(
self,
center: jax.Array,
) -> jax.Array: # shape (*grid_shape)
# in normal coordinates, not yee grid
del center
raise NotImplementedError()
[docs]
@autoinit
class GaussianPlaneSource(LinearlyPolarizedPlaneSource):
#: the radius of the gaussian source
radius: float = frozen_field()
#: the standard deviation of the gaussian source
std: float = frozen_field(default=1 / 3) # relative to radius
@staticmethod
def _gauss_profile(
width: int,
height: int,
axis: int,
center: tuple[float, float] | jax.Array,
radii: tuple[float, float],
std: float,
) -> jax.Array: # shape (*grid_shape)
grid = (
jnp.stack(jnp.meshgrid(*map(jnp.arange, (height, width)), indexing="xy"), axis=-1) - jnp.asarray(center)
) / jnp.asarray(radii)
euc_dist = (grid**2).sum(axis=-1)
mask = euc_dist < 1
mask = jnp.expand_dims(mask, axis=axis)
exp_part = jnp.exp(-0.5 * euc_dist / std**2)
exp_part = jnp.expand_dims(exp_part, axis=axis)
profile = jnp.where(mask, exp_part, 0)
profile = profile / profile.sum()
return profile
def _get_amplitude_raw(
self,
center: jax.Array,
) -> jax.Array:
grid_radius = self.radius / self._config.resolution
profile = self._gauss_profile(
width=self.grid_shape[self.horizontal_axis],
height=self.grid_shape[self.vertical_axis],
axis=self.propagation_axis,
center=center,
radii=(grid_radius, grid_radius),
std=self.std,
)
return profile