Weights & Biases

We show how LaminDB can be integrated with W&B to track the training process and associate datasets & parameters with models.

# !pip install -q 'lamindb[jupyter,aws]' torch torchvision lightning wandb
!lamin init --storage ./lamin-mlops
!wandb login
Hide code cell output
 connected lamindb: anonymous/lamin-mlops
wandb: Currently logged in as: felix_lamin (lamin-mlops-demo). Use `wandb login --relogin` to force relogin
import lamindb as ln
import wandb

ln.context.uid = "tULn4Va2yERp0000"
ln.context.track()
Hide code cell output
 connected lamindb: anonymous/lamin-mlops
 created Transform('tULn4Va2'), started new Run('HS9qEfUY') at 2024-12-20 15:03:47 UTC
 notebook imports: lamindb==0.77.3 lightning==2.5.0 torch==2.5.1 torchvision==0.20.1 wandb==0.19.1

Define a model

Define a simple autoencoder as an example model using PyTorch Lightning.

from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning


class LitAutoEncoder(lightning.LightningModule):
    def __init__(self, hidden_size, bottleneck_size):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, hidden_size), 
            nn.ReLU(), 
            nn.Linear(hidden_size, bottleneck_size)
        )
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_size, hidden_size), 
            nn.ReLU(), 
            nn.Linear(hidden_size, 28 * 28)
        )
        # save hyper-parameters to self.hparams auto-logged by wandb
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Query & download the MNIST dataset

We saved the MNIST dataset in curation notebook and it now shows up in the artifact registry:

ln.Artifact.filter(type="dataset").df()
uid key description suffix type size hash n_objects n_observations _hash_type _accessor visibility _key_is_virtual storage_id transform_id version is_latest run_id created_at created_by_id
id
1 9hJz1vrzYAkOOxXx0000 testdata/mnist None dataset 54950048 amFx_vXqnUtJr0kmxxWK2Q 4 None md5-d None 1 True 1 1 None True 1 2024-12-20 15:03:38.601606+00:00 1

You can also see it on lamin.ai if you connected your instance.

Let’s get the dataset:

artifact = ln.Artifact.get(key="testdata/mnist")
artifact
Hide code cell output
Artifact(uid='9hJz1vrzYAkOOxXx0000', is_latest=True, key='testdata/mnist', suffix='', type='dataset', size=54950048, hash='amFx_vXqnUtJr0kmxxWK2Q', n_objects=4, _hash_type='md5-d', visibility=1, _key_is_virtual=True, storage_id=1, transform_id=1, run_id=1, created_by_id=1, created_at=2024-12-20 15:03:38 UTC)

And download it to a local cache:

path = artifact.cache()
path
Hide code cell output
PosixUPath('/home/runner/work/lamin-mlops/lamin-mlops/docs/lamin-mlops/.lamindb/9hJz1vrzYAkOOxXx')

Create a pytorch-compatible dataset:

dataset = MNIST(path.as_posix(), transform=ToTensor())
dataset
Hide code cell output
Dataset MNIST
    Number of datapoints: 60000
    Root location: /home/runner/work/lamin-mlops/lamin-mlops/docs/lamin-mlops/.lamindb/9hJz1vrzYAkOOxXx
    Split: Train
    StandardTransform
Transform: ToTensor()

Monitor training with wandb

Train our example model and track the training progress with wandb.

from lightning.pytorch.loggers import WandbLogger

MODEL_CONFIG = {
    "hidden_size": 32,
    "bottleneck_size": 16,
    "batch_size": 32
}

# create the data loader
train_loader = utils.data.DataLoader(dataset, batch_size=MODEL_CONFIG["batch_size"], shuffle=True)

# init model
autoencoder = LitAutoEncoder(MODEL_CONFIG["hidden_size"], MODEL_CONFIG["bottleneck_size"])

# initialize the logger
wandb_logger = WandbLogger(project="lamin")

# add batch size to the wandb config
wandb_logger.experiment.config["batch_size"] = MODEL_CONFIG["batch_size"]
Hide code cell output
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: felix_lamin (lamin-mlops-demo). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.19.1
wandb: Run data is saved locally in ./wandb/run-20241220_150351-iuiq2m5w
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run playful-planet-149
wandb: ⭐️ View project at https://wandb.ai/lamin-mlops-demo/lamin
wandb: 🚀 View run at https://wandb.ai/lamin-mlops-demo/lamin/runs/iuiq2m5w
from lightning.pytorch.callbacks import ModelCheckpoint

# store checkpoints to disk and upload to LaminDB after training
checkpoint_callback = ModelCheckpoint(
    dirpath=f"model_checkpoints/{wandb_logger.version}", 
    filename="last_epoch",
    save_top_k=1,
    monitor="train_loss"
)

# train model
trainer = lightning.Trainer(
    accelerator="cpu",
    limit_train_batches=3, 
    max_epochs=2,
    logger=wandb_logger,
    callbacks=[checkpoint_callback]
)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
Hide code cell output
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | encoder | Sequential | 25.6 K | train
1 | decoder | Sequential | 26.4 K | train
-----------------------------------------------
52.1 K    Trainable params
0         Non-trainable params
52.1 K    Total params
0.208     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/opt/hostedtoolcache/Python/3.10.15/x64/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (3) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Training: |          | 0/? [00:00<?, ?it/s]
Training:   0%|          | 0/3 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/3 [00:00<?, ?it/s] 
Epoch 0:  33%|███▎      | 1/3 [00:00<00:00, 39.76it/s]
Epoch 0:  33%|███▎      | 1/3 [00:00<00:00, 38.33it/s, v_num=2m5w]
Epoch 0:  67%|██████▋   | 2/3 [00:00<00:00, 57.77it/s, v_num=2m5w]
Epoch 0:  67%|██████▋   | 2/3 [00:00<00:00, 56.50it/s, v_num=2m5w]
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 70.25it/s, v_num=2m5w]
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 68.98it/s, v_num=2m5w]
Epoch 0: 100%|██████████| 3/3 [00:00<00:00, 66.87it/s, v_num=2m5w]
Epoch 0:   0%|          | 0/3 [00:00<?, ?it/s, v_num=2m5w]        
Epoch 1:   0%|          | 0/3 [00:00<?, ?it/s, v_num=2m5w]
Epoch 1:  33%|███▎      | 1/3 [00:00<00:00, 100.19it/s, v_num=2m5w]
Epoch 1:  33%|███▎      | 1/3 [00:00<00:00, 92.91it/s, v_num=2m5w] 
Epoch 1:  67%|██████▋   | 2/3 [00:00<00:00, 112.25it/s, v_num=2m5w]
Epoch 1:  67%|██████▋   | 2/3 [00:00<00:00, 107.38it/s, v_num=2m5w]
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 114.35it/s, v_num=2m5w]
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 110.80it/s, v_num=2m5w]
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 105.13it/s, v_num=2m5w]
`Trainer.fit` stopped: `max_epochs=2` reached.
Epoch 1: 100%|██████████| 3/3 [00:00<00:00, 85.95it/s, v_num=2m5w] 

wandb_logger.experiment.name
Hide code cell output
'playful-planet-149'
wandb_logger.version
Hide code cell output
'iuiq2m5w'
wandb.finish()
Hide code cell output
wandb:                                                                                
wandb: 🚀 View run playful-planet-149 at: https://wandb.ai/lamin-mlops-demo/lamin/runs/iuiq2m5w
wandb: ⭐️ View project at: https://wandb.ai/lamin-mlops-demo/lamin
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20241220_150351-iuiq2m5w/logs

See the training progress in the wandb UI:

Save model in LaminDB

# save checkpoint as a model in LaminDB
artifact = ln.Artifact(
    f"model_checkpoints/{wandb_logger.version}",
    key="testmodels/litautoencoder",  # is automatically versioned
    type="model",
).save()

# create a label with the wandb experiment name
experiment_label = ln.ULabel(
    name=wandb_logger.experiment.name, 
    description="wandb experiment name"
).save()

# annotate the model artifact
artifact.ulabels.add(experiment_label)

# define the associated model hyperparameters in ln.Param
for k, v in MODEL_CONFIG.items():
    ln.Param(name=k, dtype=type(v).__name__).save()
artifact.params.add_values(MODEL_CONFIG)

# describe the artifact
artifact.describe()
Hide code cell output
Artifact 
├── General
│   ├── .uid = 'EiCCbOMVTDpwDpJp0000'
│   ├── .key = 'testmodels/litautoencoder'
│   ├── .size = 636275
│   ├── .hash = 'o0OHh1W5CYoOPqcVBzhj3Q'
│   ├── .n_objects = 1
│   ├── .path = /home/runner/work/lamin-mlops/lamin-mlops/docs/lamin-mlops/.lamindb/EiCCbOMVTDpwDpJp
│   ├── .created_by = anonymous
│   ├── .created_at = 2024-12-20 15:03:53
│   └── .transform = 'Weights & Biases'
├── Params
│   └── batch_size                  int                        32                                       
bottleneck_size             int                        16                                       
hidden_size                 int                        32                                       
└── Labels
    └── .ulabels                    ULabel                     playful-planet-149                       

See the checkpoints:

If later on, you want to re-use the checkpoint, you can download it like so:

ln.Artifact.get(key='testmodels/litautoencoder').cache()
PosixUPath('/home/runner/work/lamin-mlops/lamin-mlops/docs/lamin-mlops/.lamindb/EiCCbOMVTDpwDpJp')

Or on the CLI:

lamin get artifact --key 'testmodels/litautoencoder'
# save notebook
# ln.context.finish()