Source code for fdtdx.objects.device.device

import math
from abc import ABC
from typing import Self, Sequence, cast

import jax
import jax.numpy as jnp

from fdtdx.colors import XKCD_LIGHT_PINK, Color
from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, field, frozen_field, frozen_private_field
from fdtdx.core.jax.utils import check_specs
from fdtdx.core.misc import expand_matrix, is_float_divisible
from fdtdx.materials import Material
from fdtdx.objects.device.parameters.transform import ParameterTransformation
from fdtdx.objects.object import OrderableObject
from fdtdx.typing import (
    INVALID_SHAPE_3D,
    UNDEFINED_SHAPE_3D,
    ParameterType,
    PartialGridShape3D,
    PartialRealShape3D,
    SliceTuple3D,
)


[docs] @autoinit class Device(OrderableObject, ABC): """Abstract base class for devices with optimizable permittivity distributions. This class defines the common interface and functionality for both discrete and continuous devices that can be optimized through gradient-based methods. """ #: Dictionary of materials to be used in the device. materials: dict[str, Material] = field() #: A Sequence of parameter transformation to be applied to the parameters when mapping them to simulation materials. param_transforms: Sequence[ParameterTransformation] = field() #: Color of the object when plotted. Defaults to XKCD_LIGHT_PINK. color: Color | None = frozen_field(default=XKCD_LIGHT_PINK) #: Size of the material voxels used within the device in metrical units (meter). Note that this is independent of the simulation voxel size. #: Defaults to undefined shape. For all three axes, either the voxel grid or real shape needs to be defined. partial_voxel_grid_shape: PartialGridShape3D = frozen_field(default=UNDEFINED_SHAPE_3D) #: Size of the material voxels used within the device in simulation voxels. Defaults to undefined shape. #: For all three axes, either the voxel grid or real shape needs to be defined. partial_voxel_real_shape: PartialRealShape3D = frozen_field(default=UNDEFINED_SHAPE_3D) _single_voxel_grid_shape: tuple[int, int, int] = frozen_private_field(default=INVALID_SHAPE_3D) _matrix_voxel_grid_shape_override: tuple[int, int, int] = frozen_private_field(default=INVALID_SHAPE_3D) _physical_design_voxel_shape: tuple[float, float, float] | None = frozen_private_field(default=None) _physical_design_domain_shape: tuple[float, float, float] | None = frozen_private_field(default=None) # TODO(teevee112): support physical-unit design voxels on non-uniform grids — requires a resampling layer # or snapping to RectilinearGrid cell boundaries; currently only grid-cell-count voxels are # reliable on non-uniform grids (see PR #312) @property def matrix_voxel_grid_shape(self) -> tuple[int, int, int]: """Calculate the shape of the voxel matrix in grid coordinates. Returns: tuple[int, int, int]: Tuple of (x,y,z) dimensions representing how many voxels fit in each direction of the grid shape when divided by the single voxel shape. """ return ( self._matrix_voxel_grid_shape_override if self._matrix_voxel_grid_shape_override != INVALID_SHAPE_3D else ( round(self.grid_shape[0] / self.single_voxel_grid_shape[0]), round(self.grid_shape[1] / self.single_voxel_grid_shape[1]), round(self.grid_shape[2] / self.single_voxel_grid_shape[2]), ) ) @property def single_voxel_grid_shape(self) -> tuple[int, int, int]: """Get the shape of a single voxel in grid coordinates. Returns: tuple[int, int, int]: Tuple of (x,y,z) dimensions for one voxel. """ if self._single_voxel_grid_shape == INVALID_SHAPE_3D: raise Exception(f"{self} is not initialized yet") return self._single_voxel_grid_shape @property def single_voxel_real_shape(self) -> tuple[float, float, float]: """Calculate the representative physical size of one design voxel. Returns: Tuple of ``(x, y, z)`` dimensions in metres. Notes: On uniform simulation grids this is the exact size of each design voxel. On non-uniform grids, devices are currently supported only when design voxels are specified by simulation-cell counts. The returned physical size is then the average design-voxel extent over the placed device, suitable for transforms that need a representative scale. True physical-size design voxels still require a resampling layer and are rejected during placement. """ if self._physical_design_voxel_shape is not None: return self._physical_design_voxel_shape grid = self._config.resolved_grid if grid is not None and not grid.is_uniform: return ( self.real_shape[0] / self.matrix_voxel_grid_shape[0], self.real_shape[1] / self.matrix_voxel_grid_shape[1], self.real_shape[2] / self.matrix_voxel_grid_shape[2], ) single_voxel_shape = self.single_voxel_grid_shape spacing = self._config.uniform_spacing() return ( single_voxel_shape[0] * spacing, single_voxel_shape[1] * spacing, single_voxel_shape[2] * spacing, ) @property def output_type(self) -> ParameterType: if not self.param_transforms: return ParameterType.CONTINUOUS out_type = self.param_transforms[-1]._output_type if isinstance(out_type, dict) and len(out_type) == 1: out_type = next(iter(out_type.values())) if not isinstance(out_type, ParameterType): raise Exception( "Output of Parameter transformation sequence (last module) needs to be a single array, but got: " f"{out_type}" ) return out_type
[docs] def place_on_grid( self: Self, grid_slice_tuple: SliceTuple3D, config: SimulationConfig, key: jax.Array, ) -> Self: self = super().place_on_grid(grid_slice_tuple=grid_slice_tuple, config=config, key=key) # determine voxel shape voxel_grid_shape = [] spacing = None if not config.has_nonuniform_grid: spacing = config.uniform_spacing() uses_physical_design_grid = config.has_nonuniform_grid and any( shape is not None for shape in self.partial_voxel_real_shape ) if uses_physical_design_grid and any(shape is not None for shape in self.partial_voxel_grid_shape): raise ValueError( "Non-uniform physical device voxel sizes cannot be mixed with partial_voxel_grid_shape. " "Use either a physical design grid or grid-cell-count design voxels." ) if uses_physical_design_grid: raise NotImplementedError( "Physical-unit design voxels (partial_voxel_real_shape) are not yet supported on " "non-uniform grids. Use partial_voxel_grid_shape to specify voxel sizes in " "simulation grid-cell counts instead." ) physical_design_shape = [] matrix_shape_override = [] for axis in range(3): partial_grid = self.partial_voxel_grid_shape[axis] partial_real = self.partial_voxel_real_shape[axis] if partial_grid is not None and partial_real is not None: raise Exception(f"Multi-Material voxels overspecified in axis: {axis=}") if partial_grid is not None: voxel_grid_shape.append(partial_grid) elif partial_real is not None: if uses_physical_design_grid: if partial_real <= 0: raise ValueError(f"Physical design voxel size must be positive for {axis=}.") physical_design_shape.append(float(partial_real)) matrix_shape_override.append(max(1, math.ceil(self.real_shape[axis] / partial_real))) voxel_grid_shape.append(1) else: assert spacing is not None cell_count = round(partial_real / spacing) if cell_count < 1: raise ValueError( f"Device voxel size {partial_real:.3e} m on axis {axis} rounds to 0 cells " f"at spacing {spacing:.3e} m. Increase the voxel size or reduce the grid spacing." ) voxel_grid_shape.append(cell_count) else: raise Exception(f"Multi-Material voxels not specified in axis: {axis=}") self = self.aset("_single_voxel_grid_shape", tuple(voxel_grid_shape)) if uses_physical_design_grid: self = self.aset("_physical_design_voxel_shape", tuple(physical_design_shape)) self = self.aset("_physical_design_domain_shape", self.real_shape) self = self.aset("_matrix_voxel_grid_shape_override", tuple(matrix_shape_override)) # sanity checks on the voxel shape for axis in range(3): if spacing is not None: float_div = is_float_divisible( self.single_voxel_real_shape[axis], spacing, tolerance=max(1e-15, abs(spacing) * 1e-4), ) if not float_div: raise Exception(f"Not divisible: {self.single_voxel_real_shape[axis]=}, {spacing=}") if not uses_physical_design_grid and self.grid_shape[axis] % self.matrix_voxel_grid_shape[axis] != 0: raise Exception( f"Due to discretization, matrix got skewered for {axis=}. " f"{self.grid_shape=}, {self.matrix_voxel_grid_shape=}" ) # init parameter transformations # We need to go once backwards through the transformations to determine the shape of the latent parameters # then we need to go forward through the transformations again to determine the parameter type of the # output new_t_list: list[ParameterTransformation] = [] cur_shape = {"params": self.matrix_voxel_grid_shape} for transform in self.param_transforms[::-1]: t_new = transform.init_module( config=config, materials=self.materials, matrix_voxel_grid_shape=self.matrix_voxel_grid_shape, single_voxel_size=self.single_voxel_real_shape, output_shape=cast(dict[str, tuple[int, ...]], cur_shape), ) new_t_list.append(t_new) cur_shape = t_new._input_shape # init shape of transformations by going backwards through new list module_list: list[ParameterTransformation] = [] cur_input_type = {"params": ParameterType.CONTINUOUS} for transform in new_t_list[::-1]: t_new = transform.init_type( input_type=cur_input_type, ) module_list.append(t_new) cur_input_type = t_new._output_type # set own input shape dtype self = self.aset("param_transforms", module_list) if self.output_type == ParameterType.CONTINUOUS and len(self.materials) != 2: raise Exception( f"Need exactly two materials in device when parameter mapping outputs continuous permittivity indices, " f"but got {self.materials}" ) return self
@staticmethod def _overlap_weights_1d(sim_edges: jax.Array, design_edges: jax.Array) -> jax.Array: """Return design-voxel overlap fractions for each simulation cell. Rows correspond to simulation cells and columns correspond to design voxels. Each row sums to one for cells contained inside the local design domain. Using overlap fractions instead of center sampling keeps physical-size design grids conservative on stretched meshes: a large simulation cell that straddles multiple design voxels receives the volume-weighted average of those parameters. """ sim_lower = sim_edges[:-1, None] sim_upper = sim_edges[1:, None] design_lower = design_edges[None, :-1] design_upper = design_edges[None, 1:] overlap = jnp.maximum(0.0, jnp.minimum(sim_upper, design_upper) - jnp.maximum(sim_lower, design_lower)) widths = sim_upper - sim_lower return overlap / widths def _resample_design_params_to_sim_grid(self, params: jax.Array) -> jax.Array: """Map design-grid parameters to simulation cells. Grid-count design voxels use the legacy repeat expansion. Physical design voxels on non-uniform grids use separable volume-overlap weights so the expanded simulation grid represents the average design parameter over each rectilinear simulation cell. """ grid = self._config.resolved_grid if self._physical_design_voxel_shape is None or grid is None: return expand_matrix( matrix=params, grid_points_per_voxel=self.single_voxel_grid_shape, ) if self._physical_design_domain_shape is None: raise RuntimeError("Physical design-grid devices must be placed before expansion.") overlap_weights = [] for axis in range(3): lower, upper = self.grid_slice_tuple[axis] sim_edges = grid.edges(axis)[lower : upper + 1] design_edges = jnp.linspace( 0.0, self._physical_design_domain_shape[axis], self.matrix_voxel_grid_shape[axis] + 1, dtype=sim_edges.dtype, ) local_sim_edges = sim_edges - sim_edges[0] overlap_weights.append(self._overlap_weights_1d(local_sim_edges, design_edges)) return jnp.einsum("ia,jb,kc,abc->ijk", overlap_weights[0], overlap_weights[1], overlap_weights[2], params)
[docs] def init_params( self, key: jax.Array, ) -> dict[str, jax.Array] | jax.Array: if len(self.param_transforms) > 0: shapes = self.param_transforms[0]._input_shape else: shapes = self.matrix_voxel_grid_shape if not isinstance(shapes, dict): shapes = {"params": shapes} params = {} for k, cur_shape in shapes.items(): key, subkey = jax.random.split(key) p = jax.random.uniform( key=subkey, shape=cur_shape, minval=0, # parameter always live between 0 and 1 maxval=1, dtype=jnp.float32, ) params[k] = p if len(params) == 1: params = next(iter(params.values())) return params
def __call__( self, params: dict[str, jax.Array] | jax.Array, expand_to_sim_grid: bool = False, **transform_kwargs, ) -> jax.Array: if not isinstance(params, dict): params = {"params": params} # walk through modules for transform in self.param_transforms: check_specs(params, transform._input_shape) params_dict = cast(dict[str, jax.Array], params) params = transform(params_dict, **transform_kwargs) check_specs(params, transform._output_shape) if len(params) == 1: single_val = next(iter(params.values())) assert isinstance(single_val, jax.Array) params = single_val else: raise Exception( "The parameter mapping should return a single array of indices. If using a continuous device, please" " make sure that the latent transformations abide to this rule." ) if expand_to_sim_grid: params = self._resample_design_params_to_sim_grid(params) return params