from pathlib import Path
from typing import Any, Literal
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.patches import Patch, Rectangle
from fdtdx.config import SimulationConfig
from fdtdx.core.grid import RectilinearGrid
from fdtdx.fdtd.container import ObjectContainer
from fdtdx.objects.boundaries.bloch import BlochBoundary
from fdtdx.objects.boundaries.pec import PerfectElectricConductor
from fdtdx.objects.boundaries.perfectly_matched_layer import PerfectlyMatchedLayer
from fdtdx.objects.boundaries.pmc import PerfectMagneticConductor
from fdtdx.objects.object import SimulationObject
def _get_full_coverage_objects(
objects: list[SimulationObject],
axis_indices: tuple[int, int],
plane_size: tuple[int, int],
volume: SimulationObject,
) -> list[SimulationObject]:
"""Detect objects that cover 100% of the viewing plane.
Args:
objects: List of simulation objects to check
axis_indices: Tuple of two axis indices defining the viewing plane
plane_size: Tuple of (width, height) of the viewing plane in grid cells
volume: The simulation volume object
Returns:
List of objects that cover 100% of the viewing plane
"""
full_coverage_objects = []
total_area = plane_size[0] * plane_size[1]
# Guard against degenerate simulation volumes (zero area planes)
if total_area <= 0:
return []
for obj in objects:
if obj is volume:
continue
slices = obj.grid_slice_tuple
obj_width = slices[axis_indices[0]][1] - slices[axis_indices[0]][0]
obj_height = slices[axis_indices[1]][1] - slices[axis_indices[1]][0]
obj_area = obj_width * obj_height
# Check if object covers the entire plane (allowing for small floating point errors)
if obj_area >= total_area * 0.999: # 99.9% threshold to account for numerical issues
full_coverage_objects.append(obj)
return full_coverage_objects
def _axis_edges_um(config: SimulationConfig, axis: int, bounds: tuple[int, int]) -> tuple[float, float]:
"""Return local physical edge coordinates in micrometres for an index interval."""
grid = getattr(config, "grid", None)
if isinstance(grid, RectilinearGrid):
edges = grid.edges(axis)
domain_origin = float(edges[0])
return (float(edges[bounds[0]] - domain_origin) / 1.0e-6, float(edges[bounds[1]] - domain_origin) / 1.0e-6)
spacing = config.uniform_spacing()
return (bounds[0] * spacing / 1.0e-6, bounds[1] * spacing / 1.0e-6)
[docs]
def plot_setup_from_side(
config: SimulationConfig,
objects: ObjectContainer,
viewing_side: Literal["x", "y", "z"],
exclude_object_list: list[SimulationObject] | None = None,
filename: str | Path | None = None,
ax: Any | None = None,
plot_legend: bool = True,
exclude_xy_plane_object_list: list[SimulationObject] | None = None,
exclude_yz_plane_object_list: list[SimulationObject] | None = None,
exclude_xz_plane_object_list: list[SimulationObject] | None = None,
exclude_large_object_ratio: float | None = None,
auto_exclude_full_coverage: bool = True,
) -> Figure:
"""Creates a visualization of the simulation setup from a single viewing side.
Generates a single subplot showing a cross-section of the simulation volume and the objects
within it from the specified viewing side. Objects are drawn as colored rectangles with
optional legends.
Args:
config (SimulationConfig): Configuration object containing simulation parameters like resolution
objects (ObjectContainer): Container holding all simulation objects to be plotted
viewing_side (Literal['x', 'y', 'z']): Which plane to view ('x' for YZ, 'y' for XZ, 'z' for XY)
exclude_object_list (list[SimulationObject] | None, optional): List of objects to exclude from all plots
filename (str | Path | None, optional): If provided, saves the plot to this file instead of displaying
ax (Any | None, optional): Optional matplotlib axis to plot on. If None, creates new figure
plot_legend (bool, optional): Whether to add a legend showing object names/types
exclude_xy_plane_object_list (list[SimulationObject] | None, optional): Objects to exclude from XY plane plot
exclude_yz_plane_object_list (list[SimulationObject] | None, optional): Objects to exclude from YZ plane plot
exclude_xz_plane_object_list (list[SimulationObject] | None, optional): Objects to exclude from XZ plane plot
exclude_large_object_ratio (float | None, optional): If provided, excludes objects that cover more than
this ratio of the image (e.g., 1.0 excludes objects covering 100% of the image)
auto_exclude_full_coverage (bool, optional): Automatically exclude objects that cover 100% of the viewing plane
Returns:
Figure: The generated figure object
Note:
The plots show object positions in micrometers, converting from simulation units.
PML objects are automatically excluded from their respective boundary planes.
"""
# default to empty lists
exclude_object_list = exclude_object_list or []
exclude_xy_plane_object_list = exclude_xy_plane_object_list or []
exclude_yz_plane_object_list = exclude_yz_plane_object_list or []
exclude_xz_plane_object_list = exclude_xz_plane_object_list or []
# add boundaries to exclude lists
for o in objects.objects:
if not isinstance(
o, (PerfectlyMatchedLayer, BlochBoundary, PerfectElectricConductor, PerfectMagneticConductor)
):
continue
if o.axis == 0:
exclude_yz_plane_object_list.append(o)
elif o.axis == 1:
exclude_xz_plane_object_list.append(o)
elif o.axis == 2:
exclude_xy_plane_object_list.append(o)
# add volume to exclude list
volume = objects.volume
exclude_object_list.append(volume)
object_list = [o for o in objects.objects if o not in exclude_object_list]
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
else:
fig = None
# Determine which exclude list to use based on viewing side
if viewing_side == "z":
plane_exclude_list = list(exclude_xy_plane_object_list) # Create a copy
axis_indices = (0, 1) # X, Y
axis_labels = ("x (µm)", "y (µm)")
title = "XY plane"
plane_size = (volume.grid_shape[0], volume.grid_shape[1])
elif viewing_side == "y":
plane_exclude_list = list(exclude_xz_plane_object_list) # Create a copy
axis_indices = (0, 2) # X, Z
axis_labels = ("x (µm)", "z (µm)")
title = "XZ plane"
plane_size = (volume.grid_shape[0], volume.grid_shape[2])
elif viewing_side == "x":
plane_exclude_list = list(exclude_yz_plane_object_list) # Create a copy
axis_indices = (1, 2) # Y, Z
axis_labels = ("y (µm)", "z (µm)")
title = "YZ plane"
plane_size = (volume.grid_shape[1], volume.grid_shape[2])
else:
raise ValueError(f"Invalid viewing_side: {viewing_side}. Must be 'x', 'y', or 'z'")
# Auto-detect and exclude objects that cover 100% of the viewing plane
if auto_exclude_full_coverage:
full_coverage_objects = _get_full_coverage_objects(object_list, axis_indices, plane_size, volume)
plane_exclude_list.extend(full_coverage_objects)
# Filter objects for this plane
colored_objects: list[SimulationObject] = [
o for o in object_list if o.color is not None and (plane_exclude_list is None or o not in plane_exclude_list)
]
# Apply exclude_large_object_ratio filter
if exclude_large_object_ratio is not None:
total_area = plane_size[0] * plane_size[1]
filtered_objects = []
for obj in colored_objects:
slices = obj.grid_slice_tuple
obj_area = (slices[axis_indices[0]][1] - slices[axis_indices[0]][0]) * (
slices[axis_indices[1]][1] - slices[axis_indices[1]][0]
)
coverage_ratio = obj_area / total_area
if coverage_ratio <= exclude_large_object_ratio:
filtered_objects.append(obj)
colored_objects = filtered_objects
if plot_legend:
handles = []
used_lists = []
for o in colored_objects:
print_single = False
for o2 in colored_objects:
if o.__class__ == o2.__class__:
if o.color != o2.color:
print_single = True
if not o.name.startswith("Object"):
print_single = True
label = o.__class__.__name__ if o.name.startswith("Object") else o.name
color_val = o.color.to_mpl() if o.color is not None else "gray"
patch = Patch(color=color_val, label=label)
if print_single:
handles.append(patch)
else:
if o.__class__.__name__ not in used_lists:
used_lists.append(o.__class__.__name__)
handles.append(patch)
ax.legend(
handles=handles,
loc="upper right",
bbox_to_anchor=(1.75, 0.75),
frameon=False,
)
# Plot each object
for obj in colored_objects:
slices = obj.grid_slice_tuple
color = obj.color
ax.add_patch(
Rectangle(
(
_axis_edges_um(config, axis_indices[0], slices[axis_indices[0]])[0],
_axis_edges_um(config, axis_indices[1], slices[axis_indices[1]])[0],
),
_axis_edges_um(config, axis_indices[0], slices[axis_indices[0]])[1]
- _axis_edges_um(config, axis_indices[0], slices[axis_indices[0]])[0],
_axis_edges_um(config, axis_indices[1], slices[axis_indices[1]])[1]
- _axis_edges_um(config, axis_indices[1], slices[axis_indices[1]])[0],
color=color.to_mpl() if color is not None else "gray",
alpha=0.5,
linestyle="--"
if isinstance(obj, (BlochBoundary, PerfectElectricConductor, PerfectMagneticConductor))
else "-",
)
)
# Set labels and titles
ax.set_xlabel(axis_labels[0])
ax.set_ylabel(axis_labels[1])
ax.set_title(title)
ax.set_xlim(_axis_edges_um(config, axis_indices[0], (0, plane_size[0])))
ax.set_ylim(_axis_edges_um(config, axis_indices[1], (0, plane_size[1])))
ax.set_aspect("equal")
ax.grid(True)
if filename is not None:
plt.savefig(filename, bbox_inches="tight", dpi=300)
plt.close()
return plt.gcf() if fig is None else fig
[docs]
def plot_setup(
config: SimulationConfig,
objects: ObjectContainer,
exclude_object_list: list[SimulationObject] | None = None,
filename: str | Path | None = None,
axs: Any | None = None,
plot_legend: bool = True,
exclude_xy_plane_object_list: list[SimulationObject] | None = None,
exclude_yz_plane_object_list: list[SimulationObject] | None = None,
exclude_xz_plane_object_list: list[SimulationObject] | None = None,
exclude_large_object_ratio: float | None = None,
auto_exclude_full_coverage: bool = True,
) -> Figure:
"""Creates a visualization of the simulation setup showing objects in XY, XZ and YZ planes.
Generates three subplots showing cross-sections of the simulation volume and the objects
within it. Objects are drawn as colored rectangles with optional legends. The visualization
helps verify the correct positioning and sizing of objects in the simulation setup.
Args:
config (SimulationConfig): Configuration object containing simulation parameters like resolution
objects (ObjectContainer): Container holding all simulation objects to be plotted
exclude_object_list (list[SimulationObject] | None, optional): List of objects to exclude from all plots
filename (str | Path | None, optional): If provided, saves the plot to this file instead of displaying
axs (Any | None, optional): Optional matplotlib axes to plot on. If None, creates new figure
plot_legend (bool, optional): Whether to add a legend showing object names/types
exclude_xy_plane_object_list (list[SimulationObject] | None, optional): Objects to exclude from XY plane plot
exclude_yz_plane_object_list (list[SimulationObject] | None, optional): Objects to exclude from YZ plane plot
exclude_xz_plane_object_list (list[SimulationObject] | None, optional): Objects to exclude from XZ plane plot
exclude_large_object_ratio (float | None, optional): If provided, excludes objects that cover more than
this ratio of the image (e.g., 1.0 excludes objects covering 100% of the image)
auto_exclude_full_coverage (bool, optional): Automatically exclude objects that cover 100% of the viewing plane
Returns:
Figure: The generated figure object
Note:
The plots show object positions in micrometers, converting from simulation units.
PML objects are automatically excluded from their respective boundary planes.
Objects covering 100% of a viewing plane are automatically excluded by default.
"""
if axs is None:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
else:
fig = None
# Plot XY plane (viewing from z direction)
plot_setup_from_side(
config=config,
objects=objects,
viewing_side="z",
exclude_object_list=exclude_object_list,
filename=None,
ax=axs[0],
plot_legend=False,
exclude_xy_plane_object_list=exclude_xy_plane_object_list,
exclude_yz_plane_object_list=exclude_yz_plane_object_list,
exclude_xz_plane_object_list=exclude_xz_plane_object_list,
exclude_large_object_ratio=exclude_large_object_ratio,
auto_exclude_full_coverage=auto_exclude_full_coverage,
)
# Plot XZ plane (viewing from y direction)
plot_setup_from_side(
config=config,
objects=objects,
viewing_side="y",
exclude_object_list=exclude_object_list,
filename=None,
ax=axs[1],
plot_legend=False,
exclude_xy_plane_object_list=exclude_xy_plane_object_list,
exclude_yz_plane_object_list=exclude_yz_plane_object_list,
exclude_xz_plane_object_list=exclude_xz_plane_object_list,
exclude_large_object_ratio=exclude_large_object_ratio,
auto_exclude_full_coverage=auto_exclude_full_coverage,
)
# Plot YZ plane (viewing from x direction)
plot_setup_from_side(
config=config,
objects=objects,
viewing_side="x",
exclude_object_list=exclude_object_list,
filename=None,
ax=axs[2],
plot_legend=plot_legend,
exclude_xy_plane_object_list=exclude_xy_plane_object_list,
exclude_yz_plane_object_list=exclude_yz_plane_object_list,
exclude_xz_plane_object_list=exclude_xz_plane_object_list,
exclude_large_object_ratio=exclude_large_object_ratio,
auto_exclude_full_coverage=auto_exclude_full_coverage,
)
if filename is not None:
plt.savefig(filename, bbox_inches="tight", dpi=300)
plt.close()
return plt.gcf() if fig is None else fig