11 - XGBoost vs scANVI

Author

Martin Proks, Nazmus Salehin

Published

May 1, 2025

In this notebook we compare the accuracy between XGBoost and scANVI

!which pip
/projects/dan1/data/Brickman/conda/envs/scvi-1.0.0/bin/pip
%matplotlib inline

import scvi
import xgboost as xgb
import pandas as pd
import numpy as np
import seaborn as sns
import scanpy as sc
from tqdm import tqdm
import matplotlib.pyplot as plt

from typing import Tuple

from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings

warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
warnings.simplefilter('ignore', category=FutureWarning)
warnings.simplefilter('ignore', category=UserWarning)

scvi.settings.seed = 0

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
/projects/dan1/data/Brickman/conda/envs/scvi-1.0.0/lib/python3.10/site-packages/scvi/_settings.py:63: UserWarning: Since v1.0.0, scvi-tools no longer uses a random seed by default. Run `scvi.settings.seed = 0` to reproduce results from previous versions.
  self.seed = seed
/projects/dan1/data/Brickman/conda/envs/scvi-1.0.0/lib/python3.10/site-packages/scvi/_settings.py:70: UserWarning: Setting `dl_pin_memory_gpu_training` is deprecated in v1.0 and will be removed in v1.1. Please pass in `pin_memory` to the data loaders instead.
  self.dl_pin_memory_gpu_training = (
[rank: 0] Global seed set to 0
import anndata
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score


def cripple_classifiers(adata: anndata.AnnData, lvae: scvi.model.SCANVI, xgb_clf: xgb.XGBClassifier, groupby: str):
    """
    To prove a point that scANVI is better in term of prediction compared to XGBoost
    due to it's ability to use full list of HVGs, we sort features (genes) based on
    their normalized dispertion and remove them during prediction.
    We expect that if they are not present the XGBoost performance double downs on its
    performance as it will be missing these features to make a correct prediction.
    """

    denoised = lvae.get_normalized_expression(return_mean=True)
    features = adata.var.sort_values(by='dispersions_norm', ascending=False).index

    metrics = {
        'scANVI': { m: [] for m in ['accuracy', 'balanced', 'f1_micro', 'f1_macro'] },
        'XGBoost': { m: [] for m in ['accuracy', 'balanced', 'f1_micro', 'f1_macro'] }
    }
    
    for i in tqdm([0, 10, 20, 50, 100, 200, 500]):
        less_features = features[i:]

        # XGBoost prediction
        xgb_denoised = denoised.copy()
        if i != 0:
            xgb_denoised.loc[:, less_features] = 0
        xgb_pred = xgb_clf.predict(xgb_denoised)
        
        # scANVI query
        subset = adata[:, less_features].copy()
        subset.obs = subset.obs[['batch']].copy()
        scvi.model.SCANVI.prepare_query_anndata(subset, lvae)
        lvae_q = scvi.model.SCANVI.load_query_data(subset, lvae)
        lvae_q.train(max_epochs=100, plan_kwargs=dict(weight_decay=0.0), check_val_every_n_epoch=10, enable_progress_bar=False)
        scANVI_pred = lvae_q.predict()

        metrics['scANVI']['accuracy'].append(accuracy_score(adata.obs[groupby], scANVI_pred))
        metrics['XGBoost']['accuracy'].append(accuracy_score(adata.obs[groupby].cat.codes.values, xgb_pred))
        
        metrics['scANVI']['balanced'].append(balanced_accuracy_score(adata.obs[groupby], scANVI_pred))
        metrics['XGBoost']['balanced'].append(balanced_accuracy_score(adata.obs[groupby].cat.codes.values, xgb_pred))

        metrics['scANVI']['f1_micro'].append(f1_score(adata.obs[groupby], scANVI_pred, average="micro"))
        metrics['XGBoost']['f1_micro'].append(f1_score(adata.obs[groupby].cat.codes.values, xgb_pred, average="micro"))

        metrics['scANVI']['f1_macro'].append(f1_score(adata.obs[groupby], scANVI_pred, average="macro"))
        metrics['XGBoost']['f1_macro'].append(f1_score(adata.obs[groupby].cat.codes.values, xgb_pred, average="macro"))
    
    return metrics

1 Mouse

# base dataset
mouse = sc.read("../results/02_mouse_integration/scanvi_ns_15/adata.h5ad")

# XGBoost
mouse_xg_scVI = xgb.XGBClassifier()
mouse_xg_scVI.load_model("../results/02_mouse_integration/05_scVI_xgboost.json")

# scANVI
mouse_scANVI = scvi.model.SCANVI.load("../results/02_mouse_integration/scanvi_ns_15/")
INFO     File ../results/02_mouse_integration/scanvi_ns_15/model.pt already downloaded                             
mouse_stats = cripple_classifiers(mouse, mouse_scANVI, mouse_xg_scVI, groupby='ct')
  0%|                                                                                                              | 0/7 [00:00<?, ?it/s]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 14%|██████████████▌                                                                                       | 1/7 [00:19<01:57, 19.50s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 29%|█████████████████████████████▏                                                                        | 2/7 [00:39<01:37, 19.59s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 43%|███████████████████████████████████████████▋                                                          | 3/7 [00:59<01:19, 19.78s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 57%|██████████████████████████████████████████████████████████▎                                           | 4/7 [01:19<00:59, 19.93s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 71%|████████████████████████████████████████████████████████████████████████▊                             | 5/7 [01:39<00:39, 19.85s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 86%|███████████████████████████████████████████████████████████████████████████████████████▍              | 6/7 [01:58<00:19, 19.85s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:19<00:00, 19.90s/it]
INFO     Found 100.0% reference vars in query data.                                                                
INFO     Training for 100 epochs.                                                                                  
INFO     Found 99.66666666666667% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 99.33333333333333% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 98.33333333333333% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 96.66666666666667% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 93.33333333333333% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 83.33333333333334% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  

2 Human

2.1 XGBoost [scANVI]

%run ../scripts/helpers.py
# base dataset
human = sc.read("../results/02_human_integration/05_scanvi_ns15/adata.h5ad")
vae = scvi.model.SCVI.load("../results/02_human_integration/scvi/")
vae_df = pd.DataFrame(vae.get_normalized_expression(return_mean=True))
vae_df['target'] = human.obs.C_scANVI_nsamples

X_train, y_train, X_test, y_test = train_test_split_by_group(vae_df)
vae_xgboost = train_xgboost(vae_df, X_train, y_train, X_test, y_test)
vae_xgboost.save_model("../results/02_human_integration/11_scVI_xgboost.json")
INFO     File ../results/02_human_integration/scvi/model.pt already downloaded                                     
[0] validation_0-merror:0.06048 validation_0-mlogloss:1.25520   validation_1-merror:0.22081 validation_1-mlogloss:1.52118
[1] validation_0-merror:0.03240 validation_0-mlogloss:0.89775   validation_1-merror:0.18471 validation_1-mlogloss:1.22860
[2] validation_0-merror:0.01674 validation_0-mlogloss:0.66799   validation_1-merror:0.18471 validation_1-mlogloss:1.04108
[3] validation_0-merror:0.01242 validation_0-mlogloss:0.50784   validation_1-merror:0.16985 validation_1-mlogloss:0.89968
[4] validation_0-merror:0.00756 validation_0-mlogloss:0.39042   validation_1-merror:0.16561 validation_1-mlogloss:0.79915
[5] validation_0-merror:0.00486 validation_0-mlogloss:0.30346   validation_1-merror:0.16348 validation_1-mlogloss:0.71425
[6] validation_0-merror:0.00162 validation_0-mlogloss:0.23622   validation_1-merror:0.15711 validation_1-mlogloss:0.64952
[7] validation_0-merror:0.00108 validation_0-mlogloss:0.18595   validation_1-merror:0.16136 validation_1-mlogloss:0.60371
[8] validation_0-merror:0.00000 validation_0-mlogloss:0.14786   validation_1-merror:0.15287 validation_1-mlogloss:0.56304
[9] validation_0-merror:0.00000 validation_0-mlogloss:0.11881   validation_1-merror:0.15287 validation_1-mlogloss:0.53307
[10]    validation_0-merror:0.00000 validation_0-mlogloss:0.09605   validation_1-merror:0.14650 validation_1-mlogloss:0.50828
[11]    validation_0-merror:0.00000 validation_0-mlogloss:0.07801   validation_1-merror:0.14437 validation_1-mlogloss:0.48943
[12]    validation_0-merror:0.00000 validation_0-mlogloss:0.06405   validation_1-merror:0.14650 validation_1-mlogloss:0.47413
[13]    validation_0-merror:0.00000 validation_0-mlogloss:0.05343   validation_1-merror:0.14225 validation_1-mlogloss:0.46252
[14]    validation_0-merror:0.00000 validation_0-mlogloss:0.04508   validation_1-merror:0.14013 validation_1-mlogloss:0.45311
[15]    validation_0-merror:0.00000 validation_0-mlogloss:0.03852   validation_1-merror:0.13376 validation_1-mlogloss:0.44334
[16]    validation_0-merror:0.00000 validation_0-mlogloss:0.03314   validation_1-merror:0.13588 validation_1-mlogloss:0.43681
[17]    validation_0-merror:0.00000 validation_0-mlogloss:0.02886   validation_1-merror:0.13376 validation_1-mlogloss:0.43222
[18]    validation_0-merror:0.00000 validation_0-mlogloss:0.02532   validation_1-merror:0.13376 validation_1-mlogloss:0.42889
[19]    validation_0-merror:0.00000 validation_0-mlogloss:0.02246   validation_1-merror:0.13376 validation_1-mlogloss:0.42132
[20]    validation_0-merror:0.00000 validation_0-mlogloss:0.02018   validation_1-merror:0.13163 validation_1-mlogloss:0.42073
[21]    validation_0-merror:0.00000 validation_0-mlogloss:0.01824   validation_1-merror:0.13800 validation_1-mlogloss:0.41866
[22]    validation_0-merror:0.00000 validation_0-mlogloss:0.01658   validation_1-merror:0.13163 validation_1-mlogloss:0.41665
[23]    validation_0-merror:0.00000 validation_0-mlogloss:0.01526   validation_1-merror:0.13376 validation_1-mlogloss:0.41581
[24]    validation_0-merror:0.00000 validation_0-mlogloss:0.01408   validation_1-merror:0.13163 validation_1-mlogloss:0.41359
[25]    validation_0-merror:0.00000 validation_0-mlogloss:0.01306   validation_1-merror:0.12739 validation_1-mlogloss:0.41323
[26]    validation_0-merror:0.00000 validation_0-mlogloss:0.01218   validation_1-merror:0.13376 validation_1-mlogloss:0.41286
[27]    validation_0-merror:0.00000 validation_0-mlogloss:0.01146   validation_1-merror:0.13163 validation_1-mlogloss:0.41128
[28]    validation_0-merror:0.00000 validation_0-mlogloss:0.01085   validation_1-merror:0.13163 validation_1-mlogloss:0.41166
[29]    validation_0-merror:0.00000 validation_0-mlogloss:0.01030   validation_1-merror:0.13163 validation_1-mlogloss:0.41150
[30]    validation_0-merror:0.00000 validation_0-mlogloss:0.00985   validation_1-merror:0.13163 validation_1-mlogloss:0.41179
[31]    validation_0-merror:0.00000 validation_0-mlogloss:0.00939   validation_1-merror:0.13376 validation_1-mlogloss:0.41202
[32]    validation_0-merror:0.00000 validation_0-mlogloss:0.00902   validation_1-merror:0.13163 validation_1-mlogloss:0.41252
[33]    validation_0-merror:0.00000 validation_0-mlogloss:0.00867   validation_1-merror:0.13376 validation_1-mlogloss:0.41329
[34]    validation_0-merror:0.00000 validation_0-mlogloss:0.00837   validation_1-merror:0.13376 validation_1-mlogloss:0.41391
[35]    validation_0-merror:0.00000 validation_0-mlogloss:0.00811   validation_1-merror:0.13588 validation_1-mlogloss:0.41318
[36]    validation_0-merror:0.00000 validation_0-mlogloss:0.00786   validation_1-merror:0.13588 validation_1-mlogloss:0.41324
[37]    validation_0-merror:0.00000 validation_0-mlogloss:0.00760   validation_1-merror:0.13588 validation_1-mlogloss:0.41341

2.2 Comparison

# XGBoost
human_xg_scVI = xgb.XGBClassifier()
human_xg_scVI.load_model("../results/02_human_integration/11_scVI_xgboost.json")

# scANVI
human_scANVI = scvi.model.SCANVI.load("../results/02_human_integration/05_scanvi_ns15/")
INFO     File ../results/02_human_integration/05_scanvi_ns15/model.pt already downloaded                           
human_stats = cripple_classifiers(human, human_scANVI, human_xg_scVI, groupby='C_scANVI_nsamples')
  0%|                                                                                                              | 0/7 [00:00<?, ?it/s]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 14%|██████████████▌                                                                                       | 1/7 [00:22<02:13, 22.20s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 29%|█████████████████████████████▏                                                                        | 2/7 [00:46<01:56, 23.34s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 43%|███████████████████████████████████████████▋                                                          | 3/7 [01:10<01:34, 23.63s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 57%|██████████████████████████████████████████████████████████▎                                           | 4/7 [01:33<01:10, 23.61s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 71%|████████████████████████████████████████████████████████████████████████▊                             | 5/7 [01:57<00:47, 23.62s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
 86%|███████████████████████████████████████████████████████████████████████████████████████▍              | 6/7 [02:21<00:23, 23.77s/it]GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]
`Trainer.fit` stopped: `max_epochs=100` reached.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:44<00:00, 23.53s/it]
INFO     Found 100.0% reference vars in query data.                                                                
INFO     Training for 100 epochs.                                                                                  
INFO     Found 99.66666666666667% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 99.33333333333333% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 98.33333333333333% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 96.66666666666667% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 93.33333333333333% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
INFO     Found 83.33333333333334% reference vars in query data.                                                    
INFO     Training for 100 epochs.                                                                                  
fig, ax = plt.subplots(2, 2, figsize=[10, 6], sharey=True, sharex=True)

pd.DataFrame.from_dict(mouse_stats['XGBoost']).plot.line(marker='x', ylim=(0,1), title='Mouse [XGBoost]', legend=None, ax=ax[0, 0])
pd.DataFrame.from_dict(mouse_stats['scANVI']).plot.line(marker='x', ylim=(0,1), title='Mouse [scANVI]', legend=None, ax=ax[0, 1])
pd.DataFrame.from_dict(human_stats['XGBoost']).plot.line(marker='x', ylim=(0,1), title='Human [XGBoost]', legend=None, ax=ax[1, 0])
pd.DataFrame.from_dict(human_stats['scANVI']).plot.line(marker='x', ylim=(0,1), title='Human [scANVI]', legend=None, ax=ax[1, 1])

for i in range(2):
    ax[1, i].set_xticklabels(['', '0', '10', '20', '50', '100', '200', '500'])
    ax[0, i].axhline(0.5, c='r', ls='--')
    ax[1, i].axhline(0.5, c='r', ls='--')

ax[0, 1].legend(('Accuracy','Bal. accuracy', 'F1 (micro)', 'F1 (macro)'))

fig.supxlabel('Number of dropouts')
fig.supylabel('Score')
fig.tight_layout()
fig.savefig('../figures/xgboost_vs_scanvi_v1.svg')

fig, ax = plt.subplots(1, 4, figsize=[14, 3.5], sharey=True, sharex=True)

pd.DataFrame.from_dict(mouse_stats['XGBoost']).plot.line(marker='x', ylim=(-0.1,1), title='Mouse [XGBoost]', legend=None, ax=ax[0])
pd.DataFrame.from_dict(mouse_stats['scANVI']).plot.line(marker='x', ylim=(-0.1,1), title='Mouse [scANVI]', legend=None, ax=ax[1])
pd.DataFrame.from_dict(human_stats['XGBoost']).plot.line(marker='x', ylim=(-0.1,1), title='Human [XGBoost]', legend=None, ax=ax[2])
pd.DataFrame.from_dict(human_stats['scANVI']).plot.line(marker='x', ylim=(-0.1,1), title='Human [scANVI]', legend=None, ax=ax[3])

for i in range(4):
    ax[i].set_xticklabels(['', '0', '10', '20', '50', '100', '200', '500'])
    ax[i].axhline(0.5, c='r', ls='--')

# ax[3].legend(('Accuracy','Bal. accuracy', 'F1 (micro)', 'F1 (macro)'), ncol=4)

fig.supxlabel('Number of dropouts')
fig.supylabel('Score')
fig.tight_layout()
fig.savefig('../figures/xgboost_vs_scanvi_v2.svg')

pd.concat([
    pd.DataFrame.from_dict(mouse_stats['XGBoost']).assign(clf = 'XGBoost').assign(species = 'mouse'),
    pd.DataFrame.from_dict(mouse_stats['scANVI']).assign(clf = 'scANVI').assign(species = 'mouse'),
    pd.DataFrame.from_dict(human_stats['XGBoost']).assign(clf = 'XGBoost').assign(species = 'human'),
    pd.DataFrame.from_dict(human_stats['scANVI']).assign(clf = 'scANVI').assign(species = 'human')
]).to_excel('../results/suppl-tab-3.xlsx')