Source code for fdtdx.objects.static_material.cylinder
import jax
import jax.numpy as jnp
from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.materials import compute_ordered_names
from fdtdx.objects.static_material.static import StaticMultiMaterialObject
[docs]
@autoinit
class Cylinder(StaticMultiMaterialObject):
"""A cylindrical optical fiber with configurable properties.
This class represents a cylindrical fiber with customizable radius, material,
and orientation. The fiber can be positioned along any of the three principal axes.
The cross-section size (diameter = 2 * radius) is automatically inferred for the
two axes perpendicular to ``axis``, so ``partial_real_shape`` does not need to be
specified for those axes. The extrusion axis size must still be determined by a
constraint or an explicit ``partial_real_shape`` entry.
"""
#: The radius of the fiber in meter.
radius: float = frozen_field()
#: The principal axis along which the fiber extends (0=x, 1=y, 2=z).
axis: int = frozen_field()
#: Name of the material in the materials dictionary to be used for the object.
material_name: str = frozen_field()
def __post_init__(self):
diameter = 2.0 * self.radius
real_shape = list(self.partial_real_shape)
grid_shape = list(self.partial_grid_shape)
for ax in (self.horizontal_axis, self.vertical_axis):
if real_shape[ax] is not None:
raise Exception(
f"Cylinder {self.name}: partial_real_shape for axis {ax} is derived from the radius "
f"({diameter:.3e} m). Do not specify it explicitly."
)
if grid_shape[ax] is not None:
raise Exception(
f"Cylinder {self.name}: partial_grid_shape for axis {ax} is derived from the radius. "
f"Do not specify it explicitly."
)
real_shape[ax] = diameter
object.__setattr__(self, "partial_real_shape", tuple(real_shape))
@property
def horizontal_axis(self) -> int:
"""Gets the horizontal axis perpendicular to the fiber axis.
Returns:
int: The index of the horizontal axis (0=x or 1=y).
"""
if self.axis == 0:
return 1
return 0
@property
def vertical_axis(self) -> int:
"""Gets the vertical axis perpendicular to the fiber axis.
Returns:
int: The index of the vertical axis (1=y or 2=z).
"""
if self.axis == 2:
return 1
return 2
[docs]
def get_voxel_mask_for_shape(self) -> jax.Array:
def local_centers(axis: int) -> jax.Array:
"""Return physical cell centers relative to this object's lower edge."""
lower, upper = self.grid_slice_tuple[axis]
grid = self._config.resolved_grid
if grid is None:
spacing = self._config.uniform_spacing()
return (jnp.arange(self.grid_shape[axis]) + 0.5) * spacing
edges = grid.edges(axis)
return 0.5 * (edges[lower:upper] + edges[lower + 1 : upper + 1]) - edges[lower]
horizontal = local_centers(self.horizontal_axis)
vertical = local_centers(self.vertical_axis)
horizontal_grid, vertical_grid = jnp.meshgrid(horizontal, vertical, indexing="ij")
center_h = 0.5 * self.real_shape[self.horizontal_axis]
center_v = 0.5 * self.real_shape[self.vertical_axis]
grid = jnp.stack((horizontal_grid - center_h, vertical_grid - center_v), axis=-1) / self.radius
mask = (grid**2).sum(axis=-1) < 1
mask = jnp.expand_dims(mask, axis=self.axis)
return mask
[docs]
def get_material_mapping(
self,
) -> jax.Array:
all_names = compute_ordered_names(self.materials)
idx = all_names.index(self.material_name)
arr = jnp.ones(self.grid_shape, dtype=jnp.int32) * idx
return arr