Source code for fdtdx.core.jax.pytrees

from dataclasses import dataclass, fields
from typing import Any, Callable, Literal, Self, Sequence, TypeVar, overload

import pytreeclass as tc
from pytreeclass._src.code_build import (
    ArgKindType,
    Field,
    build_init_method,
    convert_hints_to_fields,
    dataclass_transform,
)
from pytreeclass._src.code_build import (
    field as tc_field,
)
from pytreeclass._src.tree_base import TreeClassIndexer

from fdtdx.core.null import NULL


def safe_hasattr(obj, name) -> bool:
    try:
        result = hasattr(obj, name)
        return result
    except KeyError:
        return False


class ExtendedTreeClassIndexer(TreeClassIndexer):
    """Extended indexer for tree class that preserves type information.

    Extends TreeClassIndexer to properly handle type hints and return Self type.
    """

    def __getitem__(self, where: Any) -> Self:
        return super().__getitem__(where)


@dataclass(frozen=True)
class TreeClassField:
    name: str
    type: Any
    default: Any = (NULL,)
    init: bool = True
    repr: bool = True
    kind: Literal["POS_ONLY", "POS_OR_KW", "VAR_POS", "KW_ONLY", "VAR_KW"] = "POS_OR_KW"
    metadata: dict[str, Any] | None = None
    on_setattr: Sequence[Callable] = ()
    on_getattr: Sequence[Callable] = ()
    alias: str | None = None
    value: Any = NULL

    def __iter__(self):
        """Allow conversion to dict via dict(obj)"""
        for field in fields(self):
            yield field.name, getattr(self, field.name)


[docs] class TreeClass(tc.TreeClass): """Extended tree class with improved attribute setting functionality. Extends TreeClass to provide more flexible attribute setting capabilities, particularly for handling non-recursive attribute updates. """ @property def at(self) -> ExtendedTreeClassIndexer: """Gets the extended indexer for this tree. Returns: ExtendedTreeClassIndexer: Indexer that preserves type information """ return super().at # type: ignore
[docs] def get_class_fields(self) -> list[TreeClassField]: fields = tc.fields(self) tc_fields = [] for f in fields: input_dict = {s: getattr(f, s) for s in f.__slots__} if repr(input_dict["default"]) == "NULL": # TODO: can we make this more robust? Do we need to? input_dict["default"] = NULL tc_fields.append(TreeClassField(**input_dict)) return tc_fields
[docs] def get_public_fields(self) -> list[TreeClassField]: class_fields = self.get_class_fields() value_fields = [] for f in class_fields: if not f.init: continue cur_field_dict = dict(f) cur_field_dict["value"] = getattr(self, f.name) value_fields.append(TreeClassField(**cur_field_dict)) return value_fields
def _aset( self, attr_name: str, val: Any, ): setattr(self, attr_name, val) @staticmethod def _parse_operations(s: str): if not s: raise ValueError("Empty string is not valid") operations = [] i = 0 while i < len(s): if i > 0: # Expect "->" separator if not s[i:].startswith("->"): raise ValueError(f"Expected '->' at position {i}") i += 2 # Skip "->" if i >= len(s): raise ValueError("String ends with '->'") # Parse the next operation if s[i] == "[": # Find the closing bracket j = i + 1 while j < len(s) and s[j] != "]": j += 1 if j >= len(s): raise ValueError(f"Unclosed bracket starting at position {i}") bracket_content = s[i + 1 : j].strip() # Determine if it's an integer or string if bracket_content.isdigit() or (bracket_content.startswith("-") and bracket_content[1:].isdigit()): operations.append((int(bracket_content), "index")) elif bracket_content.startswith("'") and bracket_content.endswith("'"): # Extract string content if len(bracket_content) < 2: raise ValueError(f"Invalid string format in brackets: [{bracket_content}]") string_content = bracket_content[1:-1] # Check for forbidden characters if "'" in string_content: raise ValueError(f"String keys cannot contain single quotes: '{string_content}'") if "[" in string_content or "]" in string_content: raise ValueError(f"String keys cannot contain square brackets: '{string_content}'") operations.append((string_content, "key")) else: raise ValueError(f"Invalid bracket content: [{bracket_content}]") i = j + 1 else: # Parse attribute name j = i while j < len(s) and s[j : j + 2] != "->": j += 1 attr_name = s[i:j] # Validate attribute name if not attr_name: raise ValueError(f"Empty attribute at position {i}") # Check if it's a valid Python identifier if not attr_name.isidentifier(): raise ValueError(f"Invalid attribute name: '{attr_name}'") operations.append((attr_name, "attribute")) i = j return operations
[docs] def aset( self, attr_name: str, val: Any, create_new_ok: bool = False, ) -> Self: """Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute with the new value. The attribute can either be the attribute name of this class, or for nested classes it can also be the attribute name of a class, which itself is an attribute of this class. The syntax for this operation could look like this: "a->b->[0]->['name']". Here, the current class has an attribute a, which has an attribute b, which is a list, which we index at index 0, which is an element of type dictionary, which we index using the dictionary key 'name'. Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped). Args: attr_name (str): Name of attribute to set val (Any): Value to set the attribute to create_new_ok (bool, optional): If false (default), throw an error if the attribute does not exist. If true, creates a new attribute if the attribute name does not exist yet. Returns: Self: Updated instance with new attribute value """ # parse operations ops = self._parse_operations(attr_name) # find final attribute and save intermediate attributes attr_list = [self] current_parent = self for idx, (op, op_type) in enumerate(ops): if op_type == "attribute": if not safe_hasattr(current_parent, op): if idx != len(ops) - 1 or not create_new_ok: raise Exception(f"Attribute: {op} does not exist for {current_parent.__class__}") current_parent = None else: current_parent = getattr(current_parent, op) elif op_type == "index": if "__getitem__" not in dir(current_parent): raise Exception(f"{current_parent.__class__} does not implement __getitem__") current_parent = current_parent[int(op)] # type: ignore elif op_type == "key": if "__getitem__" not in dir(current_parent): raise Exception(f"{current_parent.__class__} does not implement __getitem__") if op not in current_parent: # type: ignore if idx != len(ops) - 1 or not create_new_ok: raise Exception(f"Key: {op} does not exist for {current_parent}") current_parent = None else: current_parent = current_parent[op] # type: ignore else: raise Exception(f"Invalid operation type: {op_type}. This is an internal bug!") if idx != len(ops) - 1: assert current_parent is not None attr_list.append(current_parent) # from bottom-up set attributes and update cur_attr = val for idx in list(range(len(attr_list)))[::-1]: op, op_type = ops[idx] current_parent = attr_list[idx] if op_type == "attribute": if not isinstance(current_parent, TreeClass): raise Exception(f"Can only set attribute on ExtendedTreeClass, but got {current_parent.__class__}") _, cur_attr = current_parent.at["_aset"](op, cur_attr) elif op_type == "index": if "__setitem__" not in dir(current_parent): raise Exception( f"Can only update by index if __setitem__ is implemented, but got {current_parent.__class__}" ) cpy = current_parent.copy() # type: ignore cpy[int(op)] = cur_attr cur_attr = cpy elif op_type == "key": if "__setitem__" not in dir(current_parent): raise Exception( f"Can only update by index if __setitem__ is implemented, but got {current_parent.__class__}" ) cpy = current_parent.copy() # type: ignore cpy[op] = cur_attr cur_attr = cpy else: raise Exception(f"Invalid operation type: {op_type}. This is an internal bug!") assert cur_attr.__class__ == self.__class__ return cur_attr
T = TypeVar("T") @overload def field( *, default: T, init: bool = True, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> T: ... @overload def field( *, init: bool = True, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> Any: ...
[docs] def field( *, default: Any = NULL, init: bool = True, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> Any: """ A wrapper for pytreeclass fields. Allows specification of more advanced features. Args: default (Any, optional): The default value for the field. Defaults to None. init (bool, optional): Whether to include the field in __init__. Defaults to True. repr (bool, optional): Whether to include the field in __repr__. Defaults to True. kind (ArgKindType, optional): The argument kind (POS_ONLY, POS_OR_KW, etc.). Defaults to KW_ONLY. metadata (dict[str, Any] | None, optional): Additional metadata for the field. Defaults to None. on_setattr (Sequence[Any], optional): Additional setattr callbacks. Defaults to no callbacks. on_getattr (Sequence[Any], optional): Additional getattr callbacks. Defaults to no callbacks. alias (str | None, optional): Alternative name for the field in __init__. Defaults to None Returns: Any: A Field instance configured with freeze/unfreeze behavior """ return tc_field( default=default, init=init, repr=repr, kind=kind, metadata=metadata, on_setattr=on_setattr, on_getattr=on_getattr, alias=alias, )
@overload def private_field( *, default: T, init: bool = False, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> T: ... @overload def private_field( *, init: bool = False, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> Any: ...
[docs] def private_field( *, default: Any = NULL, init: bool = False, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> Any: """ Creates a field that sets the default to None and init to False. Args: default (Any, optional): The default value for the field. Defaults to None. init (bool, optional): Whether to include the field in __init__. Defaults to False. repr (bool, optional): Whether to include the field in __repr__. Defaults to True. kind (ArgKindType, optional): The argument kind (POS_ONLY, POS_OR_KW, etc.). Defaults to KW_ONLY. metadata (dict[str, Any] | None, optional): Additional metadata for the field. Defaults to None. on_setattr (Sequence[Any], optional): Additional setattr callbacks. Defaults to no callbacks. on_getattr (Sequence[Any], optional): Additional getattr callbacks. Defaults to no callbacks. alias (str | None, optional): Alternative name for the field in __init__. Defaults to None Returns: Any: A private field instance. """ return tc_field( default=default, init=init, repr=repr, kind=kind, metadata=metadata, on_setattr=on_setattr, on_getattr=on_getattr, alias=alias, )
@overload def frozen_field( *, default: T, init: bool = True, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> T: ... @overload def frozen_field( *, init: bool = True, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> Any: ...
[docs] def frozen_field( *, default: Any = NULL, init: bool = True, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> Any: """Creates a field that automatically freezes on set and unfreezes on get. This field behaves like a regular pytreeclass field but ensures values are frozen when stored and unfrozen when accessed. Args: default (Any, optional): The default value for the field. Defaults to None. init (bool, optional): Whether to include the field in __init__. Defaults to True. repr (bool, optional): Whether to include the field in __repr__. Defaults to True. kind (ArgKindType, optional): The argument kind (POS_ONLY, POS_OR_KW, etc.). Defaults to KW_ONLY. metadata (dict[str, Any] | None, optional): Additional metadata for the field. Defaults to None. on_setattr (Sequence[Any], optional): Additional setattr callbacks (applied after freezing). Defaults to no callbacks. on_getattr (Sequence[Any], optional): Additional getattr callbacks (applied after unfreezing). Defaults to no callbacks. alias (str | None, optional): Alternative name for the field in __init__. Defaults to None Returns: Any: A Field instance configured with freeze/unfreeze behavior """ return tc_field( default=default, init=init, repr=repr, kind=kind, metadata=metadata, on_setattr=[*list(on_setattr), tc.freeze], on_getattr=[tc.unfreeze, *list(on_getattr)], alias=alias, )
@overload def frozen_private_field( *, default: T, init: bool = False, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> T: ... @overload def frozen_private_field( *, init: bool = False, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> Any: ...
[docs] def frozen_private_field( *, default: Any = None, init: bool = False, repr: bool = True, kind: ArgKindType = "KW_ONLY", metadata: dict[str, Any] | None = None, on_setattr: Sequence[Any] = (), on_getattr: Sequence[Any] = (), alias: str | None = None, ) -> Any: """Creates a field that automatically freezes on set and unfreezes on get, sets the default to None and init to False. This field behaves like a regular pytreeclass field but ensures values are frozen when stored and unfrozen when accessed. Args: default (Any, optional): The default value for the field. Defaults to None. init (bool, optional): Whether to include the field in __init__. Defaults to False. repr (bool, optional): Whether to include the field in __repr__. Defaults to True. kind (ArgKindType, optional): The argument kind (POS_ONLY, POS_OR_KW, etc.). Defaults to KW_ONLY. metadata (dict[str, Any] | None, optional): Additional metadata for the field. Defaults to None. on_setattr (Sequence[Any], optional): Additional setattr callbacks (applied after freezing). Defaults to no callbacks. on_getattr (Sequence[Any], optional): Additional getattr callbacks (applied after unfreezing). Defaults to no callbacks. alias (str | None, optional): Alternative name for the field in __init__. Defaults to None. Returns: Any: A Field instance configured with freeze/unfreeze behavior """ return frozen_field( default=default, init=init, repr=repr, kind=kind, metadata=metadata, on_setattr=on_setattr, on_getattr=on_getattr, alias=alias, )
[docs] @dataclass_transform( field_specifiers=(Field, tc_field, frozen_field, frozen_private_field, field, private_field), kw_only_default=True, ) def autoinit(klass: type[T]) -> type[T]: """Wrapper around tc.autoinit that preserves parameter requirement information""" return ( klass # if the class already has a user-defined __init__ method # then return the class as is without any modification if "__init__" in vars(klass) # first convert the current class hints to fields # then build the __init__ method from the fields of the current class # and any base classes that are decorated with `autoinit` else build_init_method(convert_hints_to_fields(klass)) )