!which pip02 - mouse tunning
This notebook contains multiple methods on how we trained the model. We summarize below which params were helpful in generating better integration.
n_layers: should be 2 - 3gene_dispersion:geneproved to be the bestgene_likelihood:nbprefered overzinbdropout_rate: smaller penalization keeps datapoints closer (0.005)
import scvi
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
from rich import print
from scib_metrics.benchmark import Benchmarker
from scvi.model.utils import mde
import warnings
from lightning_fabric.plugins.environments.slurm import PossibleUserWarning
warnings.simplefilter(action='ignore', category=PossibleUserWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
scvi.settings.seed = 42sc.set_figure_params(figsize=(10, 6))
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'adata = sc.read("../data/processed/01_mouse_reprocessed.h5ad")
adataadata.obs.experiment.unique().tolist()sc.pp.highly_variable_genes(
adata,
flavor="seurat_v3",
n_top_genes=3_000,
layer="counts",
batch_key="batch",
subset=True,
)1 2. Pimp my model: brute-force
In this method we brute-force params. At each iteration we generate PAGA graph which we use to check if the integration connects the correct cell types.
import itertools
import pandas as pdref_df = pd.DataFrame(0, index=adata.obs.ct.cat.categories, columns=adata.obs.ct.cat.categories)
ref_df.loc['Zygote', '2C'] = 1
ref_df.loc['2C', '4C'] = 1
ref_df.loc['4C', '8C'] = 1
ref_df.loc['16C', 'E3.25-ICM'] = 1
ref_df.loc['E3.25-ICM', 'E3.5-ICM'] = 1
# df.loc['E3.5-ICM', 'E3.5-PrE'] = 1
# df.loc['E3.5-ICM', 'E3.5-EPI'] = 1
# df.loc['E3.5-ICM', 'E3.5-TE'] = 1
ref_df.loc['E3.5-EPI', 'E4.5-EPI'] = 1
ref_df.loc['E3.5-PrE', 'E4.5-PrE'] = 1
ref_df.loc['E3.5-TE', 'E4.5-TE'] = 1scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="batch")params = [["nb", "zinb"], ["gene", "gene-batch"], [32, 64, 128], list(range(2,6))]tracked_params = []
for items in list(itertools.product(*params)):
gene_likelihood, dispersion, n_layers, n_hidden = items
# SCVI
vae = scvi.model.SCVI(
adata,
n_layers=n_layers,
n_hidden=n_hidden,
dispersion=dispersion,
gene_likelihood=gene_likelihood
)
vae.train(use_gpu=1, max_epochs=400, early_stopping=True)
# SCANVI
lvae = scvi.model.SCANVI.from_scvi_model(
vae,
adata=adata,
labels_key="ct",
unlabeled_category="Unknown",
)
lvae.train(max_epochs=10)
adata.obsm["X_scANVI"] = lvae.get_latent_representation(adata)
try:
sc.pp.neighbors(adata, use_rep='X_scANVI')
sc.tl.diffmap(adata)
sc.tl.paga(adata, groups='ct')
sc.pl.paga(adata, color=['ct'], frameon=False, fontoutline=True)
sc.tl.draw_graph(adata, init_pos='paga', n_jobs=10)
df = pd.DataFrame(
adata.uns['paga']['connectivities'].A,
index=adata.obs.ct.cat.categories,
columns=adata.obs.ct.cat.categories
)
# maximize the connectivity, even though the interaction
# is around 0.5
df = df.round()
n_ref = np.sum(ref_df.values * ref_df.values)
tracked_params.append(list(items) + [n_ref] + ["success"])
except TypeError as e:
# TypeError: sparse matrix length is ambiguous; use getnnz() or shape[0]
# This error comes from PAGA, usually the integration was a fail
tracked_params.append(list(items) + [0] + ["failed"])opt_params = pd.DataFrame(tracked_params,
columns=['gene_likelihood', 'dispersion', 'n_layers', 'n_hidden', 'paga', 'run']) \
.query('run == "success"') \
# .query('n_layers >= 64')
opt_params.to_csv("../results/02_mouse_integration/opt_params.csv")
opt_params2 3. Pimp my model: ray tunner
import ray
import jax
import os
from ray import tune
from scvi import autotune
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'jax.devices()ref_tuner = sc.AnnData(adata.layers["counts"])
ref_tuner.obs = adata.obs[["total_counts", "technology", "batch"]].copy()
model_cls = scvi.model.SCVI
model_cls.setup_anndata(ref_tuner,
batch_key="batch")
scvi_tuner = autotune.ModelTuner(model_cls)scvi_tuner.info()search_space = {
"gene_likelihood": tune.choice(["nb", "zinb"]),
"dispersion": tune.choice(["gene", "gene-batch"]),
"n_hidden": tune.choice([128, 144, 256]),
"n_layers": tune.choice([2, 3, 4, 5]),
"lr": tune.loguniform(1e-4, 0.6),
}ray.init(
log_to_driver=False,
num_cpus=10,
num_gpus=2,
)results = scvi_tuner.fit(
ref_tuner,
metric="validation_loss",
search_space=search_space,
num_samples=50,
max_epochs=100,
)print(results.model_kwargs)
print(results.train_kwargs)
print(results.metric)import pandas as pd
training = pd.DataFrame([
[x.metrics['validation_loss']] + x.path.split(',')[1:]
for x in results.results if 'validation_loss' in x.metrics
]).sort_values(by=0)
training.to_csv("../results/02_mouse_integration/tunning.csv")
display(training.head(10))