Source code for fdtdx.objects.device.parameters.transform

from abc import ABC, abstractmethod
from typing import Self, Sequence

import jax

from fdtdx.config import SimulationConfig
from fdtdx.core.jax.pytrees import TreeClass, autoinit, frozen_private_field
from fdtdx.materials import Material
from fdtdx.typing import ParameterType


[docs] @autoinit class ParameterTransformation(TreeClass, ABC): _input_type: dict[str, ParameterType] = frozen_private_field() _input_shape: dict[str, tuple[int, ...]] = frozen_private_field() _output_type: dict[str, ParameterType] = frozen_private_field() _output_shape: dict[str, tuple[int, ...]] = frozen_private_field() _materials: dict[str, Material] = frozen_private_field() _config: SimulationConfig = frozen_private_field() _matrix_voxel_grid_shape: tuple[int, int, int] = frozen_private_field() _single_voxel_size: tuple[float, float, float] = frozen_private_field() # settings _check_single_array: bool = frozen_private_field(default=False) _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(default=None) _all_arrays_2d: bool = frozen_private_field(default=False)
[docs] def init_module( self: Self, config: SimulationConfig, materials: dict[str, Material], matrix_voxel_grid_shape: tuple[int, int, int], single_voxel_size: tuple[float, float, float], output_shape: dict[str, tuple[int, ...]], ) -> Self: self = self.aset("_config", config, create_new_ok=True) self = self.aset("_materials", materials, create_new_ok=True) self = self.aset("_matrix_voxel_grid_shape", matrix_voxel_grid_shape, create_new_ok=True) self = self.aset("_single_voxel_size", single_voxel_size, create_new_ok=True) self = self.aset("_output_shape", output_shape, create_new_ok=True) input_shape = self.get_input_shape(output_shape) self = self.aset("_input_shape", input_shape, create_new_ok=True) return self
[docs] def init_type( self, input_type: dict[str, ParameterType], ) -> Self: # given input type self = self.aset("_input_type", input_type, create_new_ok=True) # compute output type output_type = self.get_output_type(input_type) self = self.aset("_output_type", output_type, create_new_ok=True) return self
[docs] def get_output_type( self, input_type: dict[str, ParameterType], ) -> dict[str, ParameterType]: # checks if self._check_single_array and len(input_type) != 1: raise Exception( f"ParameterTransform {self.__class__} expects input to be a single array, but got: {input_type}" ) if self._fixed_input_type is not None: for v in input_type.values(): err_msg = ( f"ParameterTransform {self.__class__} expects all input types to be {self._fixed_input_type}" f", but got {input_type}" ) if isinstance(self._fixed_input_type, Sequence): if v not in self._fixed_input_type: raise Exception(err_msg) elif v != self._fixed_input_type: raise Exception(err_msg) # implementation output_type = self._get_output_type_impl(input_type) return output_type
[docs] def get_input_shape( self, output_shape: dict[str, tuple[int, ...]], ) -> dict[str, tuple[int, ...]]: # checks if self._all_arrays_2d: for v in output_shape.values(): err_msg = ( f"ParameterTransform {self.__class__} expects to work with 2d arrays, so exactly one axis of the " f"3d array needs to have size of 1, but got: {output_shape}" ) if len(v) != 3 or 1 not in v: raise Exception(err_msg) if sum([n != 1 for n in v]) != 2: raise Exception(err_msg) # implementation input_shape = self._get_input_shape_impl(output_shape) return input_shape
@abstractmethod def _get_input_shape_impl( self, output_shape: dict[str, tuple[int, ...]], ) -> dict[str, tuple[int, ...]]: raise NotImplementedError() @abstractmethod def _get_output_type_impl( self, input_type: dict[str, ParameterType], ) -> dict[str, ParameterType]: raise NotImplementedError() @abstractmethod def __call__( self, params: dict[str, jax.Array], **kwargs, ) -> dict[str, jax.Array]: raise NotImplementedError()
@autoinit class SameShapeTypeParameterTransform(ParameterTransformation, ABC): def _get_input_shape_impl( self, output_shape: dict[str, tuple[int, ...]], ) -> dict[str, tuple[int, ...]]: return output_shape def _get_output_type_impl( self, input_type: dict[str, ParameterType], ) -> dict[str, ParameterType]: return input_type