Source code for fdtdx.fdtd.initialization

import math
import warnings
from typing import Any, Sequence

import jax
import jax.numpy as jnp

from fdtdx import constants
from fdtdx.config import SimulationConfig
from fdtdx.core.grid import RectilinearGrid
from fdtdx.core.jax.guards import check_not_tracing
from fdtdx.core.jax.sharding import create_named_sharded_matrix
from fdtdx.core.jax.ste import straight_through_estimator
from fdtdx.fdtd.container import ArrayContainer, FieldState, ObjectContainer, ParameterContainer
from fdtdx.materials import (
    compute_allowed_dispersive_coefficients,
    compute_allowed_electric_conductivities,
    compute_allowed_magnetic_conductivities,
    compute_allowed_permeabilities,
    compute_allowed_permittivities,
    compute_pole_coefficients,
)
from fdtdx.objects.boundaries.bloch import BlochBoundary
from fdtdx.objects.device.parameters.transform import ParameterType
from fdtdx.objects.object import (
    GridCoordinateConstraint,
    PositionConstraint,
    RealCoordinateConstraint,
    SimulationObject,
    SizeConstraint,
    SizeExtensionConstraint,
)
from fdtdx.objects.static_material.static import SimulationVolume, StaticMultiMaterialObject, UniformMaterialObject

DEFAULT_MAX_ITER = 1000


def _warn_if_simulation_volume_too_large(grid_shape: tuple[int, int, int]) -> None:
    num_cells = math.prod(grid_shape)
    if num_cells > constants.MAX_SIMULATION_VOLUME_CELLS:
        warnings.warn(
            f"Simulation volume has {num_cells:,} cells (grid shape {grid_shape}), "
            f"which exceeds the recommended limit of {constants.MAX_SIMULATION_VOLUME_CELLS:,}. "
            "Allocating FDTD field arrays may require excessive memory and fail.",
            UserWarning,
            stacklevel=3,
        )


AnyConstraint = (
    PositionConstraint | SizeConstraint | SizeExtensionConstraint | GridCoordinateConstraint | RealCoordinateConstraint
)


[docs] def place_objects( object_list: Sequence[SimulationObject], config: SimulationConfig, constraints: Sequence[AnyConstraint], key: jax.Array, ) -> tuple[ ObjectContainer, ArrayContainer, ParameterContainer, SimulationConfig, dict[str, Any], ]: """Places simulation objects according to specified constraints and initializes containers. Args: objects (list[SimulationObject]): List of all simulation objects, including the simulation volume. config (SimulationConfig): Simulation configuration. constraints (Sequence[Constraint]): List of positioning/sizing constraints referencing object names. key (jax.Array): JAX random key for initialization. Returns: tuple[ObjectContainer, ArrayContainer, ParameterContainer, SimulationConfig, dict[str, Any]]: A tuple containing: - ObjectContainer with placed simulation objects - ArrayContainer with initialized field arrays - ParameterContainer with device parameters - Updated SimulationConfig - Dictionary with additional initialization info Raises: ValueError: If constraint resolution fails for one or more objects. """ # Step 0: Check if called inside a JIT trace check_not_tracing("fdtdx.place_objects") # Step 1: Resolve constraints into grid slices resolved_slices, errors = resolve_object_constraints( objects=object_list, constraints=constraints, config=config, ) # Step 2: Aggregate errors and raise if needed failed = {name: msg for name, msg in errors.items() if msg} if failed: formatted = "\n".join(f" - {name}: {msg}" for name, msg in failed.items()) raise ValueError(f"Failed to resolve object constraints:\n{formatted}") # Step 3: Convert name → object for placement object_map = {obj.name: obj for obj in object_list} volume_name = _resolve_volume_name(object_map) volume_obj = object_map[volume_name] volume_shape = tuple(s1 - s0 for s0, s1 in resolved_slices[volume_obj.name]) grid = config.resolve_grid(volume_shape) # Resolve user grid policy before objects see the config. if grid.shape != volume_shape: raise ValueError(f"Configured grid shape {grid.shape} does not match simulation volume shape {volume_shape}.") if not isinstance(config.grid, RectilinearGrid): config = config.aset("grid", grid) # Step 4: Place objects on grid based on resolved slice tuples placed_objects = [] for name, slice_tuple in resolved_slices.items(): if name == volume_obj.name: continue obj = object_map[name] key, subkey = jax.random.split(key) placed_objects.append( obj.place_on_grid( grid_slice_tuple=slice_tuple, config=config, key=subkey, ) ) # Step 5: Place volume first (index 0) key, subkey = jax.random.split(key) placed_objects.insert( 0, volume_obj.place_on_grid( grid_slice_tuple=resolved_slices[volume_obj.name], config=config, key=subkey, ), ) # Step 6: Create object container objects_container = ObjectContainer( object_list=placed_objects, volume_idx=0, ) # Step 7: Initialize parameters and arrays params = _init_params(objects=objects_container, key=key) arrays, config, info = _init_arrays(objects=objects_container, config=config) # Step 8: Update object configs with compiled configuration new_object_list = [] for o in objects_container.objects: o = o.aset("_config", config) new_object_list.append(o) objects_container = ObjectContainer( object_list=new_object_list, volume_idx=0, ) return objects_container, arrays, params, config, info
[docs] def apply_params( arrays: ArrayContainer, objects: ObjectContainer, params: ParameterContainer, key: jax.Array, **transform_kwargs, ) -> tuple[ArrayContainer, ObjectContainer, dict[str, Any]]: """Applies parameters to devices and updates source states. Args: arrays (ArrayContainer): Container with field arrays objects (ObjectContainer): Container with simulation objects params (ParameterContainer): Container with device parameters key (jax.Array): JAX random key for source updates **transform_kwargs: Keyword arguments passed to the parameter transformation. Returns: tuple[ArrayContainer, ObjectContainer, dict[str, Any]]: A tuple containing: - Updated ArrayContainer with applied device parameters - Updated ObjectContainer with new source states - Dictionary with parameter application info """ info = {} # Determine number of components from existing array shape num_perm_components = arrays.inv_permittivities.shape[0] isotropic = num_perm_components == 1 diagonally_anisotropic = num_perm_components == 3 num_dispersive_poles = arrays.dispersive_c1.shape[0] if arrays.dispersive_c1 is not None else 0 # apply parameter to devices for device in objects.devices: cur_material_indices = device(params[device.name], expand_to_sim_grid=True, **transform_kwargs) # allowed_perm_list is list of tuples with length 1 (isotropic) or 3 (diagonally anisotropic) or 9 (fully anisotropic) allowed_perm_array = jnp.asarray( compute_allowed_permittivities( device.materials, isotropic=isotropic, diagonally_anisotropic=diagonally_anisotropic, ) ) # shape: (num_materials, num_components) if isotropic or diagonally_anisotropic: inv_allowed = 1.0 / allowed_perm_array # (num_materials, num_components) else: # Fully anisotropic: reshape to 3x3 matrix, invert, and flatten back to 9 elements inv_allowed = jnp.array([jnp.linalg.inv(perm.reshape(3, 3)).flatten() for perm in allowed_perm_array]) # When any object in the sim is dispersive (num_dispersive_poles > 0) we # always write the coefficient stack into the device's grid_slice — even # when none of the device's materials are dispersive themselves. Otherwise # stale coefficients from an underlying dispersive region would survive # and keep evolving polarization in the device's voxels. # compute_allowed_dispersive_coefficients zero-pads non-dispersive materials. write_dispersive = num_dispersive_poles > 0 # Initialise dispersive slots; populated below when write_dispersive is True. allowed_c1_arr = allowed_c2_arr = allowed_c3_arr = None new_c1_slice = new_c2_slice = new_c3_slice = None if write_dispersive: assert ( arrays.dispersive_c1 is not None and arrays.dispersive_c2 is not None and arrays.dispersive_c3 is not None ) dt = device._config.time_step_duration allowed_c1_np, allowed_c2_np, allowed_c3_np = compute_allowed_dispersive_coefficients( device.materials, dt=dt, max_num_poles=num_dispersive_poles, ) allowed_c1_arr = jnp.asarray(allowed_c1_np, dtype=arrays.dispersive_c1.dtype) allowed_c2_arr = jnp.asarray(allowed_c2_np, dtype=arrays.dispersive_c2.dtype) allowed_c3_arr = jnp.asarray(allowed_c3_np, dtype=arrays.dispersive_c3.dtype) if device.output_type == ParameterType.CONTINUOUS: # Linear interpolation between two materials # Add spatial broadcast dims for element-wise multiplication inv_allowed_bc = inv_allowed[:, :, None, None, None] # cur_material_indices: (*grid_shape) broadcasts with (num_components, 1, 1, 1) new_perm_slice = (1 - cur_material_indices) * inv_allowed_bc[0] + cur_material_indices * inv_allowed_bc[1] if write_dispersive: assert allowed_c1_arr is not None and allowed_c2_arr is not None and allowed_c3_arr is not None # Linear interpolation of dispersive coefficients between the two bracketing materials. # Note: this follows the same straight-through-estimator convention as the # permittivity path above — it is *not* equivalent to a material whose # epsilon and poles are linearly interpolated, but it is the same # continuous relaxation used for inv_permittivities, so gradients still # flow smoothly through the device parameters during inverse design. # allowed_cN_arr: (num_materials, num_poles) — here num_materials == 2. # reshape to (num_poles, 1, 1, 1, 1) for broadcast over (num_poles, 1, Nx, Ny, Nz) w0 = (1 - cur_material_indices)[None, None, ...] # (1, 1, Nx, Ny, Nz) w1 = cur_material_indices[None, None, ...] c1_0 = allowed_c1_arr[0][:, None, None, None, None] # (num_poles, 1, 1, 1, 1) c1_1 = allowed_c1_arr[1][:, None, None, None, None] c2_0 = allowed_c2_arr[0][:, None, None, None, None] c2_1 = allowed_c2_arr[1][:, None, None, None, None] c3_0 = allowed_c3_arr[0][:, None, None, None, None] c3_1 = allowed_c3_arr[1][:, None, None, None, None] new_c1_slice = w0 * c1_0 + w1 * c1_1 new_c2_slice = w0 * c2_0 + w1 * c2_1 new_c3_slice = w0 * c3_0 + w1 * c3_1 else: # Discrete material selection # inv_allowed[indices] -> (*grid_shape, num_components), then moveaxis -> (num_components, *grid_shape) component_values = jnp.moveaxis(inv_allowed[cur_material_indices.astype(jnp.int32)], -1, 0) component_values = straight_through_estimator(cur_material_indices, component_values) new_perm_slice = component_values if write_dispersive: assert allowed_c1_arr is not None and allowed_c2_arr is not None and allowed_c3_arr is not None int_idx = cur_material_indices.astype(jnp.int32) # allowed_cN_arr[int_idx]: (Nx, Ny, Nz, num_poles) -> moveaxis -> (num_poles, Nx, Ny, Nz) new_c1_slice = jnp.moveaxis(allowed_c1_arr[int_idx], -1, 0)[:, None, ...] new_c2_slice = jnp.moveaxis(allowed_c2_arr[int_idx], -1, 0)[:, None, ...] new_c3_slice = jnp.moveaxis(allowed_c3_arr[int_idx], -1, 0)[:, None, ...] # Update all components of inv_permittivities array at once new_perm = arrays.inv_permittivities.at[:, *device.grid_slice].set(new_perm_slice) arrays = arrays.at["inv_permittivities"].set(new_perm) if write_dispersive: assert ( arrays.dispersive_c1 is not None and arrays.dispersive_c2 is not None and arrays.dispersive_c3 is not None ) new_c1 = arrays.dispersive_c1.at[:, :, *device.grid_slice].set(new_c1_slice) new_c2 = arrays.dispersive_c2.at[:, :, *device.grid_slice].set(new_c2_slice) new_c3 = arrays.dispersive_c3.at[:, :, *device.grid_slice].set(new_c3_slice) # Recompute inv_c2 from the post-interpolation c2. Do NOT interpolate # inv_c2 directly: 1/avg(c2) != avg(1/c2), and the reverse-time ADE # relies on inv_c2 being the exact reciprocal of the stored c2. new_inv_c2 = jnp.where(new_c2 == 0, 0.0, 1.0 / new_c2) arrays = arrays.at["dispersive_c1"].set(new_c1) arrays = arrays.at["dispersive_c2"].set(new_c2) arrays = arrays.at["dispersive_c3"].set(new_c3) arrays = arrays.at["dispersive_inv_c2"].set(new_inv_c2) # apply random key to sources. Source-side sampling of the dispersion # coefficients (used only for carrier-frequency impedance / energy # normalization) is stop_gradient'd to match the treatment of # ``inv_permittivities`` — the FDTD VJP itself still propagates gradient # through the coefficients, so this only avoids noise from the source # amplitude path. disp_c1 = None if arrays.dispersive_c1 is None else jax.lax.stop_gradient(arrays.dispersive_c1) disp_c2 = None if arrays.dispersive_c2 is None else jax.lax.stop_gradient(arrays.dispersive_c2) disp_c3 = None if arrays.dispersive_c3 is None else jax.lax.stop_gradient(arrays.dispersive_c3) new_objects = [] for obj in objects.object_list: key, subkey = jax.random.split(key) new_obj = obj.apply( key=subkey, inv_permittivities=jax.lax.stop_gradient(arrays.inv_permittivities), inv_permeabilities=jax.lax.stop_gradient(arrays.inv_permeabilities), dispersive_c1=disp_c1, dispersive_c2=disp_c2, dispersive_c3=disp_c3, ) new_objects.append(new_obj) new_objects = ObjectContainer( object_list=new_objects, volume_idx=objects.volume_idx, ) return arrays, new_objects, info
def _init_arrays( objects: ObjectContainer, config: SimulationConfig, ) -> tuple[ArrayContainer, SimulationConfig, dict[str, Any]]: """Initializes field arrays and material properties for the simulation. Creates and initializes the E/H fields, permittivity/permeability arrays, detector states, boundary states and recording states based on the simulation objects and configuration. Args: objects (ObjectContainer): Container with simulation objects config (SimulationConfig): The simulation configuration Returns: tuple[ArrayContainer, SimulationConfig, dict[str, Any]]: A tuple containing: - ArrayContainer with initialized arrays and states - Updated SimulationConfig - Dictionary with initialization info """ # create E/H fields volume_shape = objects.volume.grid_shape _warn_if_simulation_volume_too_large(volume_shape) grid = config.resolve_grid(volume_shape) if grid.shape != volume_shape: raise ValueError(f"Configured grid shape {grid.shape} does not match simulation volume shape {volume_shape}.") ext_shape = (3, *volume_shape) # Determine whether to use complex-valued fields needs_complex = any(isinstance(o, BlochBoundary) and o.needs_complex_fields for o in objects.boundary_objects) if config.use_complex_fields is None: # Auto-detect: promote to complex if any Bloch boundary has non-zero k use_complex = needs_complex else: use_complex = config.use_complex_fields if needs_complex and not use_complex: raise ValueError( "use_complex_fields=False but Bloch boundaries with non-zero " "wave vector are present. These require complex-valued fields." ) if use_complex: field_dtype = jnp.complex64 if config.dtype == jnp.float32 else jnp.complex128 else: field_dtype = config.dtype E = create_named_sharded_matrix( ext_shape, sharding_axis=1, value=0.0, dtype=field_dtype, backend=config.backend, ) H = create_named_sharded_matrix( ext_shape, value=0.0, dtype=field_dtype, sharding_axis=1, backend=config.backend, ) # create auxiliary fields psi_E and psi_H for PML boundaries psi_E = create_named_sharded_matrix( (6, *volume_shape), sharding_axis=1, value=0.0, dtype=field_dtype, backend=config.backend, ) psi_H = create_named_sharded_matrix( (6, *volume_shape), value=0.0, dtype=field_dtype, sharding_axis=1, backend=config.backend, ) # create alpha, kappa, and sigma arrays alpha = create_named_sharded_matrix( (6, *volume_shape), sharding_axis=1, value=0.0, dtype=config.dtype, backend=config.backend, ) kappa = create_named_sharded_matrix( (6, *volume_shape), sharding_axis=1, value=1.0, dtype=config.dtype, backend=config.backend, ) sigma = create_named_sharded_matrix( (6, *volume_shape), sharding_axis=1, value=0.0, dtype=config.dtype, backend=config.backend, ) # Determine isotropy flags isotropic_permittivity = objects.all_objects_isotropic_permittivity isotropic_permeability = objects.all_objects_isotropic_permeability isotropic_electric_conductivity = objects.all_objects_isotropic_electric_conductivity isotropic_magnetic_conductivity = objects.all_objects_isotropic_magnetic_conductivity # Determine diagonally anisotropic flags diagonally_anisotropic_permittivity = objects.all_objects_diagonally_anisotropic_permittivity diagonally_anisotropic_permeability = objects.all_objects_diagonally_anisotropic_permeability diagonally_anisotropic_electric_conductivity = objects.all_objects_diagonally_anisotropic_electric_conductivity diagonally_anisotropic_magnetic_conductivity = objects.all_objects_diagonally_anisotropic_magnetic_conductivity # Get component counts for each property if isotropic_permittivity: num_perm_components = 1 elif diagonally_anisotropic_permittivity: num_perm_components = 3 else: num_perm_components = 9 if isotropic_permeability: num_permeability_components = 1 elif diagonally_anisotropic_permeability: num_permeability_components = 3 else: num_permeability_components = 9 if isotropic_electric_conductivity: num_electric_cond_components = 1 elif diagonally_anisotropic_electric_conductivity: num_electric_cond_components = 3 else: num_electric_cond_components = 9 if isotropic_magnetic_conductivity: num_magnetic_cond_components = 1 elif diagonally_anisotropic_magnetic_conductivity: num_magnetic_cond_components = 3 else: num_magnetic_cond_components = 9 # permittivity - shape (1, Nx, Ny, Nz) for isotropic, (3, Nx, Ny, Nz) for diagonally anisotropic, (9, Nx, Ny, Nz) for fully anisotropic inv_permittivities = create_named_sharded_matrix( (num_perm_components, *volume_shape), value=0.0, dtype=config.dtype, sharding_axis=1, backend=config.backend, ) # permeability - scalar 1.0 if non-magnetic, else (1, Nx, Ny, Nz) for isotropic, (3, Nx, Ny, Nz) for diagonally anisotropic, (9, Nx, Ny, Nz) for fully anisotropic if objects.all_objects_non_magnetic: inv_permeabilities = 1.0 else: inv_permeabilities = create_named_sharded_matrix( (num_permeability_components, *volume_shape), value=0.0, dtype=config.dtype, sharding_axis=1, backend=config.backend, ) # electric conductivity - None if non-conductive, else (1, Nx, Ny, Nz) for isotropic, (3, Nx, Ny, Nz) for diagonally anisotropic, (9, Nx, Ny, Nz) for fully anisotropic electric_conductivity = None if not objects.all_objects_non_electrically_conductive: electric_conductivity = create_named_sharded_matrix( (num_electric_cond_components, *volume_shape), value=0.0, dtype=config.dtype, sharding_axis=1, backend=config.backend, ) # magnetic conductivity - None if non-conductive, else (1, Nx, Ny, Nz) for isotropic, (3, Nx, Ny, Nz) for diagonally anisotropic, (9, Nx, Ny, Nz) for fully anisotropic magnetic_conductivity = None if not objects.all_objects_non_magnetically_conductive: magnetic_conductivity = create_named_sharded_matrix( (num_magnetic_cond_components, *volume_shape), value=0.0, dtype=config.dtype, sharding_axis=1, backend=config.backend, ) conductivity_spacing = None if electric_conductivity is not None or magnetic_conductivity is not None: conductivity_spacing = constants.c * config.time_step_duration / config.courant_number # dispersive ADE auxiliary arrays - all None unless any material is dispersive. # Per-cell coefficients are broadcast over component via a size-1 axis. num_dispersive_poles = objects.max_num_dispersive_poles dispersive_P_curr = None dispersive_P_prev = None dispersive_c1 = None dispersive_c2 = None dispersive_c3 = None if num_dispersive_poles > 0: if not (isotropic_permittivity or diagonally_anisotropic_permittivity): raise NotImplementedError( "Dispersive materials cannot be combined with fully anisotropic " "(off-diagonal) permittivity tensors in v1." ) dispersive_P_curr = create_named_sharded_matrix( (num_dispersive_poles, 3, *volume_shape), value=0.0, dtype=field_dtype, sharding_axis=2, backend=config.backend, ) dispersive_P_prev = create_named_sharded_matrix( (num_dispersive_poles, 3, *volume_shape), value=0.0, dtype=field_dtype, sharding_axis=2, backend=config.backend, ) dispersive_c1 = create_named_sharded_matrix( (num_dispersive_poles, 1, *volume_shape), value=0.0, dtype=config.dtype, sharding_axis=2, backend=config.backend, ) dispersive_c2 = create_named_sharded_matrix( (num_dispersive_poles, 1, *volume_shape), value=0.0, dtype=config.dtype, sharding_axis=2, backend=config.backend, ) dispersive_c3 = create_named_sharded_matrix( (num_dispersive_poles, 1, *volume_shape), value=0.0, dtype=config.dtype, sharding_axis=2, backend=config.backend, ) # set permittivity/permeability/conductivity of static objects sorted_obj = sorted( objects.static_material_objects, key=lambda o: o.placement_order, ) info = {} for o in sorted_obj: if isinstance(o, UniformMaterialObject): # Material properties are tuples (εxx, εxy, εxz, εyx, εyy, εyz, εzx, εzy, εzz) # Arrays have shape (num_components, Nx, Ny, Nz) where num_components is 1 (isotropic), 3 (diagonally anisotropic), or 9 (fully anisotropic) if num_perm_components == 1: # Isotropic: simple element-wise inversion perm_tuple = (o.material.permittivity[0],) inv_obj_permittivity = (1 / jnp.array(perm_tuple, dtype=config.dtype))[:, None, None, None] inv_permittivities = inv_permittivities.at[:, *o.grid_slice].set(inv_obj_permittivity) elif num_perm_components == 3: # Diagonally anisotropic: simple element-wise inversion perm_tuple = (o.material.permittivity[0], o.material.permittivity[4], o.material.permittivity[8]) inv_obj_permittivity = (1 / jnp.array(perm_tuple, dtype=config.dtype))[:, None, None, None] inv_permittivities = inv_permittivities.at[:, *o.grid_slice].set(inv_obj_permittivity) else: # Fully anisotropic: reshape to 3x3 matrix, invert, and flatten back to 9 elements perm_tuple = o.material.permittivity perm_matrix = jnp.array(perm_tuple, dtype=config.dtype).reshape(3, 3) inv_perm_matrix = jnp.linalg.inv(perm_matrix) inv_obj_permittivity = inv_perm_matrix.flatten()[:, None, None, None] inv_permittivities = inv_permittivities.at[:, *o.grid_slice].set(inv_obj_permittivity) if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0: if num_permeability_components == 1: # Isotropic: simple element-wise inversion perm_tuple = (o.material.permeability[0],) inv_obj_permeability = (1 / jnp.array(perm_tuple, dtype=config.dtype))[:, None, None, None] inv_permeabilities = inv_permeabilities.at[:, *o.grid_slice].set(inv_obj_permeability) elif num_permeability_components == 3: # Diagonally anisotropic: simple element-wise inversion perm_tuple = (o.material.permeability[0], o.material.permeability[4], o.material.permeability[8]) inv_obj_permeability = (1 / jnp.array(perm_tuple, dtype=config.dtype))[:, None, None, None] inv_permeabilities = inv_permeabilities.at[:, *o.grid_slice].set(inv_obj_permeability) else: # Fully anisotropic: reshape to 3x3 matrix, invert, and flatten back to 9 elements perm_tuple = o.material.permeability perm_matrix = jnp.array(perm_tuple, dtype=config.dtype).reshape(3, 3) inv_perm_matrix = jnp.linalg.inv(perm_matrix) inv_obj_permeability = inv_perm_matrix.flatten()[:, None, None, None] inv_permeabilities = inv_permeabilities.at[:, *o.grid_slice].set(inv_obj_permeability) if electric_conductivity is not None: if num_electric_cond_components == 1: # Isotropic cond_tuple = (o.material.electric_conductivity[0],) elif num_electric_cond_components == 3: # Diagonally anisotropic cond_tuple = ( o.material.electric_conductivity[0], o.material.electric_conductivity[4], o.material.electric_conductivity[8], ) else: # Fully anisotropic cond_tuple = o.material.electric_conductivity # Scale physical conductivity into the dimensionless update coefficient. # On uniform grids this equals the scalar grid spacing. On stretched # grids it is the reference spacing implied by ``c0 * dt / courant``. assert conductivity_spacing is not None obj_electric_conductivity = (jnp.array(cond_tuple, dtype=config.dtype) * conductivity_spacing)[ :, None, None, None ] electric_conductivity = electric_conductivity.at[:, *o.grid_slice].set(obj_electric_conductivity) if magnetic_conductivity is not None: if num_magnetic_cond_components == 1: # Isotropic cond_tuple = (o.material.magnetic_conductivity[0],) elif num_magnetic_cond_components == 3: # Diagonally anisotropic cond_tuple = ( o.material.magnetic_conductivity[0], o.material.magnetic_conductivity[4], o.material.magnetic_conductivity[8], ) else: # Fully anisotropic cond_tuple = o.material.magnetic_conductivity # Scale physical conductivity into the dimensionless update coefficient. assert conductivity_spacing is not None obj_magnetic_conductivity = (jnp.array(cond_tuple, dtype=config.dtype) * conductivity_spacing)[ :, None, None, None ] magnetic_conductivity = magnetic_conductivity.at[:, *o.grid_slice].set(obj_magnetic_conductivity) if num_dispersive_poles > 0: # Always write the full pole-coefficient stack — zero-padded for # non-dispersive materials — so later placements deterministically # overwrite earlier coefficients across the object's grid_slice. # Without this, a non-dispersive UniformMaterialObject stacked over # a dispersive one would leave stale pole coefficients in the overlap # and drive an ADE update on cells that shouldn't have one. assert dispersive_c1 is not None and dispersive_c2 is not None and dispersive_c3 is not None poles = o.material.dispersion.poles if o.material.dispersion is not None else () c1_vals, c2_vals, c3_vals = compute_pole_coefficients(poles, config.time_step_duration) n = len(poles) c1_padded = jnp.zeros(num_dispersive_poles, dtype=config.dtype) c2_padded = jnp.zeros(num_dispersive_poles, dtype=config.dtype) c3_padded = jnp.zeros(num_dispersive_poles, dtype=config.dtype) if n > 0: c1_padded = c1_padded.at[:n].set(jnp.asarray(c1_vals, dtype=config.dtype)) c2_padded = c2_padded.at[:n].set(jnp.asarray(c2_vals, dtype=config.dtype)) c3_padded = c3_padded.at[:n].set(jnp.asarray(c3_vals, dtype=config.dtype)) # Broadcast (num_poles,) → (num_poles, 1, Nx, Ny, Nz) over grid_slice slice_shape = dispersive_c1[:, :, *o.grid_slice].shape c1_block = jnp.broadcast_to(c1_padded[:, None, None, None, None], slice_shape) c2_block = jnp.broadcast_to(c2_padded[:, None, None, None, None], slice_shape) c3_block = jnp.broadcast_to(c3_padded[:, None, None, None, None], slice_shape) dispersive_c1 = dispersive_c1.at[:, :, *o.grid_slice].set(c1_block) dispersive_c2 = dispersive_c2.at[:, :, *o.grid_slice].set(c2_block) dispersive_c3 = dispersive_c3.at[:, :, *o.grid_slice].set(c3_block) elif isinstance(o, (StaticMultiMaterialObject)): indices = o.get_material_mapping() mask = o.get_voxel_mask_for_shape() # compute_allowed_permittivities returns list of tuples with length 1 (isotropic), 3 (diagonally anisotropic), or 9 (fully anisotropic) allowed_perms = jnp.asarray( compute_allowed_permittivities( o.materials, isotropic=isotropic_permittivity, diagonally_anisotropic=diagonally_anisotropic_permittivity, ) ) if num_perm_components == 1 or num_perm_components == 3: allowed_inv_perms = 1 / allowed_perms # shape: (num_materials, num_components) else: # Fully anisotropic: reshape to 3x3 matrix, invert, and flatten back to 9 elements allowed_inv_perms = jnp.array([jnp.linalg.inv(perm.reshape(3, 3)).flatten() for perm in allowed_perms]) # allowed_inv_perms[indices] -> (*grid_shape, num_components) # After moveaxis -> (num_components, *grid_shape) component_values = jnp.moveaxis(allowed_inv_perms[indices], -1, 0) diff = component_values - inv_permittivities[:, *o.grid_slice] inv_permittivities = inv_permittivities.at[:, *o.grid_slice].add(mask * diff) if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0: allowed_perms = jnp.asarray( compute_allowed_permeabilities( o.materials, isotropic=isotropic_permeability, diagonally_anisotropic=diagonally_anisotropic_permeability, ) ) if num_permeability_components == 1 or num_permeability_components == 3: allowed_inv_perms = 1 / allowed_perms else: # Fully anisotropic: reshape to 3x3 matrix, invert, and flatten back to 9 elements allowed_inv_perms = jnp.array( [jnp.linalg.inv(perm.reshape(3, 3)).flatten() for perm in allowed_perms] ) component_values = jnp.moveaxis(allowed_inv_perms[indices], -1, 0) diff = component_values - inv_permeabilities[:, *o.grid_slice] inv_permeabilities = inv_permeabilities.at[:, *o.grid_slice].add(mask * diff) if electric_conductivity is not None: allowed_conds = jnp.asarray( compute_allowed_electric_conductivities( o.materials, isotropic=isotropic_electric_conductivity, diagonally_anisotropic=diagonally_anisotropic_electric_conductivity, ) ) assert conductivity_spacing is not None component_values = jnp.moveaxis(allowed_conds[indices], -1, 0) * conductivity_spacing diff = component_values - electric_conductivity[:, *o.grid_slice] electric_conductivity = electric_conductivity.at[:, *o.grid_slice].add(mask * diff) if magnetic_conductivity is not None: allowed_conds = jnp.asarray( compute_allowed_magnetic_conductivities( o.materials, isotropic=isotropic_magnetic_conductivity, diagonally_anisotropic=diagonally_anisotropic_magnetic_conductivity, ) ) assert conductivity_spacing is not None component_values = jnp.moveaxis(allowed_conds[indices], -1, 0) * conductivity_spacing diff = component_values - magnetic_conductivity[:, *o.grid_slice] magnetic_conductivity = magnetic_conductivity.at[:, *o.grid_slice].add(mask * diff) # Always run when dispersive arrays exist in the sim: a non-dispersive # StaticMultiMaterialObject layered over a dispersive region must # zero the inherited coefficients in its mask. compute_allowed_dispersive_coefficients # zero-pads non-dispersive materials, so this still cleanly overwrites. if num_dispersive_poles > 0: assert dispersive_c1 is not None and dispersive_c2 is not None and dispersive_c3 is not None allowed_c1, allowed_c2, allowed_c3 = compute_allowed_dispersive_coefficients( o.materials, dt=config.time_step_duration, max_num_poles=num_dispersive_poles, ) # Shape (num_materials, num_poles) -> index by (Nx, Ny, Nz) -> # (Nx, Ny, Nz, num_poles) -> moveaxis -> (num_poles, Nx, Ny, Nz) c1_voxels = jnp.moveaxis(jnp.asarray(allowed_c1, dtype=config.dtype)[indices], -1, 0) c2_voxels = jnp.moveaxis(jnp.asarray(allowed_c2, dtype=config.dtype)[indices], -1, 0) c3_voxels = jnp.moveaxis(jnp.asarray(allowed_c3, dtype=config.dtype)[indices], -1, 0) # broadcast over component axis c1_voxels = c1_voxels[:, None, ...] c2_voxels = c2_voxels[:, None, ...] c3_voxels = c3_voxels[:, None, ...] mask_bc = mask[None, None, ...] diff = c1_voxels - dispersive_c1[:, :, *o.grid_slice] dispersive_c1 = dispersive_c1.at[:, :, *o.grid_slice].add(mask_bc * diff) diff = c2_voxels - dispersive_c2[:, :, *o.grid_slice] dispersive_c2 = dispersive_c2.at[:, :, *o.grid_slice].add(mask_bc * diff) diff = c3_voxels - dispersive_c3[:, :, *o.grid_slice] dispersive_c3 = dispersive_c3.at[:, :, *o.grid_slice].add(mask_bc * diff) else: raise Exception(f"Unknown object type: {o}") # detector states detector_states = {} for d in objects.detectors: detector_states[d.name] = d.init_state() # modify arrays for boundaries for boundary in objects.boundary_objects: if hasattr(boundary, "modify_arrays") and callable(getattr(boundary, "modify_arrays", None)): modify_fn = getattr(boundary, "modify_arrays") result = modify_fn( alpha=alpha, kappa=kappa, sigma=sigma, electric_conductivity=electric_conductivity, magnetic_conductivity=magnetic_conductivity, ) if result is not None: alpha = result.get("alpha", alpha) kappa = result.get("kappa", kappa) sigma = result.get("sigma", sigma) electric_conductivity = result.get("electric_conductivity", electric_conductivity) magnetic_conductivity = result.get("magnetic_conductivity", magnetic_conductivity) # interfaces recording_state = None if config.gradient_config is not None and config.gradient_config.recorder is not None: input_shape_dtypes = {} for boundary in objects.pml_objects: cur_shape = boundary.interface_grid_shape() extended_shape = (3, *cur_shape) input_shape_dtypes[f"{boundary.name}_E"] = jax.ShapeDtypeStruct(shape=extended_shape, dtype=field_dtype) input_shape_dtypes[f"{boundary.name}_H"] = jax.ShapeDtypeStruct(shape=extended_shape, dtype=field_dtype) recorder = config.gradient_config.recorder recorder, recording_state = recorder.init_state( input_shape_dtypes=input_shape_dtypes, max_time_steps=config.time_steps_total, backend=config.backend, ) grad_cfg = config.gradient_config.aset( "recorder", recorder, ) config = config.aset("gradient_config", grad_cfg) # Cache 1/c2 with non-dispersive cells zeroed so update_E_reverse can replace # its ``jnp.where(c2 == 0, ..., / c2)`` pair with a single multiply. dispersive_inv_c2 = None if dispersive_c2 is not None: dispersive_inv_c2 = jnp.where(dispersive_c2 == 0, 0.0, 1.0 / dispersive_c2) arrays = ArrayContainer( fields=FieldState(E=E, H=H, psi_E=psi_E, psi_H=psi_H), alpha=alpha, kappa=kappa, sigma=sigma, inv_permittivities=inv_permittivities, inv_permeabilities=inv_permeabilities, detector_states=detector_states, recording_state=recording_state, electric_conductivity=electric_conductivity, magnetic_conductivity=magnetic_conductivity, dispersive_P_curr=dispersive_P_curr, dispersive_P_prev=dispersive_P_prev, dispersive_c1=dispersive_c1, dispersive_c2=dispersive_c2, dispersive_c3=dispersive_c3, dispersive_inv_c2=dispersive_inv_c2, ) return arrays, config, info def _init_params( objects: ObjectContainer, key: jax.Array, ) -> ParameterContainer: """Initializes parameters for simulation devices. Args: objects (ObjectContainer): Container with simulation objects key (jax.Array): JAX random key for parameter initialization Returns: ParameterContainer: ParameterContainer with initialized device parameters """ params = {} for d in objects.devices: key, subkey = jax.random.split(key) cur_dict = d.init_params(key=subkey) params[d.name] = cur_dict return params
[docs] def resolve_object_constraints( objects: Sequence[SimulationObject], constraints: Sequence[AnyConstraint], config: SimulationConfig, max_iter: int = DEFAULT_MAX_ITER, ) -> tuple[dict, dict]: """Resolve object constraints into grid slices and shapes.""" # Sanity check: Ensure all objects have unique names object_names = [obj.name for obj in objects] duplicates = {name for name in object_names if object_names.count(name) > 1} invalid_objects = [obj for obj in objects if not isinstance(obj, SimulationObject)] if duplicates: raise Exception( f"Duplicate object names detected: {', '.join(sorted(duplicates))}. " "Each object must have a unique name before resolving constraints into grid slices." ) if invalid_objects: raise ValueError( f"Invalid object types detected: {', '.join(sorted(invalid_objects))}. " "All objects must be instances or subclasses of SimulationObject." ) _check_objects_names_from_constraints( constraints=constraints, object_names=object_names, ) # Apply constraints iteratively resolved, errors = _apply_constraints_iteratively( objects=list(objects), constraints=constraints, config=config, max_iter=max_iter, ) # Convert shape_dict and slice_dict from object references to object names resolved_slices = {} for obj_name, slice_list in resolved.items(): resolved_slices[obj_name] = tuple([(axis_slice_list[0], axis_slice_list[1]) for axis_slice_list in slice_list]) # Get volume bounds from resolved slices volume_name = _resolve_volume_name({obj.name: obj for obj in objects}) volume_slice = resolved_slices.get(volume_name) # If the volume itself failed to resolve, skip bounds checks if volume_slice is not None: volume_bounds = tuple((s1, s2) for s1, s2 in volume_slice) # Validate all non-volume objects are within simulation volume bounds for obj_name, slice_tuple in resolved_slices.items(): if obj_name == volume_name: continue # Skip the volume itself # Check for unresolved bounds first unresolved_axes = [] for axis in range(3): s1, s2 = slice_tuple[axis] if s1 is None or s2 is None: unresolved_axes.append(axis) if unresolved_axes: # Ensure unresolved objects are flagged in errors if not errors.get(obj_name): errors[obj_name] = ( f"Object '{obj_name}' has unresolved bounds on axes {unresolved_axes}. Slice: {slice_tuple}" ) continue # Check bounds violations msgs = [] for axis in range(3): s1, s2 = slice_tuple[axis] vol_s1, vol_s2 = volume_bounds[axis] if s1 < vol_s1: msgs.append(f"axis {axis}: lower bound {s1} < volume lower bound {vol_s1}") if s2 > vol_s2: msgs.append(f"axis {axis}: upper bound {s2} > volume upper bound {vol_s2}") if s2 <= s1: msgs.append(f"axis {axis}: invalid size (lower bound {s1} >= upper bound {s2})") if msgs: prev = errors.get(obj_name) or "" errors[obj_name] = ( (prev + "; " if prev else "") + f"Object '{obj_name}' out of bounds ({slice_tuple} vs volume {volume_bounds}): " + "; ".join(msgs) ) return resolved_slices, errors
def _center_to_bounds( real_pos: float, resolution: float, size: int, volume_size: int, ) -> tuple[int, int]: """Convert a center-relative real-space position into grid bounds. The coordinate origin (0,0,0) is interpreted as the center of the simulation volume, not the lower-left simulation corner. """ # convert physical coordinate to grid coordinate relative to volume center volume_center = volume_size / 2 grid_center = round(real_pos / resolution + volume_center) lower = round(grid_center - size / 2) upper = lower + size return lower, upper def _real_length_to_grid_size(config: SimulationConfig, axis: int, length: float) -> int: """Convert a physical length to a grid-cell count. Non-uniform grids snap upward so objects always cover at least the requested metric length. Uniform grids use the historical round-to-nearest rule for exact backwards compatibility. Limitation: on a non-uniform (stretched) grid this helper measures from the lower domain edge, so the returned count reflects the local cell density at the origin rather than at the object's actual placement position. Objects placed in coarser regions may therefore span more physical length than requested. A fully location-aware conversion requires knowing the object's anchor position before sizing it, which is not available at this call site. Use ``partial_grid_shape`` to specify sizes in cell counts when exact metric sizing is required on non-uniform grids. """ snap = "upper" if config.has_nonuniform_grid else "nearest" return config.grid.length_to_cell_count(axis, length, snap=snap) def _real_coord_to_edge_index(config: SimulationConfig, axis: int, coord: float) -> int: """Snap a physical coordinate to a grid edge index.""" return config.grid.coord_to_index(axis, coord, snap="nearest") def _center_to_bounds_for_grid( config: SimulationConfig, axis: int, real_pos: float, size: int, volume_size: int ) -> tuple[int, int]: """Convert a center-relative position to edge bounds, accounting for grid geometry. real_pos is interpreted relative to the simulation volume center (0 = center of domain). Works for both uniform and non-uniform grids. """ grid = config.grid if isinstance(grid, RectilinearGrid): edges = grid.edges(axis) volume_center_coord = float(edges[0] + edges[-1]) / 2 else: volume_center_coord = grid.origin[axis] + volume_size * grid.spacing / 2 return grid.bounds_for_center(axis, real_pos + volume_center_coord, size) def _raise_for_nonuniform_grid_offsets(config: SimulationConfig, values: Sequence[int | None], name: str): """Reject index-space distance offsets when a grid is non-uniform. Zero and ``None`` are accepted as no-ops for backwards-compatible helper defaults. Non-zero grid distances do not have a metric meaning on stretched grids and must be expressed in metres instead. """ if not config.has_nonuniform_grid: return if any(v not in (None, 0) for v in values): raise ValueError(f"{name} are index-space distances and are not supported on non-uniform grids.") def _resolve_static_positions_initial( object_map: dict[str, SimulationObject], slice_dict: dict[str, list[list[int | None]]], shape_dict: dict[str, list[int | None]], config: SimulationConfig, ): """Fill in static or directly defined positions from partial_real_position during initial setup. The partial_real_position represents the center position of the object. Coordinates are interpreted relative to the center of the simulation volume, i.e. partial_real_position=(0,0,0) places an object at the geometric center of the simulation domain. This function converts center-relative real coordinates into positive grid coordinates and computes slice boundaries if the object's size is known. """ volume_name = _resolve_volume_name(object_map) for obj_name, obj in object_map.items(): if hasattr(obj, "partial_real_position") and obj.partial_real_position is not None: for axis in range(3): real_position = obj.partial_real_position[axis] if real_position is None: continue size = shape_dict[obj_name][axis] # Need object size to compute centered bounds if size is None: continue volume_size = shape_dict[volume_name][axis] if volume_size is None: raise ValueError(f"Simulation volume size for axis {axis} is unresolved.") lower, upper = _center_to_bounds_for_grid( config=config, axis=axis, real_pos=real_position, size=size, volume_size=volume_size, ) slice_dict[obj_name][axis][0] = lower slice_dict[obj_name][axis][1] = upper return slice_dict def _resolve_static_positions_iterative( object_map: dict[str, SimulationObject], slice_dict: dict[str, list[list[int | None]]], shape_dict: dict[str, list[int | None]], config: SimulationConfig, errors: dict[str, str | None], ): """Iteratively resolve positions from partial_real_position when size becomes known. The partial_real_position represents the center position of the object. Coordinates are interpreted relative to the center of the simulation volume, i.e. partial_real_position=(0,0,0) places an object at the geometric center of the simulation domain. This function is called in each iteration of constraint resolution so that positions can be computed as soon as the object size becomes known. Returns: tuple: - resolved_something: Whether new positions were resolved - updated slice_dict - updated errors """ resolved_something = False volume_name = _resolve_volume_name(object_map) for obj_name, obj in object_map.items(): if hasattr(obj, "partial_real_position") and obj.partial_real_position is not None: for axis in range(3): real_position = obj.partial_real_position[axis] if real_position is None: continue # Current bounds b0, b1 = slice_dict[obj_name][axis] # Already fully resolved if b0 is not None and b1 is not None: continue # Need object size to compute centered bounds size = shape_dict[obj_name][axis] if size is None: continue volume_size = shape_dict[volume_name][axis] if volume_size is None: raise ValueError(f"Simulation volume size for axis {axis} is unresolved.") lower, upper = _center_to_bounds_for_grid( config=config, axis=axis, real_pos=real_position, size=size, volume_size=volume_size, ) # Set or validate lower bound if b0 is None: slice_dict[obj_name][axis][0] = lower resolved_something = True elif b0 != lower: errors[obj_name] = ( f"Inconsistent position for {obj_name} " f"axis {axis}: partial_real_position implies " f"lower bound {lower}, but constraint set it " f"to {b0}" ) # Set or validate upper bound if b1 is None: slice_dict[obj_name][axis][1] = upper resolved_something = True elif b1 != upper: errors[obj_name] = ( f"Inconsistent position for {obj_name} " f"axis {axis}: partial_real_position implies " f"upper bound {upper}, but constraint set it " f"to {b1}" ) return resolved_something, slice_dict, errors def _check_objects_names_from_constraints( constraints: Sequence[AnyConstraint], object_names: list[str], ): """Collect object names mentioned in constraints and verify they exist.""" all_names = set() for c in constraints: for name in [getattr(c, "object", None), getattr(c, "other_object", None)]: if name and name not in object_names: raise ValueError(f"Unknown object name in constraint: {name}") if name: all_names.add(name) return list(all_names) def _apply_constraints_iteratively( objects: list[SimulationObject], constraints: Sequence[AnyConstraint], config: SimulationConfig, max_iter: int = DEFAULT_MAX_ITER, ) -> tuple[dict, dict]: """ Iteratively apply all constraints until shapes and positions converge. """ # Convert objects list to object_map dictionary object_map = {} for obj in objects: object_map[obj.name] = obj volume_name = _resolve_volume_name(object_map) # Initialize shape_dict and slice_dict with object references as keys shape_dict = {} slice_dict = {} for obj in objects: shape_dict[obj.name] = [None, None, None] slice_dict[obj.name] = [[None, None], [None, None], [None, None]] for axis in range(3): slice_dict[volume_name][axis][0] = 0 errors: dict[str, str | None] = {obj.name: None for obj in objects} # handle static shapes shape_dict = _resolve_static_shapes( object_map=object_map, shape_dict=shape_dict, config=config, ) slice_dict = _resolve_static_positions_initial( object_map=object_map, slice_dict=slice_dict, shape_dict=shape_dict, config=config, ) # iterate for iteration in range(max_iter): changed = False # check if we already resolved everything if all( [ all([shape_dict[o][i] is not None for i in range(3)]) and all([all([slice_dict[o][i][s] is not None for s in range(2)]) for i in range(3)]) for o in object_map.keys() ] ): break # Try to resolve positions from partial_real_position if size is now known resolved, slice_dict, errors = _resolve_static_positions_iterative( object_map=object_map, slice_dict=slice_dict, shape_dict=shape_dict, config=config, errors=errors, ) changed = changed or resolved # Slices-from-shapes: propagate a known shape to an open bound. # Shapes-from-slices: lock the shape once both bounds are known. resolved, slice_dict, errors = _update_grid_slices_from_shapes( object_map=object_map, shape_dict=shape_dict, slice_dict=slice_dict, errors=errors, ) changed = changed or resolved # update grid shapes based on grid slices resolved, shape_dict, errors = _update_grid_shapes_from_slices( object_map=object_map, shape_dict=shape_dict, slice_dict=slice_dict, errors=errors, ) changed = changed or resolved # go through all constraints for c in constraints: try: if isinstance(c, GridCoordinateConstraint): resolved, slice_dict = _apply_grid_coordinate_constraint( constraint=c, object_map=object_map, slice_dict=slice_dict, config=config, ) elif isinstance(c, RealCoordinateConstraint): resolved, slice_dict = _apply_real_coordinate_constraint( constraint=c, object_map=object_map, slice_dict=slice_dict, config=config, ) elif isinstance(c, PositionConstraint): resolved, slice_dict = _apply_position_constraint( constraint=c, object_map=object_map, config=config, shape_dict=shape_dict, slice_dict=slice_dict, ) elif isinstance(c, SizeConstraint): resolved, shape_dict = _apply_size_constraint( constraint=c, object_map=object_map, config=config, shape_dict=shape_dict, slice_dict=slice_dict, ) elif isinstance(c, SizeExtensionConstraint): resolved, slice_dict = _apply_size_extension_constraint( constraint=c, object_map=object_map, config=config, slice_dict=slice_dict, volume_name=volume_name, ) else: raise ValueError(f"Unknown constraint type: {type(c).__name__}") except Exception as e: errors[c.object] = f"Error applying {type(c).__name__}: {e}" changed = changed or resolved # Extend objects to infinity if possible if not changed: changed, slice_dict = _extend_to_inf_if_possible( constraints=constraints, object_map=object_map, slice_dict=slice_dict, shape_dict=shape_dict, volume_name=volume_name, ) # check for misspecification if not changed: errors = _handle_unresolved_objects(object_map=object_map, slice_dict=slice_dict, errors=errors) break else: # max_iter reached without convergence # Ensure all unresolved objects are flagged errors = _handle_unresolved_objects(object_map=object_map, slice_dict=slice_dict, errors=errors) return slice_dict, errors def _resolve_volume_name( object_map: dict[str, SimulationObject], ) -> str: volume_objects = [o for o in object_map.values() if isinstance(o, SimulationVolume)] if not volume_objects: raise ValueError("No SimulationVolume object found in the provided objects list.") elif len(volume_objects) > 1: raise ValueError( f"Multiple SimulationVolume objects found ({[o.name for o in volume_objects]}). " "There must be exactly one simulation volume." ) return volume_objects[0].name def _resolve_static_shapes( object_map: dict[str, SimulationObject], shape_dict: dict[str, list[int | None]], config: SimulationConfig, ) -> dict[str, list[int | None]]: """Fill in shapes from each object's partial_real_shape and partial_grid_shape.""" for obj_name, obj in object_map.items(): for axis in range(3): if obj.partial_grid_shape[axis] is not None: shape_dict[obj_name][axis] = obj.partial_grid_shape[axis] if obj.partial_real_shape[axis] is not None: cur_grid_shape = _real_length_to_grid_size(config, axis, obj.partial_real_shape[axis]) # type: ignore shape_dict[obj_name][axis] = cur_grid_shape return shape_dict def _record_shape_bound_conflict( obj_name: str, axis: int, bound_size: int, obj: SimulationObject, shape_dict: dict[str, list[int | None]], errors: dict[str, str | None], ) -> bool: """Record a conflict where shape_dict and bound-derived size disagree. Always an error.""" errors[obj_name] = ( f"Inconsistent grid shape for object: {shape_dict[obj_name][axis]} != {bound_size} " f"for axis={axis}, {obj.name} ({obj.__class__.__name__}). " f"Check partial_real_shape, partial_grid_shape, and any SizeConstraints for this object. " f"If the shape is derived from geometry (e.g. radius), a conflicting constraint was applied." ) return False def _update_grid_slices_from_shapes( object_map: dict[str, SimulationObject], shape_dict: dict[str, list[int | None]], slice_dict: dict[str, list[list[int | None]]], errors: dict[str, str | None], ): resolved_something = False for obj_name, s in shape_dict.items(): obj = object_map[obj_name] for axis in range(3): s_axis = s[axis] if s_axis is None: continue b0, b1 = slice_dict[obj_name][axis] if b0 is None and b1 is None: continue elif b0 is not None and b1 is not None: if s_axis != b1 - b0: resolved_something |= _record_shape_bound_conflict(obj_name, axis, b1 - b0, obj, shape_dict, errors) elif b0 is not None: slice_dict[obj_name][axis][1] = b0 + s_axis resolved_something = True elif b1 is not None: slice_dict[obj_name][axis][0] = b1 - s_axis resolved_something = True return resolved_something, slice_dict, errors def _update_grid_shapes_from_slices( object_map: dict[str, SimulationObject], shape_dict: dict[str, list[int | None]], slice_dict: dict[str, list[list[int | None]]], errors: dict[str, str | None], ): resolved_something = False for obj_name, b in slice_dict.items(): obj = object_map[obj_name] s = shape_dict[obj_name] for axis in range(3): b0, b1 = b[axis] s_axis = s[axis] if b0 is not None and b1 is not None: if s_axis is None: shape_dict[obj_name][axis] = b1 - b0 resolved_something = True elif b1 - b0 != s_axis: resolved_something |= _record_shape_bound_conflict(obj_name, axis, b1 - b0, obj, shape_dict, errors) return resolved_something, shape_dict, errors def _apply_grid_coordinate_constraint( constraint: GridCoordinateConstraint, object_map: dict[str, SimulationObject], slice_dict: dict[str, list[list[int | None]]], config: SimulationConfig | None = None, ): if config is not None and config.has_nonuniform_grid: raise ValueError( "GridCoordinateConstraint is an index-space placement API and is not supported on non-uniform grids." ) obj_name = constraint.object obj = object_map[obj_name] resolved_something = False for axis_idx, axis in enumerate(constraint.axes): cur_size = constraint.coordinates[axis_idx] b_idx = 0 if constraint.sides[axis_idx] == "-" else 1 if slice_dict[obj_name][axis][b_idx] is None: slice_dict[obj_name][axis][b_idx] = cur_size resolved_something = True elif slice_dict[obj_name][axis][b_idx] != cur_size: raise Exception( f"Inconsistent grid coordinates for object: " f"{slice_dict[obj_name][axis][b_idx]} != {cur_size} for {axis=} {obj.name} ({obj.__class__}). " ) return resolved_something, slice_dict def _apply_real_coordinate_constraint( constraint: RealCoordinateConstraint, object_map: dict[str, SimulationObject], slice_dict: dict[str, list[list[int | None]]], config: SimulationConfig, ): obj_name = constraint.object obj = object_map[obj_name] resolved_something = False for axis_idx, axis in enumerate(constraint.axes): cur_size = _real_coord_to_edge_index(config, axis, constraint.coordinates[axis_idx]) b_idx = 0 if constraint.sides[axis_idx] == "-" else 1 if slice_dict[obj_name][axis][b_idx] is None: slice_dict[obj_name][axis][b_idx] = cur_size resolved_something = True elif slice_dict[obj_name][axis][b_idx] != cur_size: raise Exception( f"Inconsistent grid coordinates for object: " f"{slice_dict[obj_name][axis][b_idx]} != {cur_size} for {axis=} {obj.name} ({obj.__class__}). " ) return resolved_something, slice_dict def _apply_position_constraint( constraint: PositionConstraint, object_map: dict[str, SimulationObject], config: SimulationConfig, shape_dict: dict[str, list[int | None]], slice_dict: dict[str, list[list[int | None]]], ): """Apply a position constraint between two objects.""" obj_name, other_name = constraint.object, constraint.other_object obj = object_map[obj_name] resolved_something = False # go through axes of constraint for axis_idx, axis in enumerate(constraint.axes): grid_margin = constraint.grid_margins[axis_idx] real_margin = constraint.margins[axis_idx] _raise_for_nonuniform_grid_offsets(config, (grid_margin,), "grid_margins") # check if other knows their position other_b0, other_b1 = slice_dict[other_name][axis] if other_b0 is None or other_b1 is None: continue # check if object knows their size object_size = shape_dict[obj_name][axis] if object_size is None: continue other_anchor = config.grid.anchor_coordinate( axis, (other_b0, other_b1), constraint.other_object_positions[axis_idx], ) if real_margin is not None: other_anchor += real_margin if grid_margin is not None: # grid_margin is in cell units; rejected for non-uniform grids above other_anchor += grid_margin * config.uniform_spacing() b0, b1 = config.grid.bounds_for_anchor( axis, object_size, other_anchor, constraint.object_positions[axis_idx], ) # update position or check consistency old_b0, old_b1 = slice_dict[obj_name][axis] if old_b0 is None: slice_dict[obj_name][axis][0] = b0 resolved_something = True elif old_b0 != b0: raise Exception( f"Inconsistent grid shape (may be due to extension to infinity) at lower bound: " f"{old_b0} != {b0} for {axis=}, {obj.name} ({obj.__class__}). " f"Object has a position constraint that puts the lower boundary at {b0}, " f"but the lower bound was alreay computed to be at {old_b0}. " f"This could be due to a missing size constraint/specification, " f"or another constraint on this object." ) if old_b1 is None: slice_dict[obj_name][axis][1] = b1 resolved_something = True elif old_b1 != b1: raise Exception( f"Inconsistent grid shape (may be due to extension to infinity) at lower bound: " f"{old_b1} != {b1} for {axis=}, {obj.name} ({obj.__class__}). " f"Object has a position constraint that puts the upper boundary at {b1}, " f"but the lower bound was alreay computed to be at {old_b1}. " f"This could be either due to a missing size constraint/specification, " f"or another constraint on this object." ) return resolved_something, slice_dict def _apply_size_constraint( constraint: SizeConstraint, object_map: dict[str, SimulationObject], config: SimulationConfig, shape_dict: dict[str, list[int | None]], slice_dict: dict[str, list[list[int | None]]] | None = None, ): """Resolve a size relationship between objects.""" obj_name, other_name = constraint.object, constraint.other_object obj = object_map[obj_name] resolved_something = False # iterate through axes of the constraint for axis_idx, axis in enumerate(constraint.axes): _raise_for_nonuniform_grid_offsets(config, (constraint.grid_offsets[axis_idx],), "grid_offsets") other_axes = constraint.other_axes[axis_idx] # check if other object knows their shape other_shape = shape_dict[other_name][other_axes] if other_shape is None: continue # calculate objects shape proportion = constraint.proportions[axis_idx] assert slice_dict is not None, "_apply_size_constraint requires slice_dict" other_b0, other_b1 = slice_dict[other_name][other_axes] if other_b0 is None or other_b1 is None: continue other_length = config.grid.axis_extent(other_axes, (other_b0, other_b1)) target_length = other_length * proportion if constraint.offsets[axis_idx] is not None: target_length += constraint.offsets[axis_idx] if constraint.grid_offsets[axis_idx] is not None: # grid_offsets are in cell units; rejected for non-uniform grids above target_length += constraint.grid_offsets[axis_idx] * config.uniform_spacing() object_shape = _real_length_to_grid_size(config, axis, target_length) # update or check consistency if shape_dict[obj_name][axis] is None: shape_dict[obj_name][axis] = object_shape resolved_something = True elif shape_dict[obj_name][axis] != object_shape: raise Exception( f"Inconsistent grid shape for object: " f"{shape_dict[obj_name][axis]} != {object_shape} for axis={axis}, " f"{obj.name} ({obj.__class__.__name__}). " f"Check partial_real_shape, partial_grid_shape, and any SizeConstraints for this object. " f"If the shape is derived from geometry (e.g. radius), a conflicting SizeConstraint was applied." ) return resolved_something, shape_dict def _apply_size_extension_constraint( constraint: SizeExtensionConstraint, object_map: dict[str, SimulationObject], config: SimulationConfig, slice_dict: dict[str, list[list[int | None]]], volume_name: str, ): obj_name, other_name = constraint.object, constraint.other_object obj = object_map[obj_name] dir_idx = 0 if constraint.direction == "-" else 1 resolved_something = False _raise_for_nonuniform_grid_offsets(config, (constraint.grid_offset,), "grid_offset") # calculate anchor point if other_name is not None: # check if other knows their position other_b0, other_b1 = slice_dict[other_name][constraint.axis] if other_b0 is None or other_b1 is None: return False, slice_dict other_anchor_coord = config.grid.anchor_coordinate( constraint.axis, (other_b0, other_b1), constraint.other_position, ) if constraint.offset is not None: other_anchor_coord += constraint.offset if constraint.grid_offset is not None: # grid_offset is in cell units; rejected for non-uniform grids above other_anchor_coord += constraint.grid_offset * config.uniform_spacing() other_anchor = config.grid.coord_to_index(constraint.axis, other_anchor_coord, snap="nearest") else: # if other is not specified, extend to boundary of simulation volume other_anchor = slice_dict[volume_name][constraint.axis][dir_idx] if other_anchor is None: raise Exception(f"This should never happen: Simulation volume not specified: {volume_name}") # update position or check consistency old_val = slice_dict[obj_name][constraint.axis][dir_idx] if old_val is None: slice_dict[obj_name][constraint.axis][dir_idx] = other_anchor resolved_something = True elif old_val != other_anchor: raise Exception( f"Inconsistent grid shape at bound {constraint.direction}: " f"{old_val} != {other_anchor} for {constraint.axis=}, " f"{obj.name} ({obj.__class__})." ) return resolved_something, slice_dict def _extend_to_inf_if_possible( constraints: Sequence[AnyConstraint], object_map: dict[str, SimulationObject], slice_dict: dict[str, list[list[int | None]]], shape_dict: dict[str, list[int | None]], volume_name: str, ): # Extend objects to infinity, which fulfill the properties: # - do not already have both boundaries specified # - are not constrained by extension constraints in that direction # Note: Objects with known size but no position will extend from 0 # Note: Size constraints alone don't prevent extension - they just constrain the size resolved_something = False for axis in range(3): extension_obj = [(o, 0) for o in object_map.keys()] + [(o, 1) for o in object_map.keys()] # Remove objects that are in extension constraints (not size constraints!) # Size constraints only constrain the size, not the position for c in constraints: if isinstance(c, SizeExtensionConstraint) and axis == c.axis: direction = 0 if c.direction == "-" else 1 if (c.object, direction) in extension_obj: extension_obj.remove((c.object, direction)) # Do not extend objects that have a pending PositionConstraint on this axis. # If the referenced object's bounds are still unknown the constraint cannot resolve # yet, and locking position=0 now will conflict when the constraint resolves later. if isinstance(c, PositionConstraint): for c_axis in c.axes: if c_axis != axis: continue other_b0, other_b1 = slice_dict[c.other_object][axis] if other_b0 is None or other_b1 is None: if (c.object, 0) in extension_obj: extension_obj.remove((c.object, 0)) if (c.object, 1) in extension_obj: extension_obj.remove((c.object, 1)) # For each object, determine what can be extended for o in object_map.keys(): b0, b1 = slice_dict[o][axis] size = shape_dict[o][axis] # Both boundaries known - don't extend either if b0 is not None and b1 is not None: if (o, 0) in extension_obj: extension_obj.remove((o, 0)) if (o, 1) in extension_obj: extension_obj.remove((o, 1)) # Lower bound known but upper not - can compute upper if size known elif b0 is not None and b1 is None and size is not None: if (o, 1) in extension_obj: extension_obj.remove((o, 1)) # Upper bound known but lower not - can compute lower if size known elif b1 is not None and b0 is None and size is not None: if (o, 0) in extension_obj: extension_obj.remove((o, 0)) # No boundaries known but size is known - extend lower from 0, upper can be computed elif b0 is None and b1 is None and size is not None: # Keep lower (0) in extension_obj so it extends from 0 # Remove upper from extension_obj since it will be computed if (o, 1) in extension_obj: extension_obj.remove((o, 1)) # Apply extensions for o, direction in extension_obj: if slice_dict[o][axis][direction] is not None: continue resolved_something = True if direction == 0: slice_dict[o][axis][0] = 0 else: slice_dict[o][axis][1] = shape_dict[volume_name][axis] return resolved_something, slice_dict def _handle_unresolved_objects( object_map: dict[str, SimulationObject], slice_dict: dict[str, list[list[int | None]]], errors: dict[str, str | None], ): for obj_name, obj in object_map.items(): if any([slice_dict[obj_name][a][0] is None or slice_dict[obj_name][a][1] is None for a in range(3)]): errors[obj_name] = f"Could not resolve position/size of {obj.name} ({obj.__class__})." return errors