import math
from typing import Literal, Self, Sequence
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
from loguru import logger
from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import autoinit, frozen_field, frozen_private_field
from fdtdx.core.jax.ste import straight_through_estimator
from fdtdx.core.misc import get_background_material_name
from fdtdx.materials import Material, compute_allowed_permittivities, compute_ordered_names
from fdtdx.objects.device.parameters.binary_transform import dilate_jax
from fdtdx.objects.device.parameters.transform import ParameterTransformation
from fdtdx.objects.device.parameters.utils import compute_allowed_indices, nearest_index
from fdtdx.typing import ParameterType
[docs]
@autoinit
class ClosestIndex(ParameterTransformation):
"""
Maps continuous latent values to nearest allowed material indices.
For each input value, finds the index of the closest allowed inverse
permittivity value. Uses straight-through gradient estimation to maintain
differentiability. If mapping_from_inverse_permittivities is set to False (default),
then the transform only quantizes the latent parameters to the closest integer value.
"""
#: a boolean value set for inverse_permittivities
mapping_from_inverse_permittivities: bool = frozen_field(default=False)
_fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
default=ParameterType.CONTINUOUS
)
def _get_input_shape_impl(
self,
output_shape: dict[str, tuple[int, ...]],
) -> dict[str, tuple[int, ...]]:
return output_shape
def _get_output_type_impl(
self,
input_type: dict[str, ParameterType],
) -> dict[str, ParameterType]:
if len(self._materials) <= 1:
raise Exception(f"Invalid materials (need two or more): {self._materials}")
elif len(self._materials) == 2:
output_type = ParameterType.BINARY
else:
output_type = ParameterType.DISCRETE
result = {k: output_type for k in input_type.keys()}
return result
def __call__(
self,
params: dict[str, jax.Array],
**kwargs,
) -> dict[str, jax.Array]:
del kwargs
def transform_arr(arr: jax.Array) -> jax.Array:
if self.mapping_from_inverse_permittivities:
is_isotropic = all(mat.is_isotropic_permittivity for mat in self._materials.values())
is_diagonally_anisotropic = all(
mat.is_diagonally_anisotropic_permittivity for mat in self._materials.values()
)
allowed_perm_array = jnp.asarray(
compute_allowed_permittivities(
self._materials, isotropic=is_isotropic, diagonally_anisotropic=is_diagonally_anisotropic
)
)
if is_isotropic or is_diagonally_anisotropic:
allowed_inv_perms = 1 / allowed_perm_array
else:
# Fully anisotropic: reshape to 3x3 matrix, invert, and flatten back to 9 elements
allowed_inv_perms = jnp.array(
[jnp.linalg.inv(perm.reshape(3, 3)).flatten() for perm in allowed_perm_array]
)
dist = jnp.abs(arr[..., None] - allowed_inv_perms)
discrete = jnp.argmin(dist, axis=-1)
else:
discrete = jnp.clip(jnp.round(arr), 0, len(self._materials) - 1)
return straight_through_estimator(arr, discrete)
result = {}
for k, v in params.items():
result[k] = transform_arr(v)
return result
[docs]
@autoinit
class BrushConstraint2D(ParameterTransformation):
"""Applies 2D brush-based constraints to ensure minimum feature sizes.
Implements the brush-based constraint method described in:
https://pubs.acs.org/doi/10.1021/acsphotonics.2c00313
This ensures minimum feature sizes and connectivity in 2D designs by using
morphological operations with a brush kernel.
"""
#: Array defining the brush kernel for morphological operations.
brush: jax.Array = frozen_field()
#: Axis along which to apply the 2D constraint (perpendicular plane).
axis: int = frozen_field()
#: Name of the background material in the material
#: dictionary of the device. If None, the material with the lowest permittivity is used. Defaults to None.
background_material: str | None = frozen_field(default=None)
_fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
default=ParameterType.CONTINUOUS
)
_check_single_array: bool = frozen_private_field(default=True)
_all_arrays_2d: bool = frozen_private_field(default=True)
def _get_input_shape_impl(
self,
output_shape: dict[str, tuple[int, ...]],
) -> dict[str, tuple[int, ...]]:
return output_shape
def _get_output_type_impl(
self,
input_type: dict[str, ParameterType],
) -> dict[str, ParameterType]:
if len(self._materials) != 2:
raise Exception(
f"BrushConstraint2D currently only implemented for exactly two materials, but got {self._materials}"
)
return {k: ParameterType.BINARY for k in input_type.keys()}
def __call__(
self,
params: dict[str, jax.Array],
**kwargs,
) -> dict[str, jax.Array]:
del kwargs
single_key = next(iter(params.keys()))
param_arr = params[single_key]
s = param_arr.shape
if s[self.axis] != 1:
raise Exception(f"BrushConstraint2D Generator needs array size 1 in axis, but got {s=}")
arr_2d = jnp.take(
param_arr,
jnp.asarray(0),
axis=self.axis,
)
if self.background_material is None:
background_name = get_background_material_name(self._materials)
else:
background_name = self.background_material
ordered_name_list = compute_ordered_names(self._materials)
background_idx = ordered_name_list.index(background_name)
if background_idx != 0:
arr_2d = -arr_2d
cur_result = self._generator(arr_2d)
if background_idx != 0:
cur_result = 1 - cur_result
cur_result = jnp.expand_dims(cur_result, axis=self.axis)
result = straight_through_estimator(param_arr, cur_result)
return {single_key: result}
def _generator(
self,
arr: jax.Array,
) -> jax.Array:
touches_void = jnp.zeros_like(arr, dtype=jnp.bool)
touches_solid = jnp.zeros_like(touches_void)
def cond_fn(arrs):
touch_v, touch_s = arrs[0], arrs[1]
pixel_existing_solid = dilate_jax(touch_s, self.brush)
pixel_existing_void = dilate_jax(touch_v, self.brush)
return ~jnp.all(pixel_existing_solid | pixel_existing_void)
def body_fn(sv_arrs: tuple[jax.Array, jax.Array]):
# see Algorithm 1 in paper
touch_v, touch_s = sv_arrs[0], sv_arrs[1]
# compute touches and pixel arrays
pixel_existing_solid = dilate_jax(touch_s, self.brush)
pixel_existing_void = dilate_jax(touch_v, self.brush)
touch_impossible_solid = dilate_jax(pixel_existing_void, self.brush)
touch_impossible_void = dilate_jax(pixel_existing_solid, self.brush)
touch_valid_solid = ~touch_impossible_solid & ~touch_s
touch_valid_void = ~touch_impossible_void & ~touch_v
pixel_possible_solid = dilate_jax(touch_s | touch_valid_solid, self.brush)
pixel_possible_void = dilate_jax(touch_v | touch_valid_void, self.brush)
pixel_required_solid = ~pixel_existing_solid & ~pixel_possible_void
pixel_required_void = ~pixel_existing_void & ~pixel_possible_solid
touch_resolving_solid = dilate_jax(pixel_required_solid, self.brush) & touch_valid_solid
touch_resolving_void = dilate_jax(pixel_required_void, self.brush) & touch_valid_void
touch_free_solid = ~dilate_jax(pixel_possible_void | pixel_existing_void, self.brush) & touch_valid_solid
touch_free_void = ~dilate_jax(pixel_possible_solid | pixel_existing_solid, self.brush) & touch_valid_void
# case 1
def select_all_free_touches():
new_v = touch_v | touch_free_void
new_s = touch_s | touch_free_solid
return new_v, new_s
# case 2
def select_best_resolving_touch():
values_solid = jnp.where(touch_resolving_solid, arr, -jnp.inf)
values_void = jnp.where(touch_resolving_void, -arr, -jnp.inf)
def select_void():
max_idx = jnp.argmax(values_void)
new_v = touch_v.flatten().at[max_idx].set(True).reshape(touch_s.shape)
return new_v, touch_s
def select_solid():
max_idx = jnp.argmax(values_solid)
new_s = touch_s.flatten().at[max_idx].set(True).reshape(touch_v.shape)
return touch_v, new_s
return jax.lax.cond(
jnp.max(values_solid) > jnp.max(values_void),
select_solid,
select_void,
)
# case 3
def select_best_valid_touch():
values_solid = jnp.where(touch_valid_solid, arr, -jnp.inf)
values_void = jnp.where(touch_valid_void, -arr, -jnp.inf)
def select_void():
max_idx = jnp.argmax(values_void)
new_v = touch_v.flatten().at[max_idx].set(True).reshape(touch_s.shape)
return new_v, touch_s
def select_solid():
max_idx = jnp.argmax(values_solid)
new_s = touch_s.flatten().at[max_idx].set(True).reshape(touch_v.shape)
return touch_v, new_s
return jax.lax.cond(
jnp.max(values_solid) > jnp.max(values_void),
select_solid,
select_void,
)
# case 2 and 3
def case_2_and_3_function():
resolving_exists = jnp.any(touch_resolving_solid | touch_resolving_void)
return jax.lax.cond(
resolving_exists,
select_best_resolving_touch,
select_best_valid_touch,
)
free_touches_exist = jnp.any(touch_free_solid | touch_free_void)
new_v, new_s = jax.lax.cond(
free_touches_exist,
select_all_free_touches,
case_2_and_3_function,
)
return new_v, new_s
arrs = (touches_void, touches_solid)
res_arrs = eqxi.while_loop(
cond_fun=cond_fn,
body_fun=body_fn,
init_val=arrs,
kind="lax",
)
pixel_existing_solid = dilate_jax(res_arrs[1], self.brush)
return pixel_existing_solid
[docs]
def circular_brush(
diameter: float,
size: int | None = None,
) -> jax.Array:
"""Creates a circular binary mask/brush for morphological operations.
Args:
diameter (float): Diameter of the circle in grid units.
size (int | None, optional): Optional size of the output array. If None, uses ceil(diameter) rounded
up to next odd number.
Returns:
jax.Array: Binary array containing a circular mask where True indicates points
within the circle diameter.
"""
if size is None:
s = math.ceil(diameter)
if s % 2 == 0:
s += 1
size = s
xy = jnp.stack(jnp.meshgrid(*map(jnp.arange, (size, size)), indexing="xy"), axis=-1) - jnp.asarray((size / 2) - 0.5)
euc_dist = jnp.sqrt((xy**2).sum(axis=-1))
# the less EQUAL here is important, because otherwise design may be infeasible due to discretization errors
mask = euc_dist <= (diameter / 2)
return mask
[docs]
@autoinit
class PillarDiscretization(ParameterTransformation):
"""Constraint module for mapping pillar structures to allowed configurations.
Maps arbitrary pillar structures to the nearest allowed configurations based on
material constraints and geometry requirements. Ensures structures meet fabrication
rules like single polymer columns and no trapped air holes.
"""
#: Axis along which to enforce pillar constraints (0=x, 1=y, 2=z).
axis: int = frozen_field()
#: If True, restrict to single polymer columns.
single_polymer_columns: bool = frozen_field()
#: Method to compute distances between material distributions:
#:
#: - "euclidean": Standard Euclidean distance between permittivity values.
#: - "permittivity_differences_plus_average_permittivity": Weighted combination of permittivity differences and average permittivity values, optimized for material distribution comparisons.
#:
#: Defaults to "permittivity_differences_plus_average_permittivity".
distance_metric: Literal["euclidean", "permittivity_differences_plus_average_permittivity"] = frozen_field(
default="permittivity_differences_plus_average_permittivity",
)
#: Name of the background material in the materials dictionary of the corresponding device.
#: If None, the material with lowest permittivity is used. Defaults to None.
background_material: str | None = frozen_field(default=None)
_allowed_indices: jax.Array = frozen_private_field()
_check_single_array: bool = frozen_private_field(default=True)
_fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
default=ParameterType.CONTINUOUS
)
def _get_input_shape_impl(
self,
output_shape: dict[str, tuple[int, ...]],
) -> dict[str, tuple[int, ...]]:
return output_shape
def _get_output_type_impl(
self,
input_type: dict[str, ParameterType],
) -> dict[str, ParameterType]:
if len(self._materials) <= 1:
raise Exception(f"Invalid materials (need two or more): {self._materials}")
elif len(self._materials) == 2:
output_type = ParameterType.BINARY
else:
output_type = ParameterType.DISCRETE
return {k: output_type for k in input_type.keys()}
[docs]
def init_module(
self: Self,
config: SimulationConfig,
materials: dict[str, Material],
matrix_voxel_grid_shape: tuple[int, int, int],
single_voxel_size: tuple[float, float, float],
output_shape: dict[str, tuple[int, ...]],
) -> Self:
self = super().init_module(
config=config,
materials=materials,
matrix_voxel_grid_shape=matrix_voxel_grid_shape,
single_voxel_size=single_voxel_size,
output_shape=output_shape,
)
if self.background_material is None:
background_name = get_background_material_name(self._materials)
else:
background_name = self.background_material
ordered_name_list = compute_ordered_names(self._materials)
background_idx = ordered_name_list.index(background_name)
allowed_columns = compute_allowed_indices(
num_layers=matrix_voxel_grid_shape[self.axis],
indices=list(range(len(materials))),
fill_holes_with_index=[background_idx],
single_polymer_columns=self.single_polymer_columns,
)
self = self.aset("_allowed_indices", allowed_columns, create_new_ok=True)
logger.info(f"{allowed_columns=}")
logger.info(f"{allowed_columns.shape=}")
return self
def __call__(
self,
params: dict[str, jax.Array],
**kwargs,
) -> dict[str, jax.Array]:
del kwargs
single_key = next(iter(params.keys()))
params_arr = params[single_key]
is_isotropic = all(mat.is_isotropic_permittivity for mat in self._materials.values())
is_diagonally_anisotropic = all(mat.is_diagonally_anisotropic_permittivity for mat in self._materials.values())
allowed_perm_array = jnp.asarray(
compute_allowed_permittivities(
self._materials, isotropic=is_isotropic, diagonally_anisotropic=is_diagonally_anisotropic
)
)
if is_isotropic:
# squeeze component dim (n_materials, 1) → (n_materials,) so
# nearest_index can broadcast correctly against (nx, ny, nz)
allowed_inv_perms = (1 / allowed_perm_array).squeeze(-1)
elif is_diagonally_anisotropic:
allowed_inv_perms = 1 / allowed_perm_array
else:
# Fully anisotropic: reshape to 3x3 matrix, invert, and flatten back to 9 elements
allowed_inv_perms = jnp.array([jnp.linalg.inv(perm.reshape(3, 3)).flatten() for perm in allowed_perm_array])
nearest_allowed_index = nearest_index(
values=params_arr,
allowed_values=allowed_inv_perms,
axis=self.axis,
distance_metric=self.distance_metric,
allowed_indices=self._allowed_indices,
return_distances=False,
)
result_index = self._allowed_indices[nearest_allowed_index]
if self.axis == 2:
pass # no transposition needed
elif self.axis == 1:
result_index = jnp.transpose(result_index, axes=(0, 2, 1))
elif self.axis == 0:
result_index = jnp.transpose(result_index, axes=(2, 0, 1))
else:
raise Exception(f"invalid axis: {self.axis}")
result_index = straight_through_estimator(params_arr, result_index)
return {single_key: result_index}