lamindb.integrations.lightning.Callback

class lamindb.integrations.lightning.Callback(path, key, features=None)

Bases: Callback

Saves PyTorch Lightning model checkpoints to the LaminDB instance after each training epoch.

Creates version families of artifacts for given key (relative file path).

See also: MLFlow & Weights & Biases.

Parameters:
  • path (str | Path) – A local path to the checkpoint.

  • key (str) – The key for the checkpoint artifact.

  • features (dict[str, Any] | None, default: None) – Features to annotate the checkpoint.

Examples

Create a callback that creates artifacts for checkpoints and annotates them by the MLflow run ID:

import lightning as pl
from lamindb.integrations import lightning as ll

lamindb_callback = ll.Callback(
    path=checkpoint_filename, key=artifact_key, features={"mlflow_run_id": mlflow_run.info.run_id}
)
trainer = pl.Trainer(callbacks=[lamindb_callback])

Methods

on_train_start(trainer, pl_module)

Validates that features exist for all specified params.

Return type:

None

on_train_epoch_end(trainer, pl_module)

Saves model checkpoint artifacts at the end of each epoch and optionally annotates them.

Return type:

None