Source code for fdtdx.interfaces.modules

from abc import ABC, abstractmethod
from typing import Self, Sequence

import jax
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import TreeClass, autoinit, frozen_field, frozen_private_field
from fdtdx.interfaces.state import RecordingState


@autoinit
class CompressionModule(TreeClass, ABC):
    """Abstract base class for compression modules that process simulation data.

    This class provides an interface for modules that compress and decompress field data
    during FDTD simulations. Implementations can perform operations like quantization,
    dimensionality reduction, or other compression techniques.

    """

    _input_shape_dtypes: dict[str, jax.ShapeDtypeStruct] = frozen_private_field(default=None)  # type: ignore
    _output_shape_dtypes: dict[str, jax.ShapeDtypeStruct] = frozen_private_field(default=None)  # type: ignore

    @abstractmethod
    def init_shapes(
        self,
        input_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
    ) -> tuple[
        Self,
        dict[str, jax.ShapeDtypeStruct],  # data
        dict[str, jax.ShapeDtypeStruct],  # state shapes/dtypes
    ]:
        """Initialize shapes and sizes for the compression module.

        Args:
            input_shape_dtypes (dict[str, jax.ShapeDtypeStruct]): Dictionary mapping field names to their input
                shapes/dtypes.

        Returns:
            tuple[Self, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]: Tuple containing:
                - Self: Updated instance of the compression module
                - Dictionary mapping field names to their output shapes/dtypes
                - Dictionary mapping field names to their state shapes/dtypes
        """
        del input_shape_dtypes
        raise NotImplementedError()

    @abstractmethod
    def compress(
        self,
        values: dict[str, jax.Array],
        state: RecordingState,
        key: jax.Array,
    ) -> tuple[
        dict[str, jax.Array],  # compressed data
        RecordingState,  # updated recording state
    ]:
        """Compress field values at the current time step.

        Args:
            values (dict[str, jax.Array]): Dictionary mapping field names to their values.
            state (RecordingState): Current recording state.
            key (jax.Array): Random key for stochastic operations.

        Returns:
            tuple[dict[str, jax.Array], RecordingState]: Tuple containing:
                - Dictionary of compressed field values
                - Updated recording state
        """
        del values, state, key
        raise NotImplementedError()

    @abstractmethod
    def decompress(
        self,
        values: dict[str, jax.Array],
        state: RecordingState,
        key: jax.Array,
    ) -> dict[str, jax.Array]:
        """Decompress field values back to their original form.

        Args:
            values (dict[str, jax.Array]): Dictionary mapping field names to their compressed values.
            state (RecordingState): Current recording state.
            key (jax.Array): Random key for stochastic operations.

        Returns:
            dict[str, jax.Array]: Dictionary mapping field names to their decompressed values.
        """
        del (
            values,
            state,
            key,
        )
        raise NotImplementedError()


[docs] @autoinit class DtypeConversion(CompressionModule): """Compression module that converts data types of field values. This module changes the data type of field values while preserving their shape, useful for reducing memory usage or meeting precision requirements. """ #: Target data type for conversion. dtype: jnp.dtype = frozen_field(kind="KW_ONLY") #: List of field names to exclude from conversion. exclude_filter: Sequence[str] = frozen_field(default=tuple([]), kind="KW_ONLY")
[docs] def init_shapes( self, input_shape_dtypes: dict[str, jax.ShapeDtypeStruct], ) -> tuple[ Self, dict[str, jax.ShapeDtypeStruct], # data dict[str, jax.ShapeDtypeStruct], # state shapes/dtypes ]: self = self.aset("_input_shape_dtypes", input_shape_dtypes) exclude = [] if self.exclude_filter is None else self.exclude_filter for k, v in input_shape_dtypes.items(): if any(e in k for e in exclude): continue if jnp.issubdtype(v.dtype, jnp.complexfloating) and not jnp.issubdtype(self.dtype, jnp.complexfloating): raise ValueError( f"DtypeConversion target dtype {self.dtype} is real but input '{k}' " f"has complex dtype {v.dtype}. This would silently discard the imaginary " f"component. Use a complex target dtype or add '{k}' to exclude_filter." ) out_shape_dtypes = { k: (jax.ShapeDtypeStruct(v.shape, self.dtype) if not any(e in k for e in exclude) else v) for k, v in input_shape_dtypes.items() } self = self.aset("_output_shape_dtypes", out_shape_dtypes) return self, self._output_shape_dtypes, {}
[docs] def compress( self, values: dict[str, jax.Array], state: RecordingState, key: jax.Array, ) -> tuple[ dict[str, jax.Array], RecordingState, ]: del key out_vals = { k: (v.astype(self.dtype) if not any(e in k for e in self.exclude_filter) else v) for k, v in values.items() } return out_vals, state
[docs] def decompress( self, values: dict[str, jax.Array], state: RecordingState, key: jax.Array, ) -> dict[str, jax.Array]: del key, state out_vals = {k: v.astype(self._input_shape_dtypes[k].dtype) for k, v in values.items()} return out_vals