Source code for fdtdx.objects.device.parameters.continuous

import math
from typing import Sequence

import jax
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import autoinit, frozen_field, frozen_private_field
from fdtdx.objects.device.parameters.transform import ParameterTransformation, SameShapeTypeParameterTransform
from fdtdx.typing import ParameterType


[docs] @autoinit class StandardToInversePermittivityRange(ParameterTransformation): """Maps standard [0,1] range to inverse permittivity range. Linearly maps values from [0,1] to the range between minimum and maximum inverse permittivity values allowed by the material configuration. For anisotropic materials, each axis (x, y, z) is interpolated independently within its own min/max range, producing output with shape ``(3, *input_shape)``. """ _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field( default=ParameterType.CONTINUOUS ) def _get_input_shape_impl( self, output_shape: dict[str, tuple[int, ...]], ) -> dict[str, tuple[int, ...]]: # Input shape is same as output for isotropic, or without leading 3 for anisotropic # Since we don't have access to materials yet at this point in some cases, # we return the output shape as-is (input param has spatial shape, output may have component dim) return output_shape def _get_output_type_impl( self, input_type: dict[str, ParameterType], ) -> dict[str, ParameterType]: # Output type is same as input (continuous -> continuous) return input_type def __call__( self, params: dict[str, jax.Array], **kwargs, ) -> dict[str, jax.Array]: del kwargs is_isotropic = all(mat.is_isotropic_permittivity for mat in self._materials.values()) is_diagonally_anisotropic = all(mat.is_diagonally_anisotropic_permittivity for mat in self._materials.values()) if is_isotropic: # Isotropic case: all materials have same permittivity on all axes max_inv_perm, min_inv_perm = -math.inf, math.inf for v in self._materials.values(): # For isotropic, all components are equal, just use first p = 1 / v.permittivity[0] if p > max_inv_perm: max_inv_perm = p if p < min_inv_perm: min_inv_perm = p result = {} for k, v in params.items(): mapped = v * (max_inv_perm - min_inv_perm) + min_inv_perm result[k] = mapped return result elif is_diagonally_anisotropic: # Compute min/max for each axis separately max_inv_perm = [-math.inf, -math.inf, -math.inf] min_inv_perm = [math.inf, math.inf, math.inf] for v in self._materials.values(): # v.permittivity is 9-tuple (εxx,...,εyy,...,εzz); diagonal at indices 0, 4, 8 for axis in range(3): p = 1 / v.permittivity[axis * 4] if p > max_inv_perm[axis]: max_inv_perm[axis] = p if p < min_inv_perm[axis]: min_inv_perm[axis] = p max_inv_perm_arr = jnp.asarray(max_inv_perm)[:, None, None, None] min_inv_perm_arr = jnp.asarray(min_inv_perm)[:, None, None, None] # Transform: broadcast input to (3, ...) and interpolate each axis result = {} for k, v in params.items(): # v has shape (Nx, Ny, Nz), expand to (3, Nx, Ny, Nz) v_expanded = v[None, ...] # (1, Nx, Ny, Nz) mapped = v_expanded * (max_inv_perm_arr - min_inv_perm_arr) + min_inv_perm_arr result[k] = mapped return result else: # fully anisotropic # Compute min/max for each tensor element separately max_inv_perm = [ -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, -math.inf, ] min_inv_perm = [math.inf, math.inf, math.inf, math.inf, math.inf, math.inf, math.inf, math.inf, math.inf] for v in self._materials.values(): # v.permittivity is tuple (εxx, εxy, εxz, εyx, εyy, εyz, εzx, εzy, εzz) inv_perm = jnp.linalg.inv(jnp.array(v.permittivity).reshape(3, 3)).flatten() for i in range(9): if inv_perm[i] > max_inv_perm[i]: max_inv_perm[i] = inv_perm[i] if inv_perm[i] < min_inv_perm[i]: min_inv_perm[i] = inv_perm[i] max_inv_perm_arr = jnp.asarray(max_inv_perm)[:, None, None, None] min_inv_perm_arr = jnp.asarray(min_inv_perm)[:, None, None, None] # Transform: broadcast input to (9, ...) and interpolate each element result = {} for k, v in params.items(): # v has shape (Nx, Ny, Nz), expand to (9, Nx, Ny, Nz) v_expanded = v[None, ...] # (1, Nx, Ny, Nz) mapped = v_expanded * (max_inv_perm_arr - min_inv_perm_arr) + min_inv_perm_arr result[k] = mapped return result
[docs] @autoinit class StandardToCustomRange(SameShapeTypeParameterTransform): """Maps standard [0,1] range to custom range [min_value, max_value]. Linearly maps values from [0,1] to a custom range specified by min_value and max_value parameters. """ #: Minimum value of target range. Defaults to zero. min_value: float = frozen_field(default=0) #: Maximum value of target range. Defaults to one. max_value: float = frozen_field(default=1) _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field( default=ParameterType.CONTINUOUS ) def __call__( self, params: dict[str, jax.Array], **kwargs, ) -> dict[str, jax.Array]: del kwargs result = {} for k, v in params.items(): mapped = v * (self.max_value - self.min_value) + self.min_value result[k] = mapped return result
[docs] @autoinit class StandardToPlusOneMinusOneRange(StandardToCustomRange): """Maps standard [0,1] range to [-1,1] range. Special case of StandardToCustomRange that maps to [-1,1] range. Used for symmetric value ranges around zero. """ min_value: float = frozen_private_field(default=-1) max_value: float = frozen_private_field(default=1)
[docs] @autoinit class GaussianSmoothing2D(SameShapeTypeParameterTransform): """ Applies Gaussian smoothing to 2D parameter arrays. This transform convolves the input with a 2D Gaussian kernel, which helps reduce noise and smooth the data. """ #: Integer specifying the standard deviation of the Gaussian kernel in discrete units. std_discrete: int = frozen_field() #: 1D array of shape ``(ny,)`` used as padding before axis 0. ``None`` falls back to edge-repeat. padding_low_axis0: jax.Array | None = frozen_field(default=None) #: 1D array of shape ``(ny,)`` used as padding after axis 0. ``None`` falls back to edge-repeat. padding_high_axis0: jax.Array | None = frozen_field(default=None) #: 1D array of shape ``(nx,)`` used as padding before axis 1. ``None`` falls back to edge-repeat. padding_low_axis1: jax.Array | None = frozen_field(default=None) #: 1D array of shape ``(nx,)`` used as padding after axis 1. ``None`` falls back to edge-repeat. padding_high_axis1: jax.Array | None = frozen_field(default=None) _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field( default=ParameterType.CONTINUOUS ) _all_arrays_2d: bool = frozen_private_field(default=True) def __call__( self, params: dict[str, jax.Array], **kwargs, ) -> dict[str, jax.Array]: del kwargs return {k: self._apply_smoothing(v) for k, v in params.items()} def _apply_smoothing(self, x: jax.Array) -> jax.Array: vertical_axis = x.shape.index(1) x_squeezed = x.squeeze(vertical_axis) if x_squeezed.ndim != 2: raise ValueError(f"Expected 2D array, got shape {x_squeezed.shape}") kernel_size = 6 * self.std_discrete + 1 kernel = self._create_gaussian_kernel(kernel_size, self.std_discrete) pad_w = kernel_size // 2 nx, ny = x_squeezed.shape # Pad axis 0 (row dimension) if self.padding_low_axis0 is not None: block_low0 = jnp.tile(self.padding_low_axis0[jnp.newaxis, :], (pad_w, 1)) else: block_low0 = jnp.tile(x_squeezed[0:1, :], (pad_w, 1)) if self.padding_high_axis0 is not None: block_high0 = jnp.tile(self.padding_high_axis0[jnp.newaxis, :], (pad_w, 1)) else: block_high0 = jnp.tile(x_squeezed[-1:, :], (pad_w, 1)) arr = jnp.concatenate([block_low0, x_squeezed, block_high0], axis=0) # Pad axis 1 (column dimension); extend 1D arrays with their edge values to cover corners if self.padding_low_axis1 is not None: corners_lo = jnp.full((pad_w,), self.padding_low_axis1[0]) corners_hi = jnp.full((pad_w,), self.padding_low_axis1[-1]) extended = jnp.concatenate([corners_lo, self.padding_low_axis1, corners_hi]) block_low1 = jnp.tile(extended[:, jnp.newaxis], (1, pad_w)) else: block_low1 = jnp.tile(arr[:, 0:1], (1, pad_w)) if self.padding_high_axis1 is not None: corners_lo = jnp.full((pad_w,), self.padding_high_axis1[0]) corners_hi = jnp.full((pad_w,), self.padding_high_axis1[-1]) extended = jnp.concatenate([corners_lo, self.padding_high_axis1, corners_hi]) block_high1 = jnp.tile(extended[:, jnp.newaxis], (1, pad_w)) else: block_high1 = jnp.tile(arr[:, -1:], (1, pad_w)) arr = jnp.concatenate([block_low1, arr, block_high1], axis=1) result = jax.scipy.signal.convolve(arr, kernel, mode="same") result = result[pad_w : pad_w + nx, pad_w : pad_w + ny] return result.reshape(x.shape) def _create_gaussian_kernel(self, size: int, sigma: float) -> jax.Array: # Create a coordinate grid coords = jnp.arange(-(size // 2), size // 2 + 1) x, y = jnp.meshgrid(coords, coords) # Create the Gaussian kernel kernel = jnp.exp(-(x**2 + y**2) / (2 * sigma**2)) # Normalize the kernel to sum to 1 kernel = kernel / jnp.sum(kernel) return kernel