Source code for fdtdx.interfaces.state
import jax
from fdtdx.core.jax.pytrees import TreeClass, autoinit
from fdtdx.core.jax.sharding import create_named_sharded_matrix
from fdtdx.typing import BackendOption
[docs]
@autoinit
class RecordingState(TreeClass):
"""Container for simulation recording state data.
Holds field data and state information for FDTD simulations.
"""
#: Dictionary mapping field names to their array values.
data: dict[str, jax.Array]
#: Dictionary mapping state variable names to their array values.
state: dict[str, jax.Array]
def init_recording_state(
data_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
state_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
backend: BackendOption,
) -> RecordingState:
"""Initialize a new recording state with sharded arrays.
Creates a RecordingState instance with data and state arrays sharded across
available devices based on the provided shapes/dtypes and backend.
Args:
data_shape_dtypes (dict[str, jax.ShapeDtypeStruct]): Dictionary mapping field names to their shape/dtype specs.
state_shape_dtypes (dict[str, jax.ShapeDtypeStruct]): Dictionary mapping state names to their shape/dtype specs.
backend (BackendOption): Hardware backend to use ("gpu", "tpu", or "cpu").
Returns:
RecordingState: A new RecordingState instance with initialized sharded arrays.
"""
data = init_sharded_dict(data_shape_dtypes, backend=backend)
state = init_sharded_dict(state_shape_dtypes, backend=backend)
return RecordingState(
data=data,
state=state,
)
def init_sharded_dict(
shape_dtypes: dict[str, jax.ShapeDtypeStruct],
backend: BackendOption,
) -> dict[str, jax.Array]:
"""Initialize a dictionary of sharded arrays.
Creates arrays sharded across available devices based on the provided
shapes/dtypes and backend.
Args:
shape_dtypes (dict[str, jax.ShapeDtypeStruct]): Dictionary mapping names to shape/dtype specifications.
backend (BackendOption): Hardware backend to use ("gpu", "tpu", or "cpu").
Returns:
dict[str, jax.Array]: Dictionary mapping names to initialized sharded arrays.
"""
data = {}
for k, v in shape_dtypes.items():
num_devices = len(jax.devices(backend=backend))
shape = v.shape
if v.shape[0] % num_devices != 0:
new_shape = list(v.shape)
new_shape[0] = new_shape[0] + num_devices - (new_shape[0] % num_devices)
shape = tuple(new_shape)
arr = create_named_sharded_matrix(
shape=shape,
value=0,
sharding_axis=0,
dtype=v.dtype,
backend=backend,
)
data[k] = arr
return data