Train a spatial ML model¶
Here, we show how we can query, access, and combine several SpatialData datasets across different technologies to train a Dense Net which predicts cell types Xenium data from an associated H&E image. Specifically, we use the H&E image from Visium data, and the cell type information from overlapping Xenium data. Both modalities are spatially aligned via an affine transformation.
This tutorial is adapted from the SpatialData documentation.
import warnings
warnings.filterwarnings("ignore")
import lamindb as ln
import numpy as np
import spatialdata as sd
from spatialdata import transform
from spatialdata.dataloader.datasets import ImageTilesDataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
import torch.multiprocessing as mp
mp.set_start_method("spawn", force=True)
ln.track(project="spatial guide datasets")
Show code cell output
→ connected lamindb: testuser1/test-spatial
→ created Transform('NFdTvSPxXzwJ0000'), started new Run('nPiXCB94...') at 2025-04-18 11:47:47 UTC
→ notebook imports: lamindb==1.4.0 numpy==1.26.4 pytorch-lightning==2.5.1 spatial_ml spatialdata==0.3.0 torch==2.6.0
First, we query for Visium and Xenium datasets and create a merged dataset:
xenium_1_sd = ln.Artifact.filter(key="xenium_aligned_1_guide_min.zarr").one().load()
visium_sd = ln.Artifact.filter(key="visium_aligned_guide_min.zarr").one().load()
merged_sd = sd.SpatialData(
images={
"CytAssist_FFPE_Human_Breast_Cancer_full_image": visium_sd.images[
"CytAssist_FFPE_Human_Breast_Cancer_full_image"
],
},
shapes={
"cell_circles": xenium_1_sd.shapes["cell_circles"],
"cell_boundaries": xenium_1_sd.shapes["cell_boundaries"],
},
tables={"table": xenium_1_sd["table"]},
)
The Visium image is rotated with respect to the Xenium data.

Next, we create an ImageTilesDataset
using our merged SpatialData
object.
We further import an image tile transform, the corresponding Pytorch Lightning DataModule
, and the final DenseNet
model from an existing script.
Code of tile_transform, ImageTilesDataset and the DenseNetModel
from spatialdata import SpatialData
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, LightningModule
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from monai.networks.nets import DenseNet121
def tile_transform(sdata: SpatialData) -> tuple[torch.Tensor, torch.Tensor]:
cell_types = sdata["table"].obs["celltype_major"].cat.categories.tolist()
tile = sdata["CytAssist_FFPE_Human_Breast_Cancer_full_image"].data.compute()
tile = torch.tensor(tile, dtype=torch.float32)
expected_category = sdata["table"].obs["celltype_major"].values[0]
expected_idx = cell_types.index(expected_category)
return tile, torch.tensor(expected_idx)
class TilesDataModule(LightningDataModule):
def __init__(
self, batch_size: int, num_workers: int, dataset: torch.utils.data.Dataset
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset = dataset
def setup(self, stage=None):
n_train = int(len(self.dataset) * 0.7)
n_val = int(len(self.dataset) * 0.2)
n_test = len(self.dataset) - n_train - n_val
self.train, self.val, self.test = torch.utils.data.random_split(
self.dataset,
[n_train, n_val, n_test],
generator=torch.Generator().manual_seed(42),
)
def train_dataloader(self):
return DataLoader(
self.train,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.val,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
self.test,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def predict_dataloader(self):
return DataLoader(
self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
class DenseNetModel(LightningModule):
def __init__(self, learning_rate: float, in_channels: int, num_classes: int):
super().__init__()
self.save_hyperparameters()
self.loss_function = CrossEntropyLoss()
self.model = DenseNet121(
spatial_dims=2, in_channels=in_channels, out_channels=num_classes
)
def forward(self, x) -> torch.Tensor:
return self.model(x)
def _compute_loss_from_batch(
self, batch: dict[str | int, torch.Tensor], batch_idx: int
) -> float:
inputs = batch[0]
labels = batch[1]
outputs = self.model(inputs)
return self.loss_function(outputs, labels)
def training_step(
self, batch: dict[str | int, torch.Tensor], batch_idx: int
) -> dict[str, float]:
loss = self._compute_loss_from_batch(batch=batch, batch_idx=batch_idx)
self.log("training_loss", loss, batch_size=len(batch[0]))
return {"loss": loss}
def validation_step(
self, batch: dict[str | int, torch.Tensor], batch_idx: int
) -> float:
loss = self._compute_loss_from_batch(batch=batch, batch_idx=batch_idx)
imgs, labels = batch
acc = self.compute_accuracy(imgs, labels)
self.log("test_acc", acc)
return loss
def test_step(self, batch, batch_idx):
imgs, labels = batch
acc = self.compute_accuracy(imgs, labels)
self.log("test_acc", acc)
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
imgs, labels = batch
preds = self.model(imgs).argmax(dim=-1)
return preds
def compute_accuracy(self, imgs, labels) -> float:
preds = self.model(imgs).argmax(dim=-1)
acc = (labels == preds).float().mean()
return acc
def configure_optimizers(self) -> Adam:
return Adam(self.model.parameters(), lr=self.hparams.learning_rate)
from spatial_ml import tile_transform, TilesDataModule, DenseNetModel
dataset = ImageTilesDataset(
sdata=merged_sd,
regions_to_images={"cell_circles": "CytAssist_FFPE_Human_Breast_Cancer_full_image"},
regions_to_coordinate_systems={"cell_circles": "aligned"},
table_name="table",
tile_dim_in_units=6
* np.mean(
transform(merged_sd["cell_circles"], to_coordinate_system="aligned").radius
),
transform=tile_transform,
rasterize=True,
rasterize_kwargs={"target_width": 32},
)
Now, we only need to set up a DataModule, our model, and we can start training.
pl.seed_everything(7)
tiles_data_module = TilesDataModule(batch_size=64, num_workers=8, dataset=dataset)
tiles_data_module.setup()
train_dl = tiles_data_module.train_dataloader()
val_dl = tiles_data_module.val_dataloader()
test_dl = tiles_data_module.test_dataloader()
model = DenseNetModel(
learning_rate=1e-5,
in_channels=dataset[0][0].shape[0],
num_classes=len(merged_sd["table"].obs["celltype_major"].cat.categories.tolist()),
)
trainer = pl.Trainer(
max_epochs=1,
callbacks=[
LearningRateMonitor(logging_interval="step"),
TQDMProgressBar(refresh_rate=5),
],
log_every_n_steps=20,
)
Show code cell output
Seed set to 7
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
trainer.fit(model, datamodule=tiles_data_module)
trainer.test(model, datamodule=tiles_data_module)
Show code cell output
| Name | Type | Params | Mode
-----------------------------------------------------------
0 | loss_function | CrossEntropyLoss | 0 | train
1 | model | DenseNet121 | 7.0 M | train
-----------------------------------------------------------
7.0 M Trainable params
0 Non-trainable params
7.0 M Total params
27.852 Total estimated model params size (MB)
496 Modules in train mode
0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=1` reached.
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_acc │ 0.29120880365371704 │ └───────────────────────────┴───────────────────────────┘
[{'test_acc': 0.29120880365371704}]
If we were to perform a prediction and evaluate it like outlined in the original guide, we would see predictions like:
