fdtdx.DtypeConversion#
- class fdtdx.DtypeConversion(*, dtype=null, exclude_filter=())[source]#
Bases:
CompressionModuleCompression module that converts data types of field values.
This module changes the data type of field values while preserving their shape, useful for reducing memory usage or meeting precision requirements.
Quick Reference#
Attributes
Methods
Attributes#
-
DtypeConversion.dtype:
dtype# Target data type for conversion.
-
DtypeConversion.exclude_filter:
Sequence[str]# List of field names to exclude from conversion.
Methods#
- DtypeConversion.aset(attr_name, val, create_new_ok=False)#
Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute with the new value.
The attribute can either be the attribute name of this class, or for nested classes it can also be the attribute name of a class, which itself is an attribute of this class. The syntax for this operation could look like this: “a->b->[0]->[‘name’]”. Here, the current class has an attribute a, which has an attribute b, which is a list, which we index at index 0, which is an element of type dictionary, which we index using the dictionary key ‘name’.
Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped).
- Parameters:
attr_name (str) – Name of attribute to set
val (Any) – Value to set the attribute to
create_new_ok (bool, optional) – If false (default), throw an error if the attribute does not exist. If true, creates a new attribute if the attribute name does not exist yet.
- Returns:
Updated instance with new attribute value
- Return type:
Self
- DtypeConversion.compress(values, state, key)[source]#
Compress field values at the current time step.
- Parameters:
values (dict[str, jax.Array]) – Dictionary mapping field names to their values.
state (RecordingState) – Current recording state.
key (jax.Array) – Random key for stochastic operations.
- Returns:
- Tuple containing:
Dictionary of compressed field values
Updated recording state
- Return type:
tuple[dict[str, jax.Array], RecordingState]
- DtypeConversion.decompress(values, state, key)[source]#
Decompress field values back to their original form.
- Parameters:
values (dict[str, jax.Array]) – Dictionary mapping field names to their compressed values.
state (RecordingState) – Current recording state.
key (jax.Array) – Random key for stochastic operations.
- Returns:
Dictionary mapping field names to their decompressed values.
- Return type:
dict[str, jax.Array]
- DtypeConversion.get_class_fields()#
- Return type:
list[TreeClassField]
- DtypeConversion.get_public_fields()#
- Return type:
list[TreeClassField]
- DtypeConversion.init_shapes(input_shape_dtypes)[source]#
Initialize shapes and sizes for the compression module.
- Parameters:
input_shape_dtypes (dict[str, jax.ShapeDtypeStruct]) – Dictionary mapping field names to their input shapes/dtypes.
- Returns:
- Tuple containing:
Self: Updated instance of the compression module
Dictionary mapping field names to their output shapes/dtypes
Dictionary mapping field names to their state shapes/dtypes
- Return type:
tuple[Self, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]
If you find any errors in the documentation, please report them in the Github Issues!