First Basic Simulation#
Now that we covered the basics of JAX, simulation materials and the placement of objects in the previous tutorials, let’s start to run the first actual simulation. In this simulation, we will use a source to show the interaction of light with some cuboid object floating in free space. Of course, this not very practical in real life, but it is a good starting point to show the features of FDTDX.
import fdtdx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import pytreeclass as tc
import time
from IPython.display import Video
%matplotlib inline
Setup of simulation scence#
Let’s start with a basic setup of a simulation scene. We need to specify a random key for possible stochastic operations. This simulation will be entirely deterministic, but we still need to specify the key. Then we specify a SimulationConfig object with some basic information on how long the simulation should run and how accurate it needs to be (resolution, dtype and courant factor).
# Create a JAX random key for reproducibility and stochastic operations
key = jax.random.PRNGKey(seed=42)
# intialize a list of objects
object_list = []
# Define simulation configuration (duration, resolution, data type, etc.)
config = fdtdx.SimulationConfig(
time=100e-15,
resolution=100e-9,
dtype=jnp.float32,
courant_factor=0.99,
)
Next, we specify the simulation volume. This includes the background material, which is used for all the space where we do not place objects in the following specifications.
volume = fdtdx.SimulationVolume(
partial_real_shape=(12.0e-6, 12e-6, 12e-6),
material=fdtdx.Material( # Background material
permittivity=1.0,
permeability=1.0,
)
)
object_list.append(volume)
As we have seen in the object placement tutorial, in FDTDX objects are placed through constraints. We create an empty list of these constraints first and then iteratively add more constraints to the list.
constraints = []
At first, we add the boundaries of our simulation to the constraints. We are using absorbing PML boundaries to prevent any reflections from the boundary of the simulation volume.
We could specify the boundary for each of the six sides of the simulation volume manually, but this would be tedious. Instead, we will use a handy shortcut provided by FDTDX. This creates PML boundaries on all six sides with the corresponding constraints. Here we use a thickness of 10 grid cells for the PML, which should be enough for most applications.
bound_cfg = fdtdx.BoundaryConfig.from_uniform_bound(thickness=10, boundary_type="pml")
bound_dict, c_list = fdtdx.boundary_objects_from_config(bound_cfg, volume)
object_list.extend(bound_dict.values())
constraints.extend(c_list)
Next, we create a light source. The source is placed at the top (z-axis) of the simulation volume and the propagation direction of the light is set downwards (“-“). The polarization is set for Ex-polarized light. Radius and standard deviation determine the spatial profile of the mode. A larger radius would make the emission area larger. A larger standard deviation would “flatten” the gaussian profile, making it more similar to a plane source. The radius and standard deviation should be set such that there is very little energy at the boundary of the source, because this can lead to artifacts.
source = fdtdx.GaussianPlaneSource(
partial_grid_shape=(None, None, 1),
partial_real_shape=(10e-6, 10e-6, None),
fixed_E_polarization_vector=(1, 0, 0),
wave_character=fdtdx.WaveCharacter(wavelength=1.550e-6),
radius=4e-6,
std=1 / 3,
direction="-",
)
object_list.append(source)
constraints.extend(
[
source.place_relative_to(
volume,
axes=(0, 1, 2),
own_positions=(0, 0, 1),
other_positions=(0, 0, 1),
margins=(0, 0, -1.5e-6),
),
]
)
Next, we place a uniform cuboid at the center of the simulaiton volume. This will make the simulation a bit more interesting to look at, because otherwise we will only see the light emitted from the source.
cube = fdtdx.UniformMaterialObject(
partial_real_shape=(3e-6, 3e-6, 3e-6),
material=fdtdx.Material(permittivity=2.0),
name="Cube",
color=fdtdx.colors.PINK,
)
object_list.append(cube)
constraints.append(cube.place_at_center(volume))
In order to actually see a result from the simulation, we need to define a Detector. While the simulation function will return the E and H field after runnning the simulation, usually it is also necessary to read some physical metrics on intermediate time steps in the simulation. This is exactly what Detectors are for!
Here we use an EnergyDetector, which calculates the energy at every grid point within its volume. We also speciy a switch, which controls the time steps that the detector records. Our purpose here is to generate a video of the energy during the simulation. We do not need every single time step for this, so we only record every third time step.
The as_slices option is a memory optimization specific for creating images or videos. With this option set to True, only the values which are actually plotted will be saved instead of the whole simulation volume. If you need to read values from the whole volume, simply disable this option.
video_energy_detector = fdtdx.EnergyDetector(
name="Video",
as_slices=True,
switch=fdtdx.OnOffSwitch(interval=3),
exact_interpolation=True,
num_video_workers=8,
)
object_list.append(video_energy_detector)
constraints.extend(video_energy_detector.same_position_and_size(volume))
These are all the objects we need for our simulation! Let’s resolve the constraints and plot the simulation scene to see if we made any mistakes. Note that it is good practice to split the random key to maintain randomness in JAX (see here for more details)
key, subkey = jax.random.split(key)
objects, arrays, params, config, _ = fdtdx.place_objects(
object_list=object_list,
config=config,
constraints=constraints,
key=subkey,
)
fig = fdtdx.plot_setup(
config=config,
objects=objects,
exclude_object_list=[video_energy_detector],
)
plt.show()
Additionally, we can plot some statistics about the expected memory usage of our simulation. Note that this only includes the arrays that we specify before the simulation starts, not intermediate computational results during the simulation.
In this small simulation, the main memory requirement comes from the PML boundaries. In larger simulation, this requirement of the PML is dominated by the other items in the list.
print(tc.tree_summary(arrays, depth=1))
┌──────────────────────┬──────────────────┬──────────┬────────┐
│Name │Type │Count │Size │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.E │f32[3,120,120,120]│5,184,000 │19.78MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.H │f32[3,120,120,120]│5,184,000 │19.78MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.psi_E │f32[6,120,120,120]│10,368,000│39.55MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.psi_H │f32[6,120,120,120]│10,368,000│39.55MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.alpha │f32[6,120,120,120]│10,368,000│39.55MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.kappa │f32[6,120,120,120]│10,368,000│39.55MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.sigma │f32[6,120,120,120]│10,368,000│39.55MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.inv_permittivities │f32[120,120,120] │1,728,000 │6.59MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.inv_permeabilities │float │1 │ │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.detector_states │dict │7,560,000 │28.84MB │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.recording_state │NoneType │ │ │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.electric_conductivity│NoneType │ │ │
├──────────────────────┼──────────────────┼──────────┼────────┤
│.magnetic_conductivity│NoneType │ │ │
├──────────────────────┼──────────────────┼──────────┼────────┤
│Σ │ArrayContainer │71,496,001│272.74MB│
└──────────────────────┴──────────────────┴──────────┴────────┘
Running the simulation#
Now let’s define a function that actually runs the simulation. In FDTDX, this is a two-part process.
Firstly, we call apply_params, which performs some calculations before the start of the simulation. If we have some parametric objects in the simulation, this function applies the given parameters and calculates the actual shapes of these objects. Additionally, some performance optimization are done here by calculating values for the simulation once before the simulation starts
Then, we call run_fdtd, which performs the FDTD simulation as a loop. The computation terminates as soon as the required number of time steps are reached.
def sim_fn(
params: fdtdx.ParameterContainer,
arrays: fdtdx.ArrayContainer,
key: jax.Array,
):
# Apply parameters to objects and arrays
arrays, new_objects, _ = fdtdx.apply_params(arrays, objects, params, key)
# Run FDTD simulation (forward)
final_state = fdtdx.run_fdtd(
arrays=arrays,
objects=new_objects,
config=config,
key=key,
)
_, arrays = final_state
return arrays
In order to execute this function, we should first compile it. JAX provides a just-in-time compilation functionality with jax.jit, which automatically compiles a function as soon as it is called the first time. We extend this a little bit here by calling .lower() and .compile() to compile the function immediately and measure the compilation time. If this seems complicated, just omit the .lower() and .compile() and everything will still work the same, just the time measurement would be wrong.
start_time = time.time()
jitted_loss = jax.jit(sim_fn).lower(params, arrays, key).compile()
end_time = time.time()
print(f"Compilation time: {end_time - start_time} seconds")
Compilation time: 0.7427017688751221 seconds
Now we are ready to run the simulation. We can see that the simulation time is smaller than the compilation time, which can happen for small simulations. This might seem inefficient, but in pratice a few seconds usually don’t matter. And, we are now able to call the compiled function as often as we like.
start_time = time.time()
new_arrays = jitted_loss(params, arrays, subkey)
end_time = time.time()
print(f"Simulation runtime: {end_time - start_time} seconds")
Simulation runtime: 0.0021200180053710938 seconds
Visualizing the results of a simulation#
Now we have run the simulation, but how do we visualize the results? Our goal was to generate a video of the simulation, so let’s do this.
The syntax for generating a video in a jupyter notebook is currently a bit complicated, but for actual scripts FDTDX offers some utility functions to make this easier (see here for an example script using the ExperimentLogger class of FDTDX). The reason the syntax is so complicated, is on the one hand because of the JAX-syntax which does not allow in-place updates. Additionally, the plot function saves a video to a temporary location. We can either access the video from there or move it to a more permament location
video_path = objects["Video"].draw_plot(new_arrays.detector_states["Video"])
print(video_path)
Video(list(video_path.values())[0], embed=True, width=720)
{'sliced_video': '/var/folders/9c/sf1_5nb17v51rk1l8xcxshdc0000gn/T/tmpq47gqngf.mp4'}