import atexit
import csv
import shutil
import sys
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any, Sequence, cast
import jax
import jax.numpy as jnp
import numpy as np
import seaborn as sns
from loguru import logger
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from rich.console import Console
from rich.markup import escape
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
from rich.table import Table
from fdtdx.conversion.stl import export_stl as export_stl_fn
from fdtdx.core.misc import cast_floating_to_numpy, get_background_material_name
from fdtdx.core.plotting.device_permittivity_index_utils import device_matrix_index_figure
from fdtdx.fdtd.container import ObjectContainer, ParameterContainer
from fdtdx.materials import compute_ordered_names
from fdtdx.objects.detectors.detector import DetectorState
def init_working_directory(experiment_name: str, wd_name: str | None) -> Path:
"""Initialize working directory for experiment outputs.
Creates a timestamped directory structure for experiment outputs under outputs/nobackup/.
Uses current date/time unless a specific working directory name is provided.
Args:
experiment_name (str): Name of the experiment
wd_name (str | None): Optional specific name for the working directory. If None, uses timestamp.
Returns:
Path: Created working directory path
"""
cur_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%f")
day, daytime = cur_time.split("_")
new_cwd = Path().cwd() / "outputs" / "nobackup" / day / experiment_name / (daytime if wd_name is None else wd_name)
new_cwd.mkdir(parents=True)
return new_cwd
def _log_formatter(record: Any) -> str:
"""Log message formatter"""
color_map = {
"TRACE": "dim blue",
"DEBUG": "cyan",
"INFO": "bold",
"SUCCESS": "bold green",
"WARNING": "yellow",
"ERROR": "bold red",
"CRITICAL": "bold white on red",
}
lvl_color = color_map.get(record["level"].name, "cyan")
loc = record["file"].path + ":" + str(record["line"])
message = escape(record["message"])
return (
"[not bold green]{time:DD.MM.YYYY HH:mm:ss.SSS}[/not bold green] | "
+ f"{loc}"
+ f" - [{lvl_color}]{message}[/{lvl_color}]"
)
def snapshot_python_files(snapshot_dir: Path, save_source: bool = False, save_script: bool = True):
snapshot_dir.mkdir(parents=True, exist_ok=True)
# fdtdx
root_dir = Path(__file__).parent.parent
files = []
if save_source:
files = files + list(root_dir.rglob("*.py"))
for python_file in files:
relative_path = python_file.relative_to(root_dir.parent)
destination = snapshot_dir / relative_path
destination.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(python_file, destination)
# Copy active script
if save_script and sys.argv[0]:
shutil.copy(sys.argv[0], snapshot_dir / Path(sys.argv[0]).name)
# make zip and delete directory
shutil.make_archive(str(snapshot_dir.parent / "code"), "zip", snapshot_dir)
shutil.rmtree(snapshot_dir)
[docs]
class Logger:
"""Logger for managing experiment outputs and visualization.
Handles experiment logging, metrics tracking, and visualization of simulation results.
Creates a working directory structure, initializes logging, and provides methods for
saving figures, metrics, and device parameters.
Args:
experiment_name (str): Name of the experiment. This is the naming of the parent directory where the experiment
will be saved.
name (str | None, optional): Optional specific name for the working directory. If None, uses timestamp.
"""
def __init__(
self, experiment_name: str, name: str | None = None, save_source: bool = False, save_script: bool = True
):
sns.set_theme(context="paper", style="white", palette="colorblind")
self.cwd = init_working_directory(experiment_name, wd_name=name)
self.console = Console()
self.progress = Progress(
SpinnerColumn(),
*Progress.get_default_columns(),
TimeElapsedColumn(),
console=self.console,
).__enter__()
atexit.register(self.progress.stop)
logger.remove()
logger.add(
self.console.print,
level="TRACE",
format=_log_formatter,
colorize=True,
)
logger.add(
self.cwd / "logs.log",
level="TRACE",
format="{time:DD.MM.YYYY HH:mm:ss:ssss} | {level} - {message}",
)
logger.info(f"Starting experiment {experiment_name} in {self.cwd}")
snapshot_python_files(self.cwd / "code", save_source=save_source, save_script=save_script)
self.fieldnames = None
self.writer = None
self.csvfile = open(self.cwd / "metrics.csv", "w", newline="")
self.last_indices: dict[str, jax.Array | None] = defaultdict(lambda: None)
atexit.register(self.csvfile.close)
@property
def stl_dir(self) -> Path:
"""Directory for storing STL files.
Returns:
Path: Directory for STL file outputs
"""
directory = self.cwd / "device" / "stl"
directory.mkdir(parents=True, exist_ok=True)
return directory
@property
def params_dir(self) -> Path:
"""Directory for storing parameter files.
Returns:
Path: Directory for parameter file outputs
"""
directory = self.cwd / "device" / "params"
directory.mkdir(parents=True, exist_ok=True)
return directory
[docs]
def savefig(self, directory: Path, filename: str, fig: Figure, dpi: int = 300):
"""Save a matplotlib figure to file.
Creates a figures subdirectory if needed and saves the figure with specified settings.
Args:
directory (Path): Base directory to save in
filename (str): Name for the figure file
fig (Figure): Matplotlib figure to save
dpi (int, optional): Resolution in dots per inch. Defaults to 300.
"""
figure_directory = directory / "figures"
figure_directory.mkdir(parents=True, exist_ok=True)
fig.savefig(directory / "figures" / filename, dpi=dpi, bbox_inches="tight")
plt.close(fig)
[docs]
def write(self, stats: dict, do_print: bool = True):
"""Write statistics to CSV file and optionally print them.
Records metrics in a CSV file and optionally displays them in a formatted table.
Automatically initializes CSV headers on first write.
Args:
stats (dict): Dictionary of statistics to record
do_print (bool, optional): Whether to print stats to console. Defaults to true.
"""
stats = {
k: v.item() if isinstance(v, jax.Array) else v
for k, v in stats.items()
if isinstance(v, (int, float)) or (isinstance(v, jax.Array) and v.size == 1)
}
if self.fieldnames is None:
self.fieldnames = list(stats.keys())
self.writer = csv.DictWriter(self.csvfile, fieldnames=self.fieldnames)
self.writer.writeheader()
assert self.writer is not None
self.writer.writerow(stats)
self.csvfile.flush()
if do_print:
table = Table(box=None)
for k, v in stats.items():
table.add_column(k)
table.add_column(str(v))
self.console.print(table)
[docs]
def log_detectors(
self,
iter_idx: int,
objects: ObjectContainer,
detector_states: dict[str, DetectorState],
exclude: Sequence[str] = (),
):
"""Log detector states and generate visualization plots.
Creates plots for each detector's state and saves them to the detector's output directory.
Handles both figure outputs and other detector-specific file formats.
Args:
iter_idx (int): Current iteration index
objects (ObjectContainer): Container with simulation objects
detector_states (dict[str, DetectorState]): Dictionary mapping detector names to their states
exclude (Sequence[str], optional): List of detector names to exclude from logging
"""
for detector in [d for d in objects.detectors if d.name not in exclude]:
cur_state = jax.device_get(detector_states[detector.name])
cur_state = cast_floating_to_numpy(cur_state, float)
if not detector.plot:
continue
figure_dict = detector.draw_plot(
state=cur_state,
progress=self.progress,
)
detector_dir = self.cwd / "detectors" / detector.name
detector_dir.mkdir(parents=True, exist_ok=True)
for k, v in figure_dict.items():
if isinstance(v, Figure):
self.savefig(
detector_dir,
f"{detector.name}_{k}_{iter_idx}.png",
v,
dpi=detector.plot_dpi, # type: ignore
)
elif isinstance(v, str):
shutil.copy(
v,
detector_dir / f"{detector.name}_{k}_{iter_idx}{Path(v).suffix}",
)
else:
raise Exception(f"invalid detector output for plotting: {k}, {v}")
[docs]
def log_params(
self,
iter_idx: int,
params: ParameterContainer,
objects: ObjectContainer,
export_figure: bool = False,
export_stl: bool = False,
export_background_stl: bool = False,
**transformation_kwargs,
) -> int:
"""Log parameter states and export device visualizations.
Saves device parameters and optionally exports visualizations as figures or STL files.
Tracks changes in device voxels between iterations.
Args:
iter_idx (int): Current iteration index
params (ParameterContainer): Container with device parameters
objects (ObjectContainer): Container with simulation objects
export_figure (bool, optional): Whether to export index matrix figures
export_stl (bool, optional): Whether to export device geometry as STL
export_background_stl (bool, optional): Whether to export air regions as STL
**transformation_kwargs: keyword arguments passed to the parameter transformation
Returns:
int: Number of voxels that changed since last iteration
"""
changed_voxels = 0
for device in objects.devices:
device_params = params[device.name]
indices = device(device_params, **transformation_kwargs)
# raw parameters and indices
if isinstance(device_params, dict):
device_params_dict = cast(dict[str, jax.Array], device_params)
for k, v in device_params_dict.items():
jnp.save(self.params_dir / f"params_{iter_idx}_{device.name}_{k}.npy", v)
else:
jnp.save(self.params_dir / f"params_{iter_idx}_{device.name}.npy", device_params)
jnp.save(self.params_dir / f"matrix_{iter_idx}_{device.name}.npy", indices)
has_previous = self.last_indices[device.name] is not None
cur_changed_voxels = 0
if has_previous:
last_device_indices = self.last_indices[device.name]
cur_changed_voxels = int(jnp.sum(indices != last_device_indices))
changed_voxels += cur_changed_voxels
self.last_indices[device.name] = indices
if cur_changed_voxels == 0 and has_previous:
continue
if export_stl:
background_name = get_background_material_name(device.materials)
ordered_name_list = compute_ordered_names(device.materials)
background_idx = ordered_name_list.index(background_name)
for idx in range(len(device.materials)):
if idx == background_idx and not export_background_stl:
continue
name = ordered_name_list[idx]
export_stl_fn(
matrix=np.round(indices) == idx,
stl_filename=self.stl_dir / f"matrix_{iter_idx}_{device.name}_{name}.stl",
voxel_grid_size=device.single_voxel_grid_shape,
)
if len(device.materials) > 2:
export_stl_fn(
matrix=np.round(indices) != background_idx,
stl_filename=self.stl_dir / f"matrix_{iter_idx}_{device.name}_non_air.stl",
voxel_grid_size=device.single_voxel_grid_shape,
)
# image of indices
if export_figure:
fig = device_matrix_index_figure(
device_matrix_indices=indices,
material=device.materials,
parameter_type=device.output_type,
)
self.savefig(
self.cwd / "device",
f"matrix_indices_{iter_idx}_{device.name}.png",
fig,
)
return changed_voxels