Source code for fdtdx.objects.static_material.cylinder

import jax
import jax.numpy as jnp

from fdtdx.core.jax.pytrees import autoinit, frozen_field
from fdtdx.materials import compute_ordered_names
from fdtdx.objects.static_material.static import StaticMultiMaterialObject


[docs] @autoinit class Cylinder(StaticMultiMaterialObject): """A cylindrical optical fiber with configurable properties. This class represents a cylindrical fiber with customizable radius, material, and orientation. The fiber can be positioned along any of the three principal axes. The cross-section size (diameter = 2 * radius) is automatically inferred for the two axes perpendicular to ``axis``, so ``partial_real_shape`` does not need to be specified for those axes. The extrusion axis size must still be determined by a constraint or an explicit ``partial_real_shape`` entry. """ #: The radius of the fiber in meter. radius: float = frozen_field() #: The principal axis along which the fiber extends (0=x, 1=y, 2=z). axis: int = frozen_field() #: Name of the material in the materials dictionary to be used for the object. material_name: str = frozen_field() def __post_init__(self): diameter = 2.0 * self.radius real_shape = list(self.partial_real_shape) grid_shape = list(self.partial_grid_shape) for ax in (self.horizontal_axis, self.vertical_axis): if real_shape[ax] is not None: raise Exception( f"Cylinder {self.name}: partial_real_shape for axis {ax} is derived from the radius " f"({diameter:.3e} m). Do not specify it explicitly." ) if grid_shape[ax] is not None: raise Exception( f"Cylinder {self.name}: partial_grid_shape for axis {ax} is derived from the radius. " f"Do not specify it explicitly." ) real_shape[ax] = diameter object.__setattr__(self, "partial_real_shape", tuple(real_shape)) @property def horizontal_axis(self) -> int: """Gets the horizontal axis perpendicular to the fiber axis. Returns: int: The index of the horizontal axis (0=x or 1=y). """ if self.axis == 0: return 1 return 0 @property def vertical_axis(self) -> int: """Gets the vertical axis perpendicular to the fiber axis. Returns: int: The index of the vertical axis (1=y or 2=z). """ if self.axis == 2: return 1 return 2
[docs] def get_voxel_mask_for_shape(self) -> jax.Array: width = self.grid_shape[self.vertical_axis] height = self.grid_shape[self.horizontal_axis] center = (height / 2, width / 2) grid_radius_exact = self.radius / self._config.resolution grid = ( jnp.stack(jnp.meshgrid(*map(jnp.arange, (width, height)), indexing="xy"), axis=-1) - jnp.asarray(center) + 0.5 ) / jnp.asarray(grid_radius_exact) mask = (grid**2).sum(axis=-1) < 1 mask = jnp.expand_dims(mask, axis=self.axis) return mask
[docs] def get_material_mapping( self, ) -> jax.Array: all_names = compute_ordered_names(self.materials) idx = all_names.index(self.material_name) arr = jnp.ones(self.grid_shape, dtype=jnp.int32) * idx return arr