Source code for fdtdx.objects.static_material.sphere

import jax
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.materials import compute_ordered_names
from fdtdx.objects.static_material.static import StaticMultiMaterialObject
from fdtdx.typing import UNDEFINED_SHAPE_3D, PartialGridShape3D, PartialRealShape3D


[docs] @autoinit class Sphere(StaticMultiMaterialObject): """A sphere or ellipsoid object with configurable properties. This class represents a sphere or ellipsoid with customizable radius/radii and material. When all three radii are equal, the shape is a perfect sphere. The bounding-box size (diameter = 2 * radius per axis) is automatically inferred for all three axes from the radius parameters. ``partial_real_shape`` and ``partial_grid_shape`` are not constructor parameters — they are set automatically. Per-axis radii (``radius_x``, ``radius_y``, ``radius_z``) take precedence over the default ``radius`` when present. """ #: The default radius of the sphere in meter (used if specific axis radii are not provided). radius: float = frozen_field() #: Name of the sphere material in the materials dictionary to be used for the object. material_name: str = frozen_field() #: The radius along the x-axis in meter. If none, use radius. Defaults to None. radius_x: float | None = frozen_field(default=None) #: The radius along the y-axis in meter. If none, use radius. Defaults to None. radius_y: float | None = frozen_field(default=None) #: The radius along the z-axis in meter. If none, use radius. Defaults to None. radius_z: float | None = frozen_field(default=None) # Derived from radius — not constructor parameters. partial_real_shape: PartialRealShape3D = frozen_field(default=UNDEFINED_SHAPE_3D, init=False) partial_grid_shape: PartialGridShape3D = frozen_field(default=UNDEFINED_SHAPE_3D, init=False) def __post_init__(self): rx = self.radius_x if self.radius_x is not None else self.radius ry = self.radius_y if self.radius_y is not None else self.radius rz = self.radius_z if self.radius_z is not None else self.radius object.__setattr__(self, "partial_real_shape", (2.0 * rx, 2.0 * ry, 2.0 * rz))
[docs] def get_voxel_mask_for_shape(self) -> jax.Array: """Generates a voxel mask for a sphere or ellipsoid shape. Returns: jax.Array: Boolean mask where True indicates voxels inside the sphere/ellipsoid. """ # Determine the radii for each axis radius_x = self.radius_x if self.radius_x is not None else self.radius radius_y = self.radius_y if self.radius_y is not None else self.radius radius_z = self.radius_z if self.radius_z is not None else self.radius def local_centers(axis: int) -> jax.Array: """Return physical cell centers relative to this object's lower edge.""" lower, upper = self.grid_slice_tuple[axis] grid = self._config.resolved_grid if grid is None: spacing = self._config.uniform_spacing() return (jnp.arange(self.grid_shape[axis]) + 0.5) * spacing edges = grid.edges(axis) return 0.5 * (edges[lower:upper] + edges[lower + 1 : upper + 1]) - edges[lower] # Create 3D grid x, y, z = jnp.meshgrid(local_centers(0), local_centers(1), local_centers(2), indexing="ij") center_x, center_y, center_z = (0.5 * axis_size for axis_size in self.real_shape) # Calculate normalized squared distances for each dimension using the ellipsoid equation x_term = ((x - center_x) / radius_x) ** 2 y_term = ((y - center_y) / radius_y) ** 2 z_term = ((z - center_z) / radius_z) ** 2 # Create mask based on ellipsoid equation: points inside if x^2/a^2 + y^2/b^2 + z^2/c^2 < 1 mask = (x_term + y_term + z_term) < 1 return mask
[docs] def get_material_mapping( self, ) -> jax.Array: all_names = compute_ordered_names(self.materials) idx = all_names.index(self.material_name) arr = jnp.ones(self.grid_shape, dtype=jnp.int32) * idx return arr