Source code for fdtdx.core.physics.metrics

"""Electromagnetic field metrics and normalization utilities.

This module provides functions for computing various electromagnetic field metrics
like energy density and Poynting flux, as well as field normalization operations.
All functions support JAX's automatic differentiation and work with the standard
FDTD field array shapes.
"""

import jax
import jax.numpy as jnp

from fdtdx.core.misc import expand_to_3x3


[docs] def compute_energy( E: jax.Array, H: jax.Array, inv_permittivity: jax.Array | float, inv_permeability: jax.Array | float, axis: int = 0, ) -> jax.Array: """Computes the total electromagnetic energy density of the field. Args: E (jax.Array): Electric field array with shape (3, nx, ny, nz) H (jax.Array): Magnetic field array with shape (3, nx, ny, nz) inv_permittivity (jax.Array | float): Inverse permittivity. Shape (3, nx, ny, nz) for anisotropic or scalar inv_permeability (jax.Array | float): Inverse permeability. Shape (3, nx, ny, nz) for anisotropic or scalar axis (int, optional): Axis index of the X,Y,Z component for the E and H field. Defaults to 0. Returns: jax.Array: Total energy density array with shape (nx, ny, nz) """ inv_eps_shape = getattr(inv_permittivity, "shape", ()) inv_mu_shape = getattr(inv_permeability, "shape", ()) if (inv_eps_shape and inv_eps_shape[0] == 9) or (inv_mu_shape and inv_mu_shape[0] == 9): inv_permittivity = expand_to_3x3(inv_permittivity) inv_permeability = expand_to_3x3(inv_permeability) # Invert the 3x3 matrices to get eps and mu perm = (2, 3, 4, 0, 1) # (3, 3, nx, ny, nz) -> (nx, ny, nz, 3, 3) inv_perm = (3, 4, 0, 1, 2) # (nx, ny, nz, 3, 3) -> (3, 3, nx, ny, nz) eps = jnp.linalg.inv(inv_permittivity.transpose(perm)).transpose(inv_perm) mu = jnp.linalg.inv(inv_permeability.transpose(perm)).transpose(inv_perm) # For fully anisotropic materials: 0.5 * sum_ij E_i* ε_ij E_j # E/H has shape (3, nx, ny, nz), eps/mu has shape (3, 3, nx, ny, nz) energy_E = 0.5 * jnp.real(jnp.einsum("ixyz,ijxyz,jxyz->xyz", jnp.conj(E), eps, E)) energy_H = 0.5 * jnp.real(jnp.einsum("ixyz,ijxyz,jxyz->xyz", jnp.conj(H), mu, H)) total_energy = energy_E + energy_H return total_energy else: # For anisotropic materials: energy = 0.5 * sum_i(ε_i * |E_i|² + μ_i * |H_i|²) # Component-wise calculation then sum E_squared = jnp.square(jnp.abs(E)) # shape: (3, nx, ny, nz) energy_E = 0.5 * (1 / inv_permittivity) * E_squared # component-wise multiplication energy_E = jnp.sum(energy_E, axis=axis) # sum over components H_squared = jnp.square(jnp.abs(H)) # shape: (3, nx, ny, nz) energy_H = 0.5 * (1 / inv_permeability) * H_squared # component-wise multiplication energy_H = jnp.sum(energy_H, axis=axis) # sum over components total_energy = energy_E + energy_H return total_energy
[docs] def normalize_by_energy( E: jax.Array, H: jax.Array, inv_permittivity: jax.Array | float, inv_permeability: jax.Array | float, ) -> tuple[jax.Array, jax.Array]: """Normalizes electromagnetic fields by their total energy. Args: E (jax.Array): Electric field array with shape (3, nx, ny, nz) H (jax.Array): Magnetic field array with shape (3, nx, ny, nz) inv_permittivity (jax.Array | float): Inverse of the electric permittivity array inv_permeability (jax.Array | float): Inverse of the magnetic permeability array Returns: tuple[jax.Array, jax.Array]: Tuple of (normalized E field, normalized H field) """ total_energy = compute_energy( E=E, H=H, inv_permittivity=inv_permittivity, inv_permeability=inv_permeability, ) energy_root = jnp.sqrt(jnp.sum(total_energy)) norm_E = E / energy_root norm_H = H / energy_root return norm_E, norm_H
[docs] def compute_poynting_flux(E: jax.Array, H: jax.Array, axis: int = 0) -> jax.Array: """Calculates the Poynting vector (energy flux) from E and H fields. Args: E (jax.Array): Electric field array with shape (3, nx, ny, nz) H (jax.Array): Magnetic field array with shape (3, nx, ny, nz) axis (int, optional): Axis for computing the poynting flux. Defaults to 0. Returns: jax.Array: Poynting vector array with shape (3, nx, ny, nz) representing energy flux in each direction """ return jnp.cross( E, jnp.conj(H), axisa=axis, axisb=axis, axisc=axis, )
[docs] def normalize_by_poynting_flux( E: jax.Array, H: jax.Array, axis: int, area_weights: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: """Normalize fields so the integrated Poynting flux along ``axis`` is one. Args: E: Electric field array with component axis first. H: Magnetic field array with component axis first. axis: Physical propagation axis whose Poynting component is integrated. area_weights: Optional detector-plane area weights broadcastable to ``E[axis]``. Uniform-grid callers may omit this for the historical raw-sum normalization; non-uniform callers should provide weights so refinement alone does not change the normalization. """ # Compute Poynting vector components S_complex = jnp.cross(jnp.conj(E), H, axisa=0, axisb=0, axisc=0) S_real = 0.5 * jnp.real(S_complex[axis]) # power flow in desired direction if area_weights is not None: S_real = S_real * area_weights # Integrate over transverse plane (axis orthogonal to `axis`) power = jnp.abs(jnp.sum(S_real)) # Normalize norm_factor = jnp.sqrt(power) E_norm = E / norm_factor H_norm = H / norm_factor return E_norm, H_norm
def resample_to_uniform_2d( field: jax.Array, x_centers: jax.Array, y_centers: jax.Array, ) -> tuple[jax.Array, jax.Array | float, jax.Array | float]: """Resample a ``(component, x, y)`` field onto a uniform transverse grid. Interpolates each component from the given non-uniform physical center coordinates onto a uniform grid with the same extent and number of points. Returns the resampled field and the uniform grid spacings. """ nx, ny = field.shape[1], field.shape[2] target_x = jnp.linspace(x_centers[0], x_centers[-1], nx) target_y = jnp.linspace(y_centers[0], y_centers[-1], ny) def interp_x(component): return jax.vmap(lambda column: jnp.interp(target_x, x_centers, column), in_axes=1, out_axes=1)(component) def interp_y(component): return jax.vmap(lambda row: jnp.interp(target_y, y_centers, row))(component) field = jax.vmap(interp_x)(field) field = jax.vmap(interp_y)(field) dx = (target_x[1] - target_x[0]) if nx > 1 else 1.0 dy = (target_y[1] - target_y[0]) if ny > 1 else 1.0 return field, dx, dy