Source code for fdtdx.objects.sources.linear_polarization

from abc import ABC, abstractmethod
from typing import Self

import jax
import jax.numpy as jnp
import numpy as np

from fdtdx.core.grid import calculate_time_offset_yee
from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.core.linalg import get_wave_vector_raw, rotate_vector
from fdtdx.core.misc import expand_to_3x3, linear_interpolated_indexing, normalize_polarization_for_source
from fdtdx.core.physics.metrics import compute_energy
from fdtdx.dispersion import effective_inv_permittivity
from fdtdx.objects.sources.tfsf import TFSFPlaneSource, _build_dispersive_H_filter


@autoinit
class LinearlyPolarizedPlaneSource(TFSFPlaneSource, ABC):
    #: the electric polarization vector
    fixed_E_polarization_vector: tuple[float, float, float] | None = frozen_field(default=None)

    #: the magnetic polarization vector
    fixed_H_polarization_vector: tuple[float, float, float] | None = frozen_field(default=None)

    #: whether to normalize the polarization vector
    normalize_by_energy: bool = frozen_field(default=True)

    def apply(
        self: Self,
        key: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array | float,
        dispersive_c1: jax.Array | None = None,
        dispersive_c2: jax.Array | None = None,
        dispersive_c3: jax.Array | None = None,
    ):
        # inv_permittivities shape: (3, Nx, Ny, Nz) - slice with component dimension
        inv_permittivities = inv_permittivities[:, *self.grid_slice]
        if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
            # inv_permeabilities shape: (3, Nx, Ny, Nz) - slice with component dimension
            inv_permeabilities = inv_permeabilities[:, *self.grid_slice]

        # Keep a handle to the raw (ε∞) inverse permittivity before any
        # carrier-frequency correction — the broadband impedance filter
        # computed below needs ε∞ to reconstruct the full ε(ω) spectrum.
        inv_eps_inf_slice = inv_permittivities

        # If the simulation is dispersive, evaluate the real effective inverse
        # permittivity at the source carrier frequency so that the impedance and
        # energy normalization reflect the true medium the source sits in,
        # not just the high-frequency permittivity epsilon_infinity.
        c1_slice = c2_slice = c3_slice = None
        if dispersive_c1 is not None and dispersive_c2 is not None and dispersive_c3 is not None:
            # dispersive_c* shape: (num_poles, 1, Nx, Ny, Nz) → slice spatial axes
            c1_slice = dispersive_c1[:, :, *self.grid_slice]
            c2_slice = dispersive_c2[:, :, *self.grid_slice]
            c3_slice = dispersive_c3[:, :, *self.grid_slice]
            inv_permittivities = effective_inv_permittivity(
                inv_eps=inv_permittivities,
                c1=c1_slice,
                c2=c2_slice,
                c3=c3_slice,
                omega=2.0 * np.pi * self.wave_character.get_frequency(),
                dt=self._config.time_step_duration,
            )

        # determine E/H polarization
        e_pol_raw, h_pol_raw = normalize_polarization_for_source(
            direction=self.direction,
            propagation_axis=self.propagation_axis,
            fixed_E_polarization_vector=self.fixed_E_polarization_vector,
            fixed_H_polarization_vector=self.fixed_H_polarization_vector,
            dtype=self._config.dtype,
        )
        wave_vector_raw = get_wave_vector_raw(
            direction=self.direction,
            propagation_axis=self.propagation_axis,
            dtype=self._config.dtype,
        )

        center, azimuth, elevation = self._get_random_parts(key)

        # tilt polarizations
        axes_tpl = (self.horizontal_axis, self.vertical_axis, self.propagation_axis)
        wave_vector = rotate_vector(wave_vector_raw, azimuth, elevation, axes_tpl)
        e_pol = rotate_vector(e_pol_raw, azimuth, elevation, axes_tpl)
        h_pol = rotate_vector(h_pol_raw, azimuth, elevation, axes_tpl)

        # update is amplitude multiplied by polarization
        amplitude_raw = self._get_amplitude_raw(center)[None, ...]

        # map amplitude to propagation plane
        w, h = jnp.meshgrid(
            jnp.arange(self.grid_shape[self.horizontal_axis]),
            jnp.arange(self.grid_shape[self.vertical_axis]),
            indexing="ij",
        )
        wh_indices = jnp.stack((w, h), axis=-1)
        wh_indices -= center
        # basis in plane
        h_list = [0, 0, 0]
        h_list[self.horizontal_axis] = 1
        h_axis = jnp.asarray(h_list, dtype=self._config.dtype)
        u_basis = h_axis - jnp.dot(h_axis, wave_vector) * wave_vector
        u_basis = u_basis / jnp.linalg.norm(u_basis)
        v_basis = jnp.cross(wave_vector, u_basis)

        # projection
        def project(point):
            point_list = [point[0], point[1]]
            point_list.insert(self.propagation_axis, 0)
            point = jnp.asarray(point_list, dtype=self._config.dtype)
            projection = point - jnp.dot(point, wave_vector) * wave_vector
            # Convert to plane coordinates
            u = jnp.dot(projection, u_basis)
            v = jnp.dot(projection, v_basis)
            return jnp.asarray((u, v), dtype=self._config.dtype)

        float_projected = jax.vmap(project)(wh_indices.reshape(-1, 2))
        float_projected += center
        # interpolate floating indices in original array
        index_fn = jax.vmap(linear_interpolated_indexing, in_axes=(0, None))
        interp = index_fn(float_projected, amplitude_raw.squeeze())
        amplitude = interp.reshape(*amplitude_raw.shape)

        E = amplitude * e_pol[:, None, None, None]
        H = amplitude * h_pol[:, None, None, None]

        if self.normalize_by_energy:
            energy = compute_energy(
                E=E,
                H=H,
                inv_permittivity=inv_permittivities,
                inv_permeability=inv_permeabilities,
            )
            total_energy_root = jnp.sqrt(energy.sum())
            E = E / total_energy_root
            H = H / total_energy_root

        # adjust H for impedance of the medium
        # check if fully anisotropic
        if (
            isinstance(inv_permittivities, jax.Array)
            and inv_permittivities.ndim >= 1
            and inv_permittivities.shape[0] == 9
        ) or (
            isinstance(inv_permeabilities, jax.Array)
            and inv_permeabilities.ndim >= 1
            and inv_permeabilities.shape[0] == 9
        ):
            # convert to 3x3 tensors
            inv_eps_tensor = expand_to_3x3(inv_permittivities)  # shape: (3, 3, Nx, Ny, Nz)
            inv_mu_tensor = expand_to_3x3(inv_permeabilities)  # shape: (3, 3, Nx, Ny, Nz)

            # invert to get eps and mu tensors
            perm = (2, 3, 4, 0, 1)  # (3, 3, nx, ny, nz) -> (nx, ny, nz, 3, 3)
            inv_perm = (3, 4, 0, 1, 2)  # (nx, ny, nz, 3, 3) -> (3, 3, nx, ny, nz)
            eps = jnp.linalg.inv(inv_eps_tensor.transpose(perm)).transpose(inv_perm)
            mu = jnp.linalg.inv(inv_mu_tensor.transpose(perm)).transpose(inv_perm)

            # compute effective permittivity and permeability along polarization directions
            eps_eff = jnp.einsum("i,ijxyz,j->xyz", e_pol, eps, e_pol)
            mu_eff = jnp.einsum("i,ijxyz,j->xyz", h_pol, mu, h_pol)
            impedance = jnp.sqrt(mu_eff / eps_eff)
        else:
            impedance = jnp.sqrt(inv_permittivities / inv_permeabilities)

        H = H / impedance

        time_offset_E, time_offset_H = calculate_time_offset_yee(
            center=center,
            wave_vector=wave_vector,
            inv_permittivities=inv_permittivities,
            inv_permeabilities=inv_permeabilities,
            resolution=self._config.resolution,
            time_step_duration=self._config.time_step_duration,
            e_polarization=e_pol,
            h_polarization=h_pol,
        )

        self = self.aset("_E", E, create_new_ok=True)
        self = self.aset("_H", H, create_new_ok=True)
        self = self.aset("_time_offset_E", time_offset_E, create_new_ok=True)
        self = self.aset("_time_offset_H", time_offset_H, create_new_ok=True)

        # Broadband impedance correction. The carrier-frequency rescale above
        # only matches η at ω_c; a wide-bandwidth pulse (e.g. GaussianPulseProfile)
        # sees a frequency-dependent impedance in a dispersive medium and the
        # TFSF boundary leaks unphysical reflections for frequencies away from
        # ω_c. Precompute a filtered H-side temporal profile s_H(t) whose
        # spectrum is S(ω)·√(ε(ω)/ε(ω_c)) so that the injected H field has
        # the frequency-dependent impedance correction baked in.
        if c1_slice is not None and c2_slice is not None and c3_slice is not None:
            filtered = _build_dispersive_H_filter(
                temporal_profile=self.temporal_profile,
                wave_character=self.wave_character,
                dt=self._config.time_step_duration,
                num_time_steps=self._config.time_steps_total,
                c1_slice=c1_slice,
                c2_slice=c2_slice,
                c3_slice=c3_slice,
                inv_eps_inf_slice=inv_eps_inf_slice,
                dtype=self._config.dtype,
            )
            self = self.aset("_temporal_H_filter", filtered, create_new_ok=True)
        else:
            # Reused source applied in a non-dispersive context: clear any stale
            # H-side filter left over from a previous dispersive apply, otherwise
            # the TFSF inner loop would keep injecting filtered amplitudes.
            self = self.aset("_temporal_H_filter", None, create_new_ok=True)

        return self

    @abstractmethod
    def _get_amplitude_raw(
        self,
        center: jax.Array,
    ) -> jax.Array:  # shape (*grid_shape)
        # in normal coordinates, not yee grid
        del center
        raise NotImplementedError()


[docs] @autoinit class GaussianPlaneSource(LinearlyPolarizedPlaneSource): #: the radius of the gaussian source radius: float = frozen_field() #: the standard deviation of the gaussian source std: float = frozen_field(default=1 / 3) # relative to radius @staticmethod def _gauss_profile( width: int, height: int, axis: int, center: tuple[float, float] | jax.Array, radii: tuple[float, float], std: float, ) -> jax.Array: # shape (*grid_shape) grid = ( jnp.stack(jnp.meshgrid(*map(jnp.arange, (height, width)), indexing="xy"), axis=-1) - jnp.asarray(center) ) / jnp.asarray(radii) euc_dist = (grid**2).sum(axis=-1) mask = euc_dist < 1 mask = jnp.expand_dims(mask, axis=axis) exp_part = jnp.exp(-0.5 * euc_dist / std**2) exp_part = jnp.expand_dims(exp_part, axis=axis) profile = jnp.where(mask, exp_part, 0) profile = profile / profile.sum() return profile def _get_amplitude_raw( self, center: jax.Array, ) -> jax.Array: grid_radius = self.radius / self._config.resolution profile = self._gauss_profile( width=self.grid_shape[self.horizontal_axis], height=self.grid_shape[self.vertical_axis], axis=self.propagation_axis, center=center, radii=(grid_radius, grid_radius), std=self.std, ) return profile
[docs] @autoinit class UniformPlaneSource(LinearlyPolarizedPlaneSource): #: the amplitude of the uniform source amplitude: float = frozen_field(default=1.0) def _get_amplitude_raw( self, center: jax.Array, ) -> jax.Array: del center profile = jnp.ones(shape=self.grid_shape, dtype=self._config.dtype) return self.amplitude * profile