Jupyter Notebook

Train a spatial ML model

Here, we show how we can query, access, and combine several SpatialData datasets across different technologies to train a Dense Net which predicts cell types Xenium data from an associated H&E image. Specifically, we use the H&E image from Visium data, and the cell type information from overlapping Xenium data. Both modalities are spatially aligned via an affine transformation.

This tutorial is adapted from the SpatialData documentation.

import warnings

warnings.filterwarnings("ignore")

import lamindb as ln
import numpy as np

import spatialdata as sd
from spatialdata import transform
from spatialdata.dataloader.datasets import ImageTilesDataset

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar

import torch.multiprocessing as mp

mp.set_start_method("spawn", force=True)

ln.track(project="spatial guide datasets")
Hide code cell output
 connected lamindb: testuser1/test-spatial
 created Transform('NFdTvSPxXzwJ0000'), started new Run('nPiXCB94...') at 2025-04-18 11:47:47 UTC
 notebook imports: lamindb==1.4.0 numpy==1.26.4 pytorch-lightning==2.5.1 spatial_ml spatialdata==0.3.0 torch==2.6.0

First, we query for Visium and Xenium datasets and create a merged dataset:

xenium_1_sd = ln.Artifact.filter(key="xenium_aligned_1_guide_min.zarr").one().load()
visium_sd = ln.Artifact.filter(key="visium_aligned_guide_min.zarr").one().load()
merged_sd = sd.SpatialData(
    images={
        "CytAssist_FFPE_Human_Breast_Cancer_full_image": visium_sd.images[
            "CytAssist_FFPE_Human_Breast_Cancer_full_image"
        ],
    },
    shapes={
        "cell_circles": xenium_1_sd.shapes["cell_circles"],
        "cell_boundaries": xenium_1_sd.shapes["cell_boundaries"],
    },
    tables={"table": xenium_1_sd["table"]},
)

The Visium image is rotated with respect to the Xenium data.

Dense network of cell types

Next, we create an ImageTilesDataset using our merged SpatialData object. We further import an image tile transform, the corresponding Pytorch Lightning DataModule, and the final DenseNet model from an existing script.

Code of tile_transform, ImageTilesDataset and the DenseNetModel
Spatial cell type classification model definition
from spatialdata import SpatialData
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, LightningModule
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from monai.networks.nets import DenseNet121


def tile_transform(sdata: SpatialData) -> tuple[torch.Tensor, torch.Tensor]:
    cell_types = sdata["table"].obs["celltype_major"].cat.categories.tolist()
    tile = sdata["CytAssist_FFPE_Human_Breast_Cancer_full_image"].data.compute()
    tile = torch.tensor(tile, dtype=torch.float32)

    expected_category = sdata["table"].obs["celltype_major"].values[0]
    expected_idx = cell_types.index(expected_category)
    return tile, torch.tensor(expected_idx)


class TilesDataModule(LightningDataModule):
    def __init__(
        self, batch_size: int, num_workers: int, dataset: torch.utils.data.Dataset
    ):
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset = dataset

    def setup(self, stage=None):
        n_train = int(len(self.dataset) * 0.7)
        n_val = int(len(self.dataset) * 0.2)
        n_test = len(self.dataset) - n_train - n_val
        self.train, self.val, self.test = torch.utils.data.random_split(
            self.dataset,
            [n_train, n_val, n_test],
            generator=torch.Generator().manual_seed(42),
        )

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def predict_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )


class DenseNetModel(LightningModule):
    def __init__(self, learning_rate: float, in_channels: int, num_classes: int):
        super().__init__()

        self.save_hyperparameters()

        self.loss_function = CrossEntropyLoss()

        self.model = DenseNet121(
            spatial_dims=2, in_channels=in_channels, out_channels=num_classes
        )

    def forward(self, x) -> torch.Tensor:
        return self.model(x)

    def _compute_loss_from_batch(
        self, batch: dict[str | int, torch.Tensor], batch_idx: int
    ) -> float:
        inputs = batch[0]
        labels = batch[1]

        outputs = self.model(inputs)
        return self.loss_function(outputs, labels)

    def training_step(
        self, batch: dict[str | int, torch.Tensor], batch_idx: int
    ) -> dict[str, float]:
        loss = self._compute_loss_from_batch(batch=batch, batch_idx=batch_idx)

        self.log("training_loss", loss, batch_size=len(batch[0]))

        return {"loss": loss}

    def validation_step(
        self, batch: dict[str | int, torch.Tensor], batch_idx: int
    ) -> float:
        loss = self._compute_loss_from_batch(batch=batch, batch_idx=batch_idx)

        imgs, labels = batch
        acc = self.compute_accuracy(imgs, labels)
        self.log("test_acc", acc)

        return loss

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        acc = self.compute_accuracy(imgs, labels)
        self.log("test_acc", acc)

    def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        return preds

    def compute_accuracy(self, imgs, labels) -> float:
        preds = self.model(imgs).argmax(dim=-1)

        acc = (labels == preds).float().mean()
        return acc

    def configure_optimizers(self) -> Adam:
        return Adam(self.model.parameters(), lr=self.hparams.learning_rate)
from spatial_ml import tile_transform, TilesDataModule, DenseNetModel

dataset = ImageTilesDataset(
    sdata=merged_sd,
    regions_to_images={"cell_circles": "CytAssist_FFPE_Human_Breast_Cancer_full_image"},
    regions_to_coordinate_systems={"cell_circles": "aligned"},
    table_name="table",
    tile_dim_in_units=6
    * np.mean(
        transform(merged_sd["cell_circles"], to_coordinate_system="aligned").radius
    ),
    transform=tile_transform,
    rasterize=True,
    rasterize_kwargs={"target_width": 32},
)

Now, we only need to set up a DataModule, our model, and we can start training.

pl.seed_everything(7)

tiles_data_module = TilesDataModule(batch_size=64, num_workers=8, dataset=dataset)

tiles_data_module.setup()
train_dl = tiles_data_module.train_dataloader()
val_dl = tiles_data_module.val_dataloader()
test_dl = tiles_data_module.test_dataloader()

model = DenseNetModel(
    learning_rate=1e-5,
    in_channels=dataset[0][0].shape[0],
    num_classes=len(merged_sd["table"].obs["celltype_major"].cat.categories.tolist()),
)

trainer = pl.Trainer(
    max_epochs=1,
    callbacks=[
        LearningRateMonitor(logging_interval="step"),
        TQDMProgressBar(refresh_rate=5),
    ],
    log_every_n_steps=20,
)
Hide code cell output
Seed set to 7
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
trainer.fit(model, datamodule=tiles_data_module)
trainer.test(model, datamodule=tiles_data_module)
Hide code cell output
  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | loss_function | CrossEntropyLoss | 0      | train
1 | model         | DenseNet121      | 7.0 M  | train
-----------------------------------------------------------
7.0 M     Trainable params
0         Non-trainable params
7.0 M     Total params
27.852    Total estimated model params size (MB)
496       Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_epochs=1` reached.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_acc              0.29120880365371704    │
└───────────────────────────┴───────────────────────────┘
[{'test_acc': 0.29120880365371704}]

If we were to perform a prediction and evaluate it like outlined in the original guide, we would see predictions like:

Model predictions