scrna5/6 Jupyter Notebook lamindata

Train a machine learning model on a collection

Here, we iterate over the artifacts within a collection to train a machine learning model at scale.

import lamindb as ln

ln.track("Qr1kIHvK506r0002")
Hide code cell output
→ connected lamindb: testuser1/test-scrna
→ created Transform('Qr1kIHvK'), started new Run('YYW9Gux5') at 2024-11-21 06:54:32 UTC
→ notebook imports: lamindb==0.76.16 torch==2.5.1

Query our collection:

collection = ln.Collection.get(name="My versioned scRNA-seq collection", version="2")
collection.describe()
Hide code cell output
Collection(uid='6HBgmdrV4xz24AVj0001', version='2', is_latest=True, name='My versioned scRNA-seq collection', hash='PPKf6IQ4SPZoY0yMQQfI6w', visibility=1, created_at=2024-11-21 06:54:05 UTC)
  Provenance
    .created_by = 'testuser1'
    .transform = 'Standardize and append a dataset'
    .run = 2024-11-21 06:53:45 UTC
  Usage
    .input_of_runs = 2024-11-21 06:54:15 UTC

Create a map-style dataset

Let us create a map-style dataset using using mapped(): a MappedCollection.

Under-the-hood, it performs a virtual join of the features of the underlying AnnData objects without loading the datasets into memory. You can either perform an inner join:

with collection.mapped(obs_keys=["cell_type"], join="inner") as dataset:
    print("#observations", dataset.shape[0])
    print("#variables:", len(dataset.var_joint))
Hide code cell output
#observations 1718
#variables: 749

Or an outer join:

dataset = collection.mapped(obs_keys=["cell_type"], join="outer")
print("#variables:", len(dataset.var_joint))
Hide code cell output
#variables: 36508

This is compatible with a PyTorch DataLoader because it implements __getitem__ over a list of backed AnnData objects. For instance, the 5th observation in the collection can be accessed via:

dataset[5]
Hide code cell output
{'X': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
 '_store_idx': 0,
 'cell_type': 29}

The labels are encoded into integers:

dataset.encoders
Hide code cell output
{'cell_type': {'B cell, CD19-positive': 0,
  'CD14-positive, CD16-negative classical monocyte': 1,
  'CD16-negative, CD56-bright natural killer cell, human': 2,
  'CD16-positive, CD56-dim natural killer cell, human': 3,
  'CD38-high pre-BCR positive cell': 4,
  'CD38-positive naive B cell': 5,
  'CD4-positive helper T cell': 6,
  'CD8-positive, CD25-positive, alpha-beta regulatory T cell': 7,
  'CD8-positive, alpha-beta memory T cell': 8,
  'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 9,
  'T follicular helper cell': 10,
  'alpha-beta T cell': 11,
  'alveolar macrophage': 12,
  'animal cell': 13,
  'classical monocyte': 14,
  'conventional dendritic cell': 15,
  'cytotoxic T cell': 16,
  'dendritic cell': 17,
  'dendritic cell, human': 18,
  'effector memory CD4-positive, alpha-beta T cell': 19,
  'effector memory CD4-positive, alpha-beta T cell, terminally differentiated': 20,
  'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 21,
  'gamma-delta T cell': 22,
  'germinal center B cell': 23,
  'group 3 innate lymphoid cell': 24,
  'lymphocyte': 25,
  'macrophage': 26,
  'mast cell': 27,
  'megakaryocyte': 28,
  'memory B cell': 29,
  'mucosal invariant T cell': 30,
  'naive B cell': 31,
  'naive thymus-derived CD4-positive, alpha-beta T cell': 32,
  'naive thymus-derived CD8-positive, alpha-beta T cell': 33,
  'non-classical monocyte': 34,
  'plasma cell': 35,
  'plasmablast': 36,
  'plasmacytoid dendritic cell': 37,
  'progenitor cell': 38,
  'regulatory T cell': 39}}

It is also possible to create a dataset by selecting only observations with certain values of an .obs column. Setting obs_filter in the below example makes the dataset iterate only over observations having CD16-positive, CD56-dim natural killer cell, human or macrophage in .obs column cell_type across all AnnData objects.

select_by_cell_type = ("CD16-positive, CD56-dim natural killer cell, human", "macrophage")

with collection.mapped(obs_filter=("cell_type", select_by_cell_type)) as dataset_filter:
    print(dataset_filter.shape)
Hide code cell output
(142, 749)

Create a pytorch DataLoader

Let us use a weighted sampler:

from torch.utils.data import DataLoader, WeightedRandomSampler

# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
    weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)

We can now iterate through the data loader:

for batch in dataloader:
    pass

Close the connections in MappedCollection:

dataset.close()
In practice, use a context manager
with collection.mapped(obs_keys=["cell_type"]) as dataset:
    sampler = WeightedRandomSampler(
        weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
    )
    dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
    for batch in dataloader:
        pass