Constructing a PET Lightning training wrapper can be extremely slow, even when the underlying pipeline is fast to sample.
In my case:
sample = next(iter(pipeline_i))
print(type(sample), sample.shape, sample.dtype)
returns roughly:
<class 'numpy.ndarray'> (1, 2, 300, 360) float32
in about 0.18 s.
However, constructing:
trainer = pyearthtools.training.lightning.Train(
lightning_model,
datamodule,
path="/g/data/v46/txs156/OM2-emulator/data/",
trainer_kwargs={
"max_epochs": 10,
"num_sanity_val_steps": 1,
},
)
takes about 5 minutes before trainer.fit() is even called.
This appears to be caused by eager datamodule saving during Train(...) construction. This is because my datamodule contains a PET pipeline with a normalisation transform that holds large in-memory xarray.Dataset objects:
pipelines_normed = petpipe.operations.xarray.normalisation.Evaluated(
normalisation_eval="(sample - mean) / deviation",
unnormalisation_eval="(sample * deviation) + mean",
mean = mean_stats_ds,
deviation = std_stats_ds)
where mean and deviation are ~188 MB data arrays with dimensions
time: 228
latitude: 300
longitude: 360
and variables such as:
ocean_heat_content_2d
total_surface_heat_flx
Because PET records init arguments and then YAML-dumps the full datamodule object graph, those large xarray objects end up being serialised during Train(...) construction.
This means wrapper construction time scales with the size of embedded runtime objects, not with actual per-sample pipeline performance, in toy case ~5 minutes.
This significantly slows experimentation and makes PET wrapper startup much slower than raw Lightning usage, even when the actual pipeline is performant.
In may view, constructing pyearthtools.training.lightning.Train(...) should be fast and should not eagerly serialise large runtime state unless explicitly requested.
I would suggest one of the following fixes:
- datamodule_save_mode="yaml" | "manifest" | "snapshot" | "none"
- Make the default behavior lightweight: save a small manifest/config rather than full YAML dumping of live objects, or
- Defer datamodule saving until explicitly requested.
Constructing a PET Lightning training wrapper can be extremely slow, even when the underlying pipeline is fast to sample.
In my case:
returns roughly:
in about 0.18 s.
However, constructing:
takes about 5 minutes before
trainer.fit()is even called.This appears to be caused by eager datamodule saving during
Train(...)construction. This is because my datamodule contains a PET pipeline with a normalisation transform that holds large in-memoryxarray.Datasetobjects:where mean and deviation are ~188 MB data arrays with dimensions
and variables such as:
Because PET records init arguments and then YAML-dumps the full datamodule object graph, those large xarray objects end up being serialised during
Train(...)construction.This means wrapper construction time scales with the size of embedded runtime objects, not with actual per-sample pipeline performance, in toy case ~5 minutes.
This significantly slows experimentation and makes PET wrapper startup much slower than raw Lightning usage, even when the actual pipeline is performant.
In may view, constructing
pyearthtools.training.lightning.Train(...)should be fast and should not eagerly serialise large runtime state unless explicitly requested.I would suggest one of the following fixes: