Skip to content

Slow PET Lightning training wrapper build due to eager YAML save #265

@taimoorsohail

Description

@taimoorsohail

Constructing a PET Lightning training wrapper can be extremely slow, even when the underlying pipeline is fast to sample.

In my case:

sample = next(iter(pipeline_i))
print(type(sample), sample.shape, sample.dtype)

returns roughly:

<class 'numpy.ndarray'> (1, 2, 300, 360) float32

in about 0.18 s.

However, constructing:

trainer = pyearthtools.training.lightning.Train(
    lightning_model,
    datamodule,
    path="/g/data/v46/txs156/OM2-emulator/data/",
    trainer_kwargs={
        "max_epochs": 10,
        "num_sanity_val_steps": 1,
    },
)

takes about 5 minutes before trainer.fit() is even called.

This appears to be caused by eager datamodule saving during Train(...) construction. This is because my datamodule contains a PET pipeline with a normalisation transform that holds large in-memory xarray.Dataset objects:

    pipelines_normed = petpipe.operations.xarray.normalisation.Evaluated(
                        normalisation_eval="(sample - mean) / deviation",
                        unnormalisation_eval="(sample * deviation) + mean",
                        mean = mean_stats_ds,
                        deviation = std_stats_ds)

where mean and deviation are ~188 MB data arrays with dimensions

time: 228
latitude: 300
longitude: 360

and variables such as:

ocean_heat_content_2d
total_surface_heat_flx

Because PET records init arguments and then YAML-dumps the full datamodule object graph, those large xarray objects end up being serialised during Train(...) construction.

This means wrapper construction time scales with the size of embedded runtime objects, not with actual per-sample pipeline performance, in toy case ~5 minutes.
This significantly slows experimentation and makes PET wrapper startup much slower than raw Lightning usage, even when the actual pipeline is performant.

In may view, constructing pyearthtools.training.lightning.Train(...) should be fast and should not eagerly serialise large runtime state unless explicitly requested.

I would suggest one of the following fixes:

  • datamodule_save_mode="yaml" | "manifest" | "snapshot" | "none"
  • Make the default behavior lightweight: save a small manifest/config rather than full YAML dumping of live objects, or
  • Defer datamodule saving until explicitly requested.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions