Source code for fdtdx.objects.static_material.sphere
import jax
import jax.numpy as jnp
from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.materials import compute_ordered_names
from fdtdx.objects.static_material.static import StaticMultiMaterialObject
from fdtdx.typing import UNDEFINED_SHAPE_3D, PartialGridShape3D, PartialRealShape3D
[docs]
@autoinit
class Sphere(StaticMultiMaterialObject):
"""A sphere or ellipsoid object with configurable properties.
This class represents a sphere or ellipsoid with customizable radius/radii and material.
When all three radii are equal, the shape is a perfect sphere.
The bounding-box size (diameter = 2 * radius per axis) is automatically inferred
for all three axes from the radius parameters. ``partial_real_shape`` and
``partial_grid_shape`` are not constructor parameters — they are set automatically.
Per-axis radii (``radius_x``, ``radius_y``, ``radius_z``) take precedence over the
default ``radius`` when present.
"""
#: The default radius of the sphere in meter (used if specific axis radii are not provided).
radius: float = frozen_field()
#: Name of the sphere material in the materials dictionary to be used for the object.
material_name: str = frozen_field()
#: The radius along the x-axis in meter. If none, use radius. Defaults to None.
radius_x: float | None = frozen_field(default=None)
#: The radius along the y-axis in meter. If none, use radius. Defaults to None.
radius_y: float | None = frozen_field(default=None)
#: The radius along the z-axis in meter. If none, use radius. Defaults to None.
radius_z: float | None = frozen_field(default=None)
# Derived from radius — not constructor parameters.
partial_real_shape: PartialRealShape3D = frozen_field(default=UNDEFINED_SHAPE_3D, init=False)
partial_grid_shape: PartialGridShape3D = frozen_field(default=UNDEFINED_SHAPE_3D, init=False)
def __post_init__(self):
rx = self.radius_x if self.radius_x is not None else self.radius
ry = self.radius_y if self.radius_y is not None else self.radius
rz = self.radius_z if self.radius_z is not None else self.radius
object.__setattr__(self, "partial_real_shape", (2.0 * rx, 2.0 * ry, 2.0 * rz))
[docs]
def get_voxel_mask_for_shape(self) -> jax.Array:
"""Generates a voxel mask for a sphere or ellipsoid shape.
Returns:
jax.Array: Boolean mask where True indicates voxels inside the sphere/ellipsoid.
"""
# Determine the radii for each axis
radius_x = self.radius_x if self.radius_x is not None else self.radius
radius_y = self.radius_y if self.radius_y is not None else self.radius
radius_z = self.radius_z if self.radius_z is not None else self.radius
def local_centers(axis: int) -> jax.Array:
"""Return physical cell centers relative to this object's lower edge."""
lower, upper = self.grid_slice_tuple[axis]
grid = self._config.resolved_grid
if grid is None:
spacing = self._config.uniform_spacing()
return (jnp.arange(self.grid_shape[axis]) + 0.5) * spacing
edges = grid.edges(axis)
return 0.5 * (edges[lower:upper] + edges[lower + 1 : upper + 1]) - edges[lower]
# Create 3D grid
x, y, z = jnp.meshgrid(local_centers(0), local_centers(1), local_centers(2), indexing="ij")
center_x, center_y, center_z = (0.5 * axis_size for axis_size in self.real_shape)
# Calculate normalized squared distances for each dimension using the ellipsoid equation
x_term = ((x - center_x) / radius_x) ** 2
y_term = ((y - center_y) / radius_y) ** 2
z_term = ((z - center_z) / radius_z) ** 2
# Create mask based on ellipsoid equation: points inside if x^2/a^2 + y^2/b^2 + z^2/c^2 < 1
mask = (x_term + y_term + z_term) < 1
return mask
[docs]
def get_material_mapping(
self,
) -> jax.Array:
all_names = compute_ordered_names(self.materials)
idx = all_names.index(self.material_name)
arr = jnp.ones(self.grid_shape, dtype=jnp.int32) * idx
return arr