Train a machine learning model on a collection¶
Here, we iterate over the artifacts within a collection to train a machine learning model at scale.
import lamindb as ln
ln.track("Qr1kIHvK506r0002")
Show code cell output
→ connected lamindb: testuser1/test-scrna
→ created Transform('Qr1kIHvK'), started new Run('bmHmhcgF') at 2024-12-20 15:05:42 UTC
→ notebook imports: lamindb==0.77.3 torch==2.5.1
Query our collection:
collection = ln.Collection.get(name="My versioned scRNA-seq collection", version="2")
collection.describe()
Show code cell output
Collection └── General ├── .uid = 'vU9vGhm9ozKWE66f0001' ├── .hash = 'luH-jPb6eJLsXvc1TWGpUg' ├── .version = '2' ├── .created_by = testuser1 (Test User1) ├── .created_at = 2024-12-20 15:05:15 └── .transform = 'Standardize and append a dataset'
Create a map-style dataset¶
Let us create a map-style dataset using using mapped()
: a MappedCollection
.
Under-the-hood, it performs a virtual join of the features of the underlying AnnData
objects without loading the datasets into memory. You can either perform an inner join:
with collection.mapped(obs_keys=["cell_type"], join="inner") as dataset:
print("#observations", dataset.shape[0])
print("#variables:", len(dataset.var_joint))
Show code cell output
#observations 1718
#variables: 749
Or an outer join:
dataset = collection.mapped(obs_keys=["cell_type"], join="outer")
print("#variables:", len(dataset.var_joint))
Show code cell output
#variables: 36508
This is compatible with a PyTorch DataLoader
because it implements __getitem__
over a list of backed AnnData
objects.
For instance, the 5th observation in the collection can be accessed via:
dataset[5]
Show code cell output
{'X': array([0., 0., 0., ..., 0., 0., 0.], dtype=float32),
'_store_idx': 0,
'cell_type': 29}
The labels
are encoded into integers:
dataset.encoders
Show code cell output
{'cell_type': {'B cell, CD19-positive': 0,
'CD14-positive monocyte': 1,
'CD16-negative, CD56-bright natural killer cell, human': 2,
'CD16-positive, CD56-dim natural killer cell, human': 3,
'CD34-positive, CD56-positive, CD117-positive common innate lymphoid precursor, human': 4,
'CD4-positive helper T cell': 5,
'CD4-positive, CD25-positive, alpha-beta regulatory T cell': 6,
'CD56-positive, CD161-positive immature natural killer cell, human': 7,
'CD8-positive, alpha-beta cytotoxic T cell': 8,
'CD8-positive, alpha-beta memory T cell': 9,
'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 10,
'T follicular helper cell': 11,
'alpha-beta T cell': 12,
'alveolar macrophage': 13,
'animal cell': 14,
'classical monocyte': 15,
'conventional dendritic cell': 16,
'dendritic cell': 17,
'dendritic cell, human': 18,
'effector memory CD4-positive, alpha-beta T cell': 19,
'effector memory CD45RA-positive, alpha-beta T cell, terminally differentiated': 20,
'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 21,
'gamma-delta T cell': 22,
'germinal center B cell': 23,
'group 3 innate lymphoid cell': 24,
'lymphocyte': 25,
'macrophage': 26,
'mast cell': 27,
'megakaryocyte': 28,
'memory B cell': 29,
'mucosal invariant T cell': 30,
'naive B cell': 31,
'naive thymus-derived CD4-positive, alpha-beta T cell': 32,
'naive thymus-derived CD8-positive, alpha-beta T cell': 33,
'non-classical monocyte': 34,
'plasma cell': 35,
'plasmablast': 36,
'plasmacytoid dendritic cell': 37,
'progenitor cell': 38,
'regulatory T cell': 39}}
It is also possible to create a dataset by selecting only observations with certain values of an .obs
column. Setting obs_filter
in the below example makes the dataset iterate only over observations having CD16-positive, CD56-dim natural killer cell, human
or macrophage
in .obs
column cell_type
across all AnnData
objects.
select_by_cell_type = ("CD16-positive, CD56-dim natural killer cell, human", "macrophage")
with collection.mapped(obs_filter=("cell_type", select_by_cell_type)) as dataset_filter:
print(dataset_filter.shape)
Show code cell output
(139, 749)
Create a pytorch DataLoader¶
Let us use a weighted sampler:
from torch.utils.data import DataLoader, WeightedRandomSampler
# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
We can now iterate through the data loader:
for batch in dataloader:
pass
Close the connections in MappedCollection
:
dataset.close()
In practice, use a context manager
with collection.mapped(obs_keys=["cell_type"]) as dataset:
sampler = WeightedRandomSampler(
weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
for batch in dataloader:
pass