Source code for fdtdx.interfaces.recorder

from typing import Self, Sequence

import jax
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import TreeClass, autoinit, frozen_private_field
from fdtdx.core.jax.utils import check_shape_dtype
from fdtdx.interfaces.modules import CompressionModule
from fdtdx.interfaces.state import RecordingState, init_recording_state
from fdtdx.interfaces.time_filter import TimeStepFilter
from fdtdx.typing import BackendOption


[docs] @autoinit class Recorder(TreeClass): """Records and compresses simulation data over time using a sequence of processing modules. The Recorder manages a pipeline of modules that process simulation data at each timestep. It supports both compression modules that reduce data size and time filters that control when data is recorded. The recorder handles initialization, compression and decompression of simulation data through its module pipeline. """ #: Sequence of processing modules to apply to the simulation data. Can be either CompressionModule for data reduction #: or TimeStepFilter for controlling recording frequency. modules: Sequence[CompressionModule | TimeStepFilter] _input_shape_dtypes: dict[str, jax.ShapeDtypeStruct] = frozen_private_field(default=None) # type:ignore _output_shape_dtypes: dict[str, jax.ShapeDtypeStruct] = frozen_private_field(default=None) # type:ignore _max_time_steps: int = frozen_private_field() _latent_array_size: int = frozen_private_field()
[docs] def init_state( self: Self, input_shape_dtypes: dict[str, jax.ShapeDtypeStruct], max_time_steps: int, backend: BackendOption, ) -> tuple[Self, RecordingState]: self = self.aset("_max_time_steps", max_time_steps, create_new_ok=True) self = self.aset("_input_shape_dtypes", input_shape_dtypes, create_new_ok=True) latent_arr_size, out_shapes = max_time_steps, input_shape_dtypes state_sizes: dict[str, jax.ShapeDtypeStruct] = {} new_modules = [] for m in self.modules: if isinstance(m, CompressionModule): m, out_shapes, state_shapes = m.init_shapes(out_shapes) else: m, latent_arr_size, out_shapes, state_shapes = m.init_shapes(out_shapes, latent_arr_size) state_sizes.update(state_shapes) new_modules.append(m) self = self.aset("modules", new_modules) self = self.aset("_output_shape_dtypes", out_shapes) self = self.aset("_latent_array_size", latent_arr_size) expanded_out_shapes = { k: jax.ShapeDtypeStruct( shape=(self._latent_array_size, *v.shape), dtype=v.dtype, ) for k, v in self._output_shape_dtypes.items() } state = init_recording_state( data_shape_dtypes=expanded_out_shapes, state_shape_dtypes=state_sizes, backend=backend, ) return self, state
[docs] def compress( self, values: dict[str, jax.Array], state: RecordingState, time_step: jax.Array, key: jax.Array, ) -> RecordingState: check_shape_dtype(values, self._input_shape_dtypes) latent_idx = time_step def helper_fn(m, values, state, latent_idx, key): if isinstance(m, CompressionModule): values, state = m.compress(values, state, key=key) elif isinstance(m, TimeStepFilter): values, state = m.compress(values, state, latent_idx, key=key) latent_idx = m.time_to_array_index(latent_idx) else: raise Exception(f"Invalid module: {m}") check_shape_dtype(values, m._output_shape_dtypes) return values, state, latent_idx def dummy_fn(m, values, state, latent_idx, key): del key # Only create zero arrays for keys that exist in the input values # This ensures structure matching with helper_fn for periodic boundaries values = {k: jnp.zeros(v.shape, v.dtype) for k, v in m._output_shape_dtypes.items() if k in values} check_shape_dtype(values, m._output_shape_dtypes) return values, state, latent_idx for m in self.modules: key, subkey = jax.random.split(key) values, state, latent_idx = jax.lax.cond( latent_idx == -1, dummy_fn, helper_fn, m, values, state, latent_idx, subkey ) check_shape_dtype(values, m._output_shape_dtypes) def update_state_fn(state, values, latent_idx): # Only update state data for keys that exist in values for k in state.data.keys(): if k in values: state.data[k] = state.data[k].at[latent_idx].set(values[k]) return state def update_dummy_fn(state, values, latent_idx): del values, latent_idx return state state = jax.lax.cond(latent_idx == -1, update_dummy_fn, update_state_fn, state, values, latent_idx) return state
[docs] def decompress( self, state: RecordingState, time_step: jax.Array, key: jax.Array, ) -> tuple[ dict[str, jax.Array], RecordingState, ]: # gather indices necessary to reconstruct time_filters = [m for m in self.modules if isinstance(m, TimeStepFilter)] indices: list[jax.Array] = [jnp.asarray([time_step])] time_indices: list[jax.Array] = [] for tf in time_filters: cur_time_indices = indices[-1].flatten() cur_indices = jnp.asarray([tf.indices_to_decompress(idx) for idx in cur_time_indices]) time_indices.append(cur_time_indices) indices.append(cur_indices) def reconstruction_iteration( m: CompressionModule | TimeStepFilter, state: RecordingState, key: jax.Array, latent: list[dict[str, jax.Array]], cur_tf_idx: int, ) -> tuple[ int, list[dict[str, jax.Array]], RecordingState, ]: if isinstance(m, CompressionModule): latent = [m.decompress(v, state, key=key) for v in latent] else: num_time_idx = indices[cur_tf_idx].shape[0] num_arr_idx = indices[cur_tf_idx].shape[1] next_latent = [] for cur_idx in range(0, num_time_idx): # for idx in range(start_idx, start_idx + num_idx): start_idx = cur_idx * num_arr_idx cur_v = [latent[i] for i in range(start_idx, start_idx + num_arr_idx)] arr_indices = indices[cur_tf_idx][cur_idx] time_idx = time_indices[cur_tf_idx - 1][cur_idx] next_v = m.decompress( values=cur_v, state=state, arr_indices=arr_indices, time_idx=time_idx, key=key, ) next_latent.append(next_v) latent = next_latent cur_tf_idx = cur_tf_idx - 1 for v in latent: check_shape_dtype(v, m._input_shape_dtypes) return cur_tf_idx, latent, state def bottom_up_reconstruction(state: RecordingState, key): cur_tf_idx = len(time_filters) latent: list[dict[str, jax.Array]] = [ {k: jnp.take(v, indices=idx.reshape(1), axis=0).squeeze(axis=0) for k, v in state.data.items()} for idx in indices[cur_tf_idx].flatten() ] for m in self.modules[::-1]: key, subkey = jax.random.split(key) cur_tf_idx, latent, state = reconstruction_iteration( m=m, state=state, key=subkey, latent=latent, cur_tf_idx=cur_tf_idx, ) return latent, state values, state = bottom_up_reconstruction(state, key) if len(values) != 1: raise Exception("This should never happen") return values[0], state