import math
from abc import ABC
from typing import Self, Sequence, cast
import jax
import jax.numpy as jnp
from fdtdx.colors import XKCD_LIGHT_PINK, Color
from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, field, frozen_field, frozen_private_field
from fdtdx.core.jax.utils import check_specs
from fdtdx.core.misc import expand_matrix, is_float_divisible
from fdtdx.materials import Material
from fdtdx.objects.device.parameters.transform import ParameterTransformation
from fdtdx.objects.object import OrderableObject
from fdtdx.typing import (
INVALID_SHAPE_3D,
UNDEFINED_SHAPE_3D,
ParameterType,
PartialGridShape3D,
PartialRealShape3D,
SliceTuple3D,
)
[docs]
@autoinit
class Device(OrderableObject, ABC):
"""Abstract base class for devices with optimizable permittivity distributions.
This class defines the common interface and functionality for both discrete and
continuous devices that can be optimized through gradient-based methods.
"""
#: Dictionary of materials to be used in the device.
materials: dict[str, Material] = field()
#: A Sequence of parameter transformation to be applied to the parameters when mapping them to simulation materials.
param_transforms: Sequence[ParameterTransformation] = field()
#: Color of the object when plotted. Defaults to XKCD_LIGHT_PINK.
color: Color | None = frozen_field(default=XKCD_LIGHT_PINK)
#: Size of the material voxels used within the device in metrical units (meter). Note that this is independent of the simulation voxel size.
#: Defaults to undefined shape. For all three axes, either the voxel grid or real shape needs to be defined.
partial_voxel_grid_shape: PartialGridShape3D = frozen_field(default=UNDEFINED_SHAPE_3D)
#: Size of the material voxels used within the device in simulation voxels. Defaults to undefined shape.
#: For all three axes, either the voxel grid or real shape needs to be defined.
partial_voxel_real_shape: PartialRealShape3D = frozen_field(default=UNDEFINED_SHAPE_3D)
_single_voxel_grid_shape: tuple[int, int, int] = frozen_private_field(default=INVALID_SHAPE_3D)
_matrix_voxel_grid_shape_override: tuple[int, int, int] = frozen_private_field(default=INVALID_SHAPE_3D)
_physical_design_voxel_shape: tuple[float, float, float] | None = frozen_private_field(default=None)
_physical_design_domain_shape: tuple[float, float, float] | None = frozen_private_field(default=None)
# TODO(teevee112): support physical-unit design voxels on non-uniform grids — requires a resampling layer
# or snapping to RectilinearGrid cell boundaries; currently only grid-cell-count voxels are
# reliable on non-uniform grids (see PR #312)
@property
def matrix_voxel_grid_shape(self) -> tuple[int, int, int]:
"""Calculate the shape of the voxel matrix in grid coordinates.
Returns:
tuple[int, int, int]: Tuple of (x,y,z) dimensions representing how many voxels fit in each direction
of the grid shape when divided by the single voxel shape.
"""
return (
self._matrix_voxel_grid_shape_override
if self._matrix_voxel_grid_shape_override != INVALID_SHAPE_3D
else (
round(self.grid_shape[0] / self.single_voxel_grid_shape[0]),
round(self.grid_shape[1] / self.single_voxel_grid_shape[1]),
round(self.grid_shape[2] / self.single_voxel_grid_shape[2]),
)
)
@property
def single_voxel_grid_shape(self) -> tuple[int, int, int]:
"""Get the shape of a single voxel in grid coordinates.
Returns:
tuple[int, int, int]: Tuple of (x,y,z) dimensions for one voxel.
"""
if self._single_voxel_grid_shape == INVALID_SHAPE_3D:
raise Exception(f"{self} is not initialized yet")
return self._single_voxel_grid_shape
@property
def single_voxel_real_shape(self) -> tuple[float, float, float]:
"""Calculate the representative physical size of one design voxel.
Returns:
Tuple of ``(x, y, z)`` dimensions in metres.
Notes:
On uniform simulation grids this is the exact size of each design
voxel. On non-uniform grids, devices are currently supported only
when design voxels are specified by simulation-cell counts. The
returned physical size is then the average design-voxel extent over
the placed device, suitable for transforms that need a representative
scale. True physical-size design voxels still require a resampling
layer and are rejected during placement.
"""
if self._physical_design_voxel_shape is not None:
return self._physical_design_voxel_shape
grid = self._config.resolved_grid
if grid is not None and not grid.is_uniform:
return (
self.real_shape[0] / self.matrix_voxel_grid_shape[0],
self.real_shape[1] / self.matrix_voxel_grid_shape[1],
self.real_shape[2] / self.matrix_voxel_grid_shape[2],
)
single_voxel_shape = self.single_voxel_grid_shape
spacing = self._config.uniform_spacing()
return (
single_voxel_shape[0] * spacing,
single_voxel_shape[1] * spacing,
single_voxel_shape[2] * spacing,
)
@property
def output_type(self) -> ParameterType:
if not self.param_transforms:
return ParameterType.CONTINUOUS
out_type = self.param_transforms[-1]._output_type
if isinstance(out_type, dict) and len(out_type) == 1:
out_type = next(iter(out_type.values()))
if not isinstance(out_type, ParameterType):
raise Exception(
"Output of Parameter transformation sequence (last module) needs to be a single array, but got: "
f"{out_type}"
)
return out_type
[docs]
def place_on_grid(
self: Self,
grid_slice_tuple: SliceTuple3D,
config: SimulationConfig,
key: jax.Array,
) -> Self:
self = super().place_on_grid(grid_slice_tuple=grid_slice_tuple, config=config, key=key)
# determine voxel shape
voxel_grid_shape = []
spacing = None
if not config.has_nonuniform_grid:
spacing = config.uniform_spacing()
uses_physical_design_grid = config.has_nonuniform_grid and any(
shape is not None for shape in self.partial_voxel_real_shape
)
if uses_physical_design_grid and any(shape is not None for shape in self.partial_voxel_grid_shape):
raise ValueError(
"Non-uniform physical device voxel sizes cannot be mixed with partial_voxel_grid_shape. "
"Use either a physical design grid or grid-cell-count design voxels."
)
if uses_physical_design_grid:
raise NotImplementedError(
"Physical-unit design voxels (partial_voxel_real_shape) are not yet supported on "
"non-uniform grids. Use partial_voxel_grid_shape to specify voxel sizes in "
"simulation grid-cell counts instead."
)
physical_design_shape = []
matrix_shape_override = []
for axis in range(3):
partial_grid = self.partial_voxel_grid_shape[axis]
partial_real = self.partial_voxel_real_shape[axis]
if partial_grid is not None and partial_real is not None:
raise Exception(f"Multi-Material voxels overspecified in axis: {axis=}")
if partial_grid is not None:
voxel_grid_shape.append(partial_grid)
elif partial_real is not None:
if uses_physical_design_grid:
if partial_real <= 0:
raise ValueError(f"Physical design voxel size must be positive for {axis=}.")
physical_design_shape.append(float(partial_real))
matrix_shape_override.append(max(1, math.ceil(self.real_shape[axis] / partial_real)))
voxel_grid_shape.append(1)
else:
assert spacing is not None
cell_count = round(partial_real / spacing)
if cell_count < 1:
raise ValueError(
f"Device voxel size {partial_real:.3e} m on axis {axis} rounds to 0 cells "
f"at spacing {spacing:.3e} m. Increase the voxel size or reduce the grid spacing."
)
voxel_grid_shape.append(cell_count)
else:
raise Exception(f"Multi-Material voxels not specified in axis: {axis=}")
self = self.aset("_single_voxel_grid_shape", tuple(voxel_grid_shape))
if uses_physical_design_grid:
self = self.aset("_physical_design_voxel_shape", tuple(physical_design_shape))
self = self.aset("_physical_design_domain_shape", self.real_shape)
self = self.aset("_matrix_voxel_grid_shape_override", tuple(matrix_shape_override))
# sanity checks on the voxel shape
for axis in range(3):
if spacing is not None:
float_div = is_float_divisible(
self.single_voxel_real_shape[axis],
spacing,
tolerance=max(1e-15, abs(spacing) * 1e-4),
)
if not float_div:
raise Exception(f"Not divisible: {self.single_voxel_real_shape[axis]=}, {spacing=}")
if not uses_physical_design_grid and self.grid_shape[axis] % self.matrix_voxel_grid_shape[axis] != 0:
raise Exception(
f"Due to discretization, matrix got skewered for {axis=}. "
f"{self.grid_shape=}, {self.matrix_voxel_grid_shape=}"
)
# init parameter transformations
# We need to go once backwards through the transformations to determine the shape of the latent parameters
# then we need to go forward through the transformations again to determine the parameter type of the
# output
new_t_list: list[ParameterTransformation] = []
cur_shape = {"params": self.matrix_voxel_grid_shape}
for transform in self.param_transforms[::-1]:
t_new = transform.init_module(
config=config,
materials=self.materials,
matrix_voxel_grid_shape=self.matrix_voxel_grid_shape,
single_voxel_size=self.single_voxel_real_shape,
output_shape=cast(dict[str, tuple[int, ...]], cur_shape),
)
new_t_list.append(t_new)
cur_shape = t_new._input_shape
# init shape of transformations by going backwards through new list
module_list: list[ParameterTransformation] = []
cur_input_type = {"params": ParameterType.CONTINUOUS}
for transform in new_t_list[::-1]:
t_new = transform.init_type(
input_type=cur_input_type,
)
module_list.append(t_new)
cur_input_type = t_new._output_type
# set own input shape dtype
self = self.aset("param_transforms", module_list)
if self.output_type == ParameterType.CONTINUOUS and len(self.materials) != 2:
raise Exception(
f"Need exactly two materials in device when parameter mapping outputs continuous permittivity indices, "
f"but got {self.materials}"
)
return self
@staticmethod
def _overlap_weights_1d(sim_edges: jax.Array, design_edges: jax.Array) -> jax.Array:
"""Return design-voxel overlap fractions for each simulation cell.
Rows correspond to simulation cells and columns correspond to design
voxels. Each row sums to one for cells contained inside the local
design domain. Using overlap fractions instead of center sampling keeps
physical-size design grids conservative on stretched meshes: a large
simulation cell that straddles multiple design voxels receives the
volume-weighted average of those parameters.
"""
sim_lower = sim_edges[:-1, None]
sim_upper = sim_edges[1:, None]
design_lower = design_edges[None, :-1]
design_upper = design_edges[None, 1:]
overlap = jnp.maximum(0.0, jnp.minimum(sim_upper, design_upper) - jnp.maximum(sim_lower, design_lower))
widths = sim_upper - sim_lower
return overlap / widths
def _resample_design_params_to_sim_grid(self, params: jax.Array) -> jax.Array:
"""Map design-grid parameters to simulation cells.
Grid-count design voxels use the legacy repeat expansion. Physical
design voxels on non-uniform grids use separable volume-overlap weights
so the expanded simulation grid represents the average design parameter
over each rectilinear simulation cell.
"""
grid = self._config.resolved_grid
if self._physical_design_voxel_shape is None or grid is None:
return expand_matrix(
matrix=params,
grid_points_per_voxel=self.single_voxel_grid_shape,
)
if self._physical_design_domain_shape is None:
raise RuntimeError("Physical design-grid devices must be placed before expansion.")
overlap_weights = []
for axis in range(3):
lower, upper = self.grid_slice_tuple[axis]
sim_edges = grid.edges(axis)[lower : upper + 1]
design_edges = jnp.linspace(
0.0,
self._physical_design_domain_shape[axis],
self.matrix_voxel_grid_shape[axis] + 1,
dtype=sim_edges.dtype,
)
local_sim_edges = sim_edges - sim_edges[0]
overlap_weights.append(self._overlap_weights_1d(local_sim_edges, design_edges))
return jnp.einsum("ia,jb,kc,abc->ijk", overlap_weights[0], overlap_weights[1], overlap_weights[2], params)
[docs]
def init_params(
self,
key: jax.Array,
) -> dict[str, jax.Array] | jax.Array:
if len(self.param_transforms) > 0:
shapes = self.param_transforms[0]._input_shape
else:
shapes = self.matrix_voxel_grid_shape
if not isinstance(shapes, dict):
shapes = {"params": shapes}
params = {}
for k, cur_shape in shapes.items():
key, subkey = jax.random.split(key)
p = jax.random.uniform(
key=subkey,
shape=cur_shape,
minval=0, # parameter always live between 0 and 1
maxval=1,
dtype=jnp.float32,
)
params[k] = p
if len(params) == 1:
params = next(iter(params.values()))
return params
def __call__(
self,
params: dict[str, jax.Array] | jax.Array,
expand_to_sim_grid: bool = False,
**transform_kwargs,
) -> jax.Array:
if not isinstance(params, dict):
params = {"params": params}
# walk through modules
for transform in self.param_transforms:
check_specs(params, transform._input_shape)
params_dict = cast(dict[str, jax.Array], params)
params = transform(params_dict, **transform_kwargs)
check_specs(params, transform._output_shape)
if len(params) == 1:
single_val = next(iter(params.values()))
assert isinstance(single_val, jax.Array)
params = single_val
else:
raise Exception(
"The parameter mapping should return a single array of indices. If using a continuous device, please"
" make sure that the latent transformations abide to this rule."
)
if expand_to_sim_grid:
params = self._resample_design_params_to_sim_grid(params)
return params