from abc import ABC
from typing import Self, Sequence, cast
import jax
import jax.numpy as jnp
from fdtdx.colors import XKCD_LIGHT_PINK, Color
from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, field, frozen_field, frozen_private_field
from fdtdx.core.jax.utils import check_specs
from fdtdx.core.misc import expand_matrix, is_float_divisible
from fdtdx.materials import Material
from fdtdx.objects.device.parameters.transform import ParameterTransformation
from fdtdx.objects.object import OrderableObject
from fdtdx.typing import (
INVALID_SHAPE_3D,
UNDEFINED_SHAPE_3D,
ParameterType,
PartialGridShape3D,
PartialRealShape3D,
SliceTuple3D,
)
[docs]
@autoinit
class Device(OrderableObject, ABC):
"""Abstract base class for devices with optimizable permittivity distributions.
This class defines the common interface and functionality for both discrete and
continuous devices that can be optimized through gradient-based methods.
"""
#: Dictionary of materials to be used in the device.
materials: dict[str, Material] = field()
#: A Sequence of parameter transformation to be applied to the parameters when mapping them to simulation materials.
param_transforms: Sequence[ParameterTransformation] = field()
#: Color of the object when plotted. Defaults to XKCD_LIGHT_PINK.
color: Color | None = frozen_field(default=XKCD_LIGHT_PINK)
#: Size of the material voxels used within the device in metrical units (meter). Note that this is independent of the simulation voxel size.
#: Defaults to undefined shape. For all three axes, either the voxel grid or real shape needs to be defined.
partial_voxel_grid_shape: PartialGridShape3D = frozen_field(default=UNDEFINED_SHAPE_3D)
#: Size of the material voxels used within the device in simulation voxels. Defaults to undefined shape.
#: For all three axes, either the voxel grid or real shape needs to be defined.
partial_voxel_real_shape: PartialRealShape3D = frozen_field(default=UNDEFINED_SHAPE_3D)
_single_voxel_grid_shape: tuple[int, int, int] = frozen_private_field(default=INVALID_SHAPE_3D)
@property
def matrix_voxel_grid_shape(self) -> tuple[int, int, int]:
"""Calculate the shape of the voxel matrix in grid coordinates.
Returns:
tuple[int, int, int]: Tuple of (x,y,z) dimensions representing how many voxels fit in each direction
of the grid shape when divided by the single voxel shape.
"""
return (
round(self.grid_shape[0] / self.single_voxel_grid_shape[0]),
round(self.grid_shape[1] / self.single_voxel_grid_shape[1]),
round(self.grid_shape[2] / self.single_voxel_grid_shape[2]),
)
@property
def single_voxel_grid_shape(self) -> tuple[int, int, int]:
"""Get the shape of a single voxel in grid coordinates.
Returns:
tuple[int, int, int]: Tuple of (x,y,z) dimensions for one voxel.
"""
if self._single_voxel_grid_shape == INVALID_SHAPE_3D:
raise Exception(f"{self} is not initialized yet")
return self._single_voxel_grid_shape
@property
def single_voxel_real_shape(self) -> tuple[float, float, float]:
"""Calculate the shape of a single voxel in real (physical) coordinates.
Returns:
tuple[float, float, float]: Tuple of (x,y,z) dimensions in real units, computed by multiplying
the grid shape by the simulation resolution.
"""
grid_shape = self.single_voxel_grid_shape
return (
grid_shape[0] * self._config.resolution,
grid_shape[1] * self._config.resolution,
grid_shape[2] * self._config.resolution,
)
@property
def output_type(self) -> ParameterType:
if not self.param_transforms:
return ParameterType.CONTINUOUS
out_type = self.param_transforms[-1]._output_type
if isinstance(out_type, dict) and len(out_type) == 1:
out_type = next(iter(out_type.values()))
if not isinstance(out_type, ParameterType):
raise Exception(
"Output of Parameter transformation sequence (last module) needs to be a single array, but got: "
f"{out_type}"
)
return out_type
[docs]
def place_on_grid(
self: Self,
grid_slice_tuple: SliceTuple3D,
config: SimulationConfig,
key: jax.Array,
) -> Self:
self = super().place_on_grid(grid_slice_tuple=grid_slice_tuple, config=config, key=key)
# determine voxel shape
voxel_grid_shape = []
for axis in range(3):
partial_grid = self.partial_voxel_grid_shape[axis]
partial_real = self.partial_voxel_real_shape[axis]
if partial_grid is not None and partial_real is not None:
raise Exception(f"Multi-Material voxels overspecified in axis: {axis=}")
if partial_grid is not None:
voxel_grid_shape.append(partial_grid)
elif partial_real is not None:
voxel_grid_shape.append(round(partial_real / config.resolution))
else:
raise Exception(f"Multi-Material voxels not specified in axis: {axis=}")
self = self.aset("_single_voxel_grid_shape", tuple(voxel_grid_shape))
# sanity checks on the voxel shape
for axis in range(3):
float_div = is_float_divisible(
self.single_voxel_real_shape[axis],
self._config.resolution,
)
if not float_div:
raise Exception(f"Not divisible: {self.single_voxel_real_shape[axis]=}, {self._config.resolution=}")
if self.grid_shape[axis] % self.matrix_voxel_grid_shape[axis] != 0:
raise Exception(
f"Due to discretization, matrix got skewered for {axis=}. "
f"{self.grid_shape=}, {self.matrix_voxel_grid_shape=}"
)
# init parameter transformations
# We need to go once backwards through the transformations to determine the shape of the latent parameters
# then we need to go forward through the transformations again to determine the parameter type of the
# output
new_t_list: list[ParameterTransformation] = []
cur_shape = {"params": self.matrix_voxel_grid_shape}
for transform in self.param_transforms[::-1]:
t_new = transform.init_module(
config=config,
materials=self.materials,
matrix_voxel_grid_shape=self.matrix_voxel_grid_shape,
single_voxel_size=self.single_voxel_real_shape,
output_shape=cast(dict[str, tuple[int, ...]], cur_shape),
)
new_t_list.append(t_new)
cur_shape = t_new._input_shape
# init shape of transformations by going backwards through new list
module_list: list[ParameterTransformation] = []
cur_input_type = {"params": ParameterType.CONTINUOUS}
for transform in new_t_list[::-1]:
t_new = transform.init_type(
input_type=cur_input_type,
)
module_list.append(t_new)
cur_input_type = t_new._output_type
# set own input shape dtype
self = self.aset("param_transforms", module_list)
if self.output_type == ParameterType.CONTINUOUS and len(self.materials) != 2:
raise Exception(
f"Need exactly two materials in device when parameter mapping outputs continuous permittivity indices, "
f"but got {self.materials}"
)
return self
[docs]
def init_params(
self,
key: jax.Array,
) -> dict[str, jax.Array] | jax.Array:
if len(self.param_transforms) > 0:
shapes = self.param_transforms[0]._input_shape
else:
shapes = self.matrix_voxel_grid_shape
if not isinstance(shapes, dict):
shapes = {"params": shapes}
params = {}
for k, cur_shape in shapes.items():
key, subkey = jax.random.split(key)
p = jax.random.uniform(
key=subkey,
shape=cur_shape,
minval=0, # parameter always live between 0 and 1
maxval=1,
dtype=jnp.float32,
)
params[k] = p
if len(params) == 1:
params = next(iter(params.values()))
return params
def __call__(
self,
params: dict[str, jax.Array] | jax.Array,
expand_to_sim_grid: bool = False,
**transform_kwargs,
) -> jax.Array:
if not isinstance(params, dict):
params = {"params": params}
# walk through modules
for transform in self.param_transforms:
check_specs(params, transform._input_shape)
params_dict = cast(dict[str, jax.Array], params)
params = transform(params_dict, **transform_kwargs)
check_specs(params, transform._output_shape)
if len(params) == 1:
single_val = next(iter(params.values()))
assert isinstance(single_val, jax.Array)
params = single_val
else:
raise Exception(
"The parameter mapping should return a single array of indices. If using a continuous device, please"
" make sure that the latent transformations abide to this rule."
)
if expand_to_sim_grid:
params = expand_matrix(
matrix=params,
grid_points_per_voxel=self.single_voxel_grid_shape,
)
return params