## Train a spatial ML model [image: .md][image]

Here, we show how we can query, access, and combine several
SpatialData datasets across different technologies to train a Dense
Net which predicts cell types Xenium data from an associated H&E
image. Specifically, we use the H&E image from Visium data, and the
cell type information from overlapping Xenium data. Both modalities
are spatially aligned via an affine transformation.

This tutorial is adapted from the SpatialData documentation.

 import warnings

 warnings.filterwarnings("ignore")

 import lamindb as ln
 import numpy as np

 import spatialdata as sd
 from spatialdata import transform
 from spatialdata.dataloader.datasets import ImageTilesDataset

 import pytorch_lightning as pl
 from pytorch_lightning.callbacks import LearningRateMonitor
 from pytorch_lightning.callbacks.progress import TQDMProgressBar

 import torch.multiprocessing as mp

 mp.set_start_method("spawn", force=True)

 ln.track()

First, we query for Xenium and Visium datasets that we curated and
ingested on the previous page:

 xenium_af = ln.Artifact.filter(
 tissue="breast",
 assay="10x Xenium",
 ).first()

 visium_af = ln.Artifact.filter(
 tissue="breast",
 assay="Visium Spatial Gene Expression",
 ).first()

From the query results, we load the SpatialData datasets:

 xenium_sd = xenium_af.load()
 visium_sd = visium_af.load()

Because both datasets were curated with matching tissue, disease, and
organism metadata, we can merge them for multi-modal analysis.

 merged_sd = sd.SpatialData(
 images={
 "CytAssist_FFPE_Human_Breast_Cancer_full_image": visium_sd.images[
 "CytAssist_FFPE_Human_Breast_Cancer_full_image"
 ],
 },
 shapes={
 "cell_circles": xenium_sd.shapes["cell_circles"],
 "cell_boundaries": xenium_sd.shapes["cell_boundaries"],
 },
 tables={"table": xenium_sd["table"]},
 )

The Visium image is rotated with respect to the Xenium data.

Next, we create an "ImageTilesDataset" using our merged "SpatialData"
object. We further import an image tile transform, the corresponding
Pytorch Lightning "DataModule", and the final "DenseNet" model from an
existing script.

-[ Code of tile_transform, ImageTilesDataset and the DenseNetModel ]-

Spatial cell type classification model definition

 from spatialdata import SpatialData
 import torch
 from torch.utils.data import DataLoader
 from pytorch_lightning import LightningDataModule, LightningModule
 from torch.nn import CrossEntropyLoss
 from torch.optim import Adam
 from monai.networks.nets import DenseNet121

 def tile_transform(sdata: SpatialData) -> tuple[torch.Tensor, torch.Tensor]:
 cell_types = sdata["table"].obs["celltype_major"].cat.categories.tolist()
 tile = sdata["CytAssist_FFPE_Human_Breast_Cancer_full_image"].data.compute()
 tile = torch.tensor(tile, dtype=torch.float32)

 expected_category = sdata["table"].obs["celltype_major"].values[0]
 expected_idx = cell_types.index(expected_category)
 return tile, torch.tensor(expected_idx)

 class TilesDataModule(LightningDataModule):
 def __init__(
 self, batch_size: int, num_workers: int, dataset: torch.utils.data.Dataset
 ):
 super().__init__()

 self.batch_size = batch_size
 self.num_workers = num_workers
 self.dataset = dataset

 def setup(self, stage=None):
 n_train = int(len(self.dataset) * 0.7)
 n_val = int(len(self.dataset) * 0.2)
 n_test = len(self.dataset) - n_train - n_val
 self.train, self.val, self.test = torch.utils.data.random_split(
 self.dataset,
 [n_train, n_val, n_test],
 generator=torch.Generator().manual_seed(42),
 )

 def train_dataloader(self):
 return DataLoader(
 self.train,
 batch_size=self.batch_size,
 num_workers=self.num_workers,
 shuffle=True,
 )

 def val_dataloader(self):
 return DataLoader(
 self.val,
 batch_size=self.batch_size,
 num_workers=self.num_workers,
 shuffle=False,
 )

 def test_dataloader(self):
 return DataLoader(
 self.test,
 batch_size=self.batch_size,
 num_workers=self.num_workers,
 shuffle=False,
 )

 def predict_dataloader(self):
 return DataLoader(
 self.dataset,
 batch_size=self.batch_size,
 num_workers=self.num_workers,
 shuffle=False,
 )

 class DenseNetModel(LightningModule):
 def __init__(self, learning_rate: float, in_channels: int, num_classes: int):
 super().__init__()

 self.save_hyperparameters()

 self.loss_function = CrossEntropyLoss()

 self.model = DenseNet121(
 spatial_dims=2, in_channels=in_channels, out_channels=num_classes
 )

 def forward(self, x) -> torch.Tensor:
 return self.model(x)

 def _compute_loss_from_batch(
| self, batch: dict[str | int, torch.Tensor], batch_idx: int |
 ) -> float:
 inputs = batch[0]
 labels = batch[1]

 outputs = self.model(inputs)
 return self.loss_function(outputs, labels)

 def training_step(
| self, batch: dict[str | int, torch.Tensor], batch_idx: int |
 ) -> dict[str, float]:
 loss = self._compute_loss_from_batch(batch=batch, batch_idx=batch_idx)

 self.log("training_loss", loss, batch_size=len(batch[0]))

 return {"loss": loss}

 def validation_step(
| self, batch: dict[str | int, torch.Tensor], batch_idx: int |
 ) -> float:
 loss = self._compute_loss_from_batch(batch=batch, batch_idx=batch_idx)

 imgs, labels = batch
 acc = self.compute_accuracy(imgs, labels)
 self.log("test_acc", acc)

 return loss

 def test_step(self, batch, batch_idx):
 imgs, labels = batch
 acc = self.compute_accuracy(imgs, labels)
 self.log("test_acc", acc)

 def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
 imgs, labels = batch
 preds = self.model(imgs).argmax(dim=-1)
 return preds

 def compute_accuracy(self, imgs, labels) -> float:
 preds = self.model(imgs).argmax(dim=-1)

 acc = (labels == preds).float().mean()
 return acc

 def configure_optimizers(self) -> Adam:
 return Adam(self.model.parameters(), lr=self.hparams.learning_rate)

 from spatial_ml import tile_transform, TilesDataModule, DenseNetModel

 dataset = ImageTilesDataset(
 sdata=merged_sd,
 regions_to_images={"cell_circles": "CytAssist_FFPE_Human_Breast_Cancer_full_image"},
 regions_to_coordinate_systems={"cell_circles": "aligned"},
 table_name="table",
 tile_dim_in_units=6
 * np.mean(
 transform(merged_sd["cell_circles"], to_coordinate_system="aligned").radius
 ),
 transform=tile_transform,
 rasterize=True,
 rasterize_kwargs={"target_width": 32},
 )

Now, we only need to set up a DataModule, our model, and we can start
training.

 pl.seed_everything(7)

 tiles_data_module = TilesDataModule(batch_size=64, num_workers=8, dataset=dataset)

 tiles_data_module.setup()
 train_dl = tiles_data_module.train_dataloader()
 val_dl = tiles_data_module.val_dataloader()
 test_dl = tiles_data_module.test_dataloader()

 model = DenseNetModel(
 learning_rate=1e-5,
 in_channels=dataset[0][0].shape[0],
 num_classes=len(merged_sd["table"].obs["celltype_major"].cat.categories.tolist()),
 )

 trainer = pl.Trainer(
 max_epochs=1,
 callbacks=[
 LearningRateMonitor(logging_interval="step"),
 TQDMProgressBar(refresh_rate=5),
 ],
 log_every_n_steps=20,
 )

 trainer.fit(model, datamodule=tiles_data_module)
 trainer.test(model, datamodule=tiles_data_module)

If we were to perform a prediction and evaluate it like outlined in
the original guide, we would see predictions like:

 ln.finish()

 # clean up test instance
 !rm -rf test-spatial
 !lamin delete --force test-spatial