02 - mouse tunning

Author

Martin Proks

Published

July 7, 2023

This notebook contains multiple methods on how we trained the model. We summarize below which params were helpful in generating better integration.

!which pip
import scvi
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt


from rich import print
from scib_metrics.benchmark import Benchmarker
from scvi.model.utils import mde


import warnings
from lightning_fabric.plugins.environments.slurm import PossibleUserWarning
warnings.simplefilter(action='ignore', category=PossibleUserWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

scvi.settings.seed = 42
sc.set_figure_params(figsize=(10, 6))

%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'
adata = sc.read("../data/processed/01_mouse_reprocessed.h5ad")
adata
adata.obs.experiment.unique().tolist()
sc.pp.highly_variable_genes(
    adata,
    flavor="seurat_v3",
    n_top_genes=3_000,
    layer="counts",
    batch_key="batch",
    subset=True,
)

1 2. Pimp my model: brute-force

In this method we brute-force params. At each iteration we generate PAGA graph which we use to check if the integration connects the correct cell types.

import itertools
import pandas as pd
ref_df = pd.DataFrame(0, index=adata.obs.ct.cat.categories, columns=adata.obs.ct.cat.categories)
ref_df.loc['Zygote', '2C'] = 1
ref_df.loc['2C', '4C'] = 1
ref_df.loc['4C', '8C'] = 1
ref_df.loc['16C', 'E3.25-ICM'] = 1
ref_df.loc['E3.25-ICM', 'E3.5-ICM'] = 1
# df.loc['E3.5-ICM', 'E3.5-PrE'] = 1
# df.loc['E3.5-ICM', 'E3.5-EPI'] = 1
# df.loc['E3.5-ICM', 'E3.5-TE'] = 1

ref_df.loc['E3.5-EPI', 'E4.5-EPI'] = 1
ref_df.loc['E3.5-PrE', 'E4.5-PrE'] = 1
ref_df.loc['E3.5-TE', 'E4.5-TE'] = 1
scvi.model.SCVI.setup_anndata(adata, layer="counts", batch_key="batch")
params = [["nb", "zinb"], ["gene", "gene-batch"], [32, 64, 128], list(range(2,6))]
tracked_params = []

for items in list(itertools.product(*params)):
    gene_likelihood, dispersion, n_layers, n_hidden = items
    
    # SCVI
    vae = scvi.model.SCVI(
        adata, 
        n_layers=n_layers, 
        n_hidden=n_hidden, 
        dispersion=dispersion, 
        gene_likelihood=gene_likelihood
    )
    vae.train(use_gpu=1, max_epochs=400, early_stopping=True)
    
    # SCANVI
    lvae = scvi.model.SCANVI.from_scvi_model(
        vae,
        adata=adata,
        labels_key="ct",
        unlabeled_category="Unknown",
    )
    lvae.train(max_epochs=10)
    adata.obsm["X_scANVI"] = lvae.get_latent_representation(adata)

    try:
        sc.pp.neighbors(adata, use_rep='X_scANVI')
        sc.tl.diffmap(adata)
        sc.tl.paga(adata, groups='ct')
        sc.pl.paga(adata, color=['ct'], frameon=False, fontoutline=True)
        sc.tl.draw_graph(adata, init_pos='paga', n_jobs=10)
        df = pd.DataFrame(
            adata.uns['paga']['connectivities'].A, 
            index=adata.obs.ct.cat.categories, 
            columns=adata.obs.ct.cat.categories
        )
        # maximize the connectivity, even though the interaction
        # is around 0.5
        df = df.round()

        n_ref = np.sum(ref_df.values * ref_df.values)
        
        tracked_params.append(list(items) + [n_ref] + ["success"])
    except TypeError as e:
        # TypeError: sparse matrix length is ambiguous; use getnnz() or shape[0]
        # This error comes from PAGA, usually the integration was a fail
        tracked_params.append(list(items) + [0] + ["failed"])
opt_params = pd.DataFrame(tracked_params, 
                          columns=['gene_likelihood', 'dispersion', 'n_layers', 'n_hidden', 'paga', 'run']) \
            .query('run == "success"') \
            # .query('n_layers >= 64')
opt_params.to_csv("../results/02_mouse_integration/opt_params.csv")
opt_params

2 3. Pimp my model: ray tunner

import ray
import jax
import os

from ray import tune
from scvi import autotune

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'
jax.devices()
ref_tuner = sc.AnnData(adata.layers["counts"])
ref_tuner.obs = adata.obs[["total_counts", "technology", "batch"]].copy()

model_cls = scvi.model.SCVI
model_cls.setup_anndata(ref_tuner, 
                        batch_key="batch")

scvi_tuner = autotune.ModelTuner(model_cls)
scvi_tuner.info()
search_space = {
    "gene_likelihood": tune.choice(["nb", "zinb"]),
    "dispersion": tune.choice(["gene", "gene-batch"]),
    "n_hidden": tune.choice([128, 144, 256]),
    "n_layers": tune.choice([2, 3, 4, 5]),
    "lr": tune.loguniform(1e-4, 0.6),
}
ray.init(
    log_to_driver=False,
    num_cpus=10,
    num_gpus=2,
)
results = scvi_tuner.fit(
    ref_tuner,
    metric="validation_loss",
    search_space=search_space,
    num_samples=50,
    max_epochs=100,
)
print(results.model_kwargs)
print(results.train_kwargs)
print(results.metric)
import pandas as pd


training = pd.DataFrame([
    [x.metrics['validation_loss']] + x.path.split(',')[1:]
    for x in results.results if 'validation_loss' in x.metrics
]).sort_values(by=0)

training.to_csv("../results/02_mouse_integration/tunning.csv")
display(training.head(10))