Source code for fdtdx.fdtd.wrapper
from __future__ import annotations
from collections.abc import Callable
import jax
from fdtdx.config import SimulationConfig
from fdtdx.fdtd.container import ArrayContainer, ObjectContainer, SimulationState
from fdtdx.fdtd.fdtd import checkpointed_fdtd, reversible_fdtd
from fdtdx.fdtd.stop_conditions import StoppingCondition
[docs]
def run_fdtd(
arrays: ArrayContainer,
objects: ObjectContainer,
config: SimulationConfig,
key: jax.Array,
stopping_condition: StoppingCondition | None = None,
show_progress: bool = True,
progress_callback: Callable[[int, int], None] | None = None,
) -> SimulationState:
if stopping_condition is not None:
if config.gradient_config is not None:
raise NotImplementedError(
"Custom stopping conditions are not yet compatible with gradient computation. "
"Set config.gradient_config to None or use default time-based stopping by "
"setting stopping_condition=None."
)
if config.gradient_config is None:
# only forward simulation, use standard while loop of checkpointed fdtd
return checkpointed_fdtd(
arrays=arrays,
objects=objects,
config=config,
key=key,
stopping_condition=stopping_condition,
show_progress=show_progress,
progress_callback=progress_callback,
)
if config.gradient_config.method == "reversible":
return reversible_fdtd(
arrays=arrays,
objects=objects,
config=config,
key=key,
show_progress=show_progress,
progress_callback=progress_callback,
)
elif config.gradient_config.method == "checkpointed":
return checkpointed_fdtd(
arrays=arrays,
objects=objects,
config=config,
key=key,
stopping_condition=stopping_condition,
show_progress=show_progress,
progress_callback=progress_callback,
)
else:
raise Exception(f"Unknown gradient computation method: {config.gradient_config.method}")