fdtdx.DtypeConversion#

class fdtdx.DtypeConversion(*, dtype=null, exclude_filter=())[source]#

Bases: CompressionModule

Compression module that converts data types of field values.

This module changes the data type of field values while preserving their shape, useful for reducing memory usage or meeting precision requirements.

Quick Reference#

Attributes

Methods

Attributes#

DtypeConversion.dtype: dtype#

Target data type for conversion.

DtypeConversion.exclude_filter: Sequence[str]#

List of field names to exclude from conversion.

Methods#

DtypeConversion.aset(attr_name, val, create_new_ok=False)#

Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute with the new value.

The attribute can either be the attribute name of this class, or for nested classes it can also be the attribute name of a class, which itself is an attribute of this class. The syntax for this operation could look like this: “a->b->[0]->[‘name’]”. Here, the current class has an attribute a, which has an attribute b, which is a list, which we index at index 0, which is an element of type dictionary, which we index using the dictionary key ‘name’.

Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped).

Parameters:
  • attr_name (str) – Name of attribute to set

  • val (Any) – Value to set the attribute to

  • create_new_ok (bool, optional) – If false (default), throw an error if the attribute does not exist. If true, creates a new attribute if the attribute name does not exist yet.

Returns:

Updated instance with new attribute value

Return type:

Self

DtypeConversion.compress(values, state, key)[source]#

Compress field values at the current time step.

Parameters:
  • values (dict[str, jax.Array]) – Dictionary mapping field names to their values.

  • state (RecordingState) – Current recording state.

  • key (jax.Array) – Random key for stochastic operations.

Returns:

Tuple containing:
  • Dictionary of compressed field values

  • Updated recording state

Return type:

tuple[dict[str, jax.Array], RecordingState]

DtypeConversion.decompress(values, state, key)[source]#

Decompress field values back to their original form.

Parameters:
  • values (dict[str, jax.Array]) – Dictionary mapping field names to their compressed values.

  • state (RecordingState) – Current recording state.

  • key (jax.Array) – Random key for stochastic operations.

Returns:

Dictionary mapping field names to their decompressed values.

Return type:

dict[str, jax.Array]

DtypeConversion.get_class_fields()#
Return type:

list[TreeClassField]

DtypeConversion.get_public_fields()#
Return type:

list[TreeClassField]

DtypeConversion.init_shapes(input_shape_dtypes)[source]#

Initialize shapes and sizes for the compression module.

Parameters:

input_shape_dtypes (dict[str, jax.ShapeDtypeStruct]) – Dictionary mapping field names to their input shapes/dtypes.

Returns:

Tuple containing:
  • Self: Updated instance of the compression module

  • Dictionary mapping field names to their output shapes/dtypes

  • Dictionary mapping field names to their state shapes/dtypes

Return type:

tuple[Self, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]

If you find any errors in the documentation, please report them in the Github Issues!