!which pip/projects/dan1/data/Brickman/conda/envs/scvi-1.0.0/bin/pip
In this notebook we construct multiple classifier which will try to predict cell type (ct):
!which pip/projects/dan1/data/Brickman/conda/envs/scvi-1.0.0/bin/pip
%matplotlib inline
import scvi
import scgen
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from typing import Tuple
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings
warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)
warnings.simplefilter('ignore', category=FutureWarning)
warnings.simplefilter('ignore', category=UserWarning)
scvi.settings.seed = 0
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'/projects/dan1/data/Brickman/conda/envs/scvi-1.0.0/lib/python3.10/site-packages/scvi/_settings.py:63: UserWarning: Since v1.0.0, scvi-tools no longer uses a random seed by default. Run `scvi.settings.seed = 0` to reproduce results from previous versions.
self.seed = seed
/projects/dan1/data/Brickman/conda/envs/scvi-1.0.0/lib/python3.10/site-packages/scvi/_settings.py:70: UserWarning: Setting `dl_pin_memory_gpu_training` is deprecated in v1.0 and will be removed in v1.1. Please pass in `pin_memory` to the data loaders instead.
self.dl_pin_memory_gpu_training = (
/projects/dan1/data/Brickman/conda/envs/scvi-1.0.0/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
[rank: 0] Global seed set to 0
%run ../scripts/helpers.pyfrom sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
def prediction_stats(label: str, model, y_test, y_pred, X, Y):
stats = [
label,
accuracy_score(y_test, y_pred),
balanced_accuracy_score(y_test, y_pred),
f1_score(y_test, y_pred, average="micro"),
f1_score(y_test, y_pred, average="macro"),
np.nan
]
if model is None:
return stats
# Cross Validation
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
scores = cross_val_score(model, X, Y, cv = kfold, n_jobs = 8)
stats[-1] = scores.mean()
return stats
stats = []lvae = scvi.model.SCANVI.load("../results/02_mouse_integration/scanvi/")INFO File ../results/02_mouse_integration/scanvi/model.pt already downloaded
y_test = lvae.adata.obs.ct.values
y_pred = lvae.predict(lvae.adata)
stats.append(prediction_stats("scANVI", None, y_test, y_pred, None, None))n_samples_per_label=15)vae = scvi.model.SCVI.load("../results/02_mouse_integration/scvi/")
lvae = scvi.model.SCANVI.from_scvi_model(vae, labels_key="ct", unlabeled_category="Unknown")
lvae.train(max_epochs=20, n_samples_per_label=15)
lvae.save("../results/02_mouse_integration/scanvi_ns_15/", overwrite=True, save_anndata=True)INFO File ../results/02_mouse_integration/scvi/model.pt already downloaded
INFO Training for 20 epochs.
Epoch 20/20: 100%|███████████████████████████| 20/20 [00:06<00:00, 3.11it/s, v_num=1, train_loss_step=5.75e+3, train_loss_epoch=5.12e+3]Epoch 20/20: 100%|███████████████████████████| 20/20 [00:06<00:00, 3.15it/s, v_num=1, train_loss_step=5.75e+3, train_loss_epoch=5.12e+3]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2]
`Trainer.fit` stopped: `max_epochs=20` reached.
y_test = lvae.adata.obs.ct.values
y_pred = lvae.predict()
stats.append(prediction_stats("scANVI_ns15", None, y_test, y_pred, None, None))vae = scvi.model.SCVI.load("../results/02_mouse_integration/scvi/")INFO File ../results/02_mouse_integration/scvi/model.pt already downloaded
vae_df = pd.DataFrame(vae.get_normalized_expression(return_mean=True))
vae_df['target'] = vae.adata.obs.ct
vae_df.head()| sox17 | ppp1r42 | arfgef1 | prdm14 | xkr9 | msc | ube2w | gm7654 | tmem70 | ly96 | ... | stn1 | gsto1 | 1700054a03rik | gm50273 | habp2 | ccdc186 | afap1l2 | pnlip | pnliprp2 | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| SRX259148 | 9.244410e-09 | 1.495744e-07 | 0.000106 | 0.000002 | 0.000030 | 2.020811e-04 | 0.000375 | 1.544236e-09 | 0.000513 | 0.000280 | ... | 0.000278 | 0.004330 | 7.987154e-08 | 3.248621e-10 | 5.573803e-09 | 0.000101 | 5.110550e-08 | 1.434869e-09 | 4.117884e-05 | 16C |
| SRX259191 | 2.290104e-07 | 2.459044e-09 | 0.000065 | 0.000006 | 0.000080 | 4.674523e-05 | 0.000086 | 4.516714e-07 | 0.000462 | 0.000328 | ... | 0.000394 | 0.001819 | 6.749668e-11 | 1.539256e-09 | 6.671179e-09 | 0.000118 | 2.692376e-07 | 8.009134e-07 | 9.071199e-05 | 8C |
| SRX259121 | 9.530742e-11 | 1.176371e-06 | 0.000106 | 0.000005 | 0.000014 | 4.688524e-04 | 0.000103 | 7.302969e-12 | 0.000299 | 0.000452 | ... | 0.000032 | 0.001475 | 1.854037e-07 | 1.135729e-09 | 1.288587e-10 | 0.000048 | 1.281698e-08 | 7.423926e-10 | 6.965978e-06 | 16C |
| SRX259140 | 2.935074e-08 | 1.263999e-08 | 0.000069 | 0.000011 | 0.000058 | 8.229597e-05 | 0.000256 | 1.905268e-08 | 0.000314 | 0.000802 | ... | 0.000497 | 0.003291 | 1.695690e-09 | 4.987342e-10 | 1.461652e-08 | 0.000127 | 1.668044e-07 | 5.001769e-08 | 9.694348e-05 | 16C |
| SRX259161 | 4.501800e-08 | 3.915447e-09 | 0.000240 | 0.000057 | 0.000007 | 3.080844e-07 | 0.000202 | 1.228954e-06 | 0.000794 | 0.000154 | ... | 0.000211 | 0.000063 | 4.561012e-12 | 9.231247e-08 | 3.405940e-10 | 0.000431 | 3.755741e-05 | 2.560960e-06 | 7.125635e-07 | 4C |
5 rows × 3001 columns
X_train, y_train, X_test, y_test = train_test_split_by_group(vae_df)vae_xgboost = train_xgboost(vae_df, X_train, y_train, X_test, y_test)[0] validation_0-merror:0.03004 validation_0-mlogloss:1.09462 validation_1-merror:0.08128 validation_1-mlogloss:1.19250
[1] validation_0-merror:0.01877 validation_0-mlogloss:0.77131 validation_1-merror:0.05419 validation_1-mlogloss:0.88954
[2] validation_0-merror:0.01690 validation_0-mlogloss:0.56525 validation_1-merror:0.05172 validation_1-mlogloss:0.69790
[3] validation_0-merror:0.01189 validation_0-mlogloss:0.42255 validation_1-merror:0.04433 validation_1-mlogloss:0.56222
[4] validation_0-merror:0.00876 validation_0-mlogloss:0.32044 validation_1-merror:0.04680 validation_1-mlogloss:0.46083
[5] validation_0-merror:0.00626 validation_0-mlogloss:0.24549 validation_1-merror:0.04926 validation_1-mlogloss:0.38825
[6] validation_0-merror:0.00438 validation_0-mlogloss:0.18819 validation_1-merror:0.04926 validation_1-mlogloss:0.33157
[7] validation_0-merror:0.00375 validation_0-mlogloss:0.14573 validation_1-merror:0.04926 validation_1-mlogloss:0.29007
[8] validation_0-merror:0.00188 validation_0-mlogloss:0.11385 validation_1-merror:0.04926 validation_1-mlogloss:0.25384
[9] validation_0-merror:0.00125 validation_0-mlogloss:0.08912 validation_1-merror:0.05172 validation_1-mlogloss:0.22724
[10] validation_0-merror:0.00125 validation_0-mlogloss:0.07097 validation_1-merror:0.04926 validation_1-mlogloss:0.20505
[11] validation_0-merror:0.00000 validation_0-mlogloss:0.05642 validation_1-merror:0.04433 validation_1-mlogloss:0.18941
[12] validation_0-merror:0.00000 validation_0-mlogloss:0.04556 validation_1-merror:0.04187 validation_1-mlogloss:0.17663
[13] validation_0-merror:0.00000 validation_0-mlogloss:0.03727 validation_1-merror:0.04187 validation_1-mlogloss:0.16720
[14] validation_0-merror:0.00000 validation_0-mlogloss:0.03082 validation_1-merror:0.04187 validation_1-mlogloss:0.16092
[15] validation_0-merror:0.00000 validation_0-mlogloss:0.02581 validation_1-merror:0.03941 validation_1-mlogloss:0.15553
[16] validation_0-merror:0.00000 validation_0-mlogloss:0.02199 validation_1-merror:0.03941 validation_1-mlogloss:0.14887
[17] validation_0-merror:0.00000 validation_0-mlogloss:0.01899 validation_1-merror:0.03941 validation_1-mlogloss:0.14508
[18] validation_0-merror:0.00000 validation_0-mlogloss:0.01656 validation_1-merror:0.03941 validation_1-mlogloss:0.14449
[19] validation_0-merror:0.00000 validation_0-mlogloss:0.01463 validation_1-merror:0.03941 validation_1-mlogloss:0.14215
[20] validation_0-merror:0.00000 validation_0-mlogloss:0.01305 validation_1-merror:0.03695 validation_1-mlogloss:0.14105
[21] validation_0-merror:0.00000 validation_0-mlogloss:0.01174 validation_1-merror:0.03941 validation_1-mlogloss:0.14084
[22] validation_0-merror:0.00000 validation_0-mlogloss:0.01067 validation_1-merror:0.03941 validation_1-mlogloss:0.14050
[23] validation_0-merror:0.00000 validation_0-mlogloss:0.00981 validation_1-merror:0.03941 validation_1-mlogloss:0.14013
[24] validation_0-merror:0.00000 validation_0-mlogloss:0.00909 validation_1-merror:0.03941 validation_1-mlogloss:0.14064
[25] validation_0-merror:0.00000 validation_0-mlogloss:0.00851 validation_1-merror:0.03941 validation_1-mlogloss:0.14045
[26] validation_0-merror:0.00000 validation_0-mlogloss:0.00804 validation_1-merror:0.03941 validation_1-mlogloss:0.14129
[27] validation_0-merror:0.00000 validation_0-mlogloss:0.00768 validation_1-merror:0.03941 validation_1-mlogloss:0.14126
[28] validation_0-merror:0.00000 validation_0-mlogloss:0.00735 validation_1-merror:0.03941 validation_1-mlogloss:0.14149
[29] validation_0-merror:0.00000 validation_0-mlogloss:0.00707 validation_1-merror:0.03941 validation_1-mlogloss:0.14218
[30] validation_0-merror:0.00000 validation_0-mlogloss:0.00686 validation_1-merror:0.03941 validation_1-mlogloss:0.14375
[31] validation_0-merror:0.00000 validation_0-mlogloss:0.00665 validation_1-merror:0.03941 validation_1-mlogloss:0.14360
[32] validation_0-merror:0.00000 validation_0-mlogloss:0.00648 validation_1-merror:0.03941 validation_1-mlogloss:0.14413

y_pred = vae_xgboost.predict(X_test)stats.append(prediction_stats("XGB_scVI", vae_xgboost, y_test, y_pred, vae_df.values[:, :-1], vae_df['target'].cat.codes.to_numpy()))vae_xgboost.save_model("../results/02_mouse_integration/05_scVI_xgboost.json")lvae = scvi.model.SCANVI.load("../results/02_mouse_integration/scanvi/")INFO File ../results/02_mouse_integration/scanvi/model.pt already downloaded
lvae_df = pd.DataFrame(lvae.get_normalized_expression(return_mean=True))
lvae_df['target'] = lvae.adata.obs.ct
lvae_df.head()| sox17 | ppp1r42 | arfgef1 | prdm14 | xkr9 | msc | ube2w | gm7654 | tmem70 | ly96 | ... | stn1 | gsto1 | 1700054a03rik | gm50273 | habp2 | ccdc186 | afap1l2 | pnlip | pnliprp2 | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| SRX259148 | 1.499321e-08 | 7.591232e-08 | 0.000068 | 0.000003 | 0.000065 | 0.000326 | 0.000131 | 1.483101e-08 | 0.000298 | 0.000150 | ... | 0.000185 | 0.005562 | 8.425851e-07 | 4.370481e-09 | 1.569241e-09 | 0.000072 | 4.824862e-08 | 1.690867e-09 | 0.000051 | 16C |
| SRX259191 | 1.535953e-07 | 3.323944e-08 | 0.000084 | 0.000012 | 0.000095 | 0.000061 | 0.000062 | 1.126642e-06 | 0.000531 | 0.000390 | ... | 0.000336 | 0.002567 | 1.206273e-08 | 3.697671e-08 | 8.405232e-09 | 0.000134 | 1.122398e-06 | 5.098362e-07 | 0.000258 | 8C |
| SRX259121 | 2.234061e-10 | 1.818486e-07 | 0.000059 | 0.000002 | 0.000023 | 0.000644 | 0.000123 | 3.998329e-11 | 0.000119 | 0.000196 | ... | 0.000069 | 0.001645 | 4.700184e-07 | 2.258681e-09 | 1.102188e-10 | 0.000037 | 5.974784e-09 | 1.568855e-09 | 0.000009 | 16C |
| SRX259140 | 8.150251e-08 | 4.920110e-08 | 0.000055 | 0.000010 | 0.000060 | 0.000158 | 0.000128 | 1.852338e-07 | 0.000286 | 0.000275 | ... | 0.000263 | 0.004427 | 1.616031e-07 | 1.186041e-08 | 5.657514e-09 | 0.000073 | 3.052841e-07 | 3.905111e-08 | 0.000125 | 16C |
| SRX259161 | 2.039545e-07 | 1.044988e-07 | 0.000329 | 0.000031 | 0.000166 | 0.000001 | 0.000171 | 3.242764e-05 | 0.000807 | 0.000285 | ... | 0.000279 | 0.000124 | 4.134634e-10 | 1.175266e-07 | 7.103967e-09 | 0.000540 | 9.310104e-05 | 2.019131e-06 | 0.000002 | 4C |
5 rows × 3001 columns
X_train, y_train, X_test, y_test = train_test_split_by_group(lvae_df)lvae_xgboost = train_xgboost(lvae_df, X_train, y_train, X_test, y_test)[0] validation_0-merror:0.03504 validation_0-mlogloss:1.11435 validation_1-merror:0.16502 validation_1-mlogloss:1.35171
[1] validation_0-merror:0.02566 validation_0-mlogloss:0.78891 validation_1-merror:0.12315 validation_1-mlogloss:1.04526
[2] validation_0-merror:0.01815 validation_0-mlogloss:0.58309 validation_1-merror:0.11084 validation_1-mlogloss:0.85486
[3] validation_0-merror:0.01314 validation_0-mlogloss:0.43796 validation_1-merror:0.09852 validation_1-mlogloss:0.72215
[4] validation_0-merror:0.01064 validation_0-mlogloss:0.33298 validation_1-merror:0.10099 validation_1-mlogloss:0.62165
[5] validation_0-merror:0.00688 validation_0-mlogloss:0.25603 validation_1-merror:0.09852 validation_1-mlogloss:0.54944
[6] validation_0-merror:0.00501 validation_0-mlogloss:0.19739 validation_1-merror:0.09606 validation_1-mlogloss:0.49209
[7] validation_0-merror:0.00250 validation_0-mlogloss:0.15305 validation_1-merror:0.09852 validation_1-mlogloss:0.44675
[8] validation_0-merror:0.00250 validation_0-mlogloss:0.12090 validation_1-merror:0.09606 validation_1-mlogloss:0.41005
[9] validation_0-merror:0.00188 validation_0-mlogloss:0.09542 validation_1-merror:0.09852 validation_1-mlogloss:0.38280
[10] validation_0-merror:0.00125 validation_0-mlogloss:0.07578 validation_1-merror:0.10099 validation_1-mlogloss:0.36633
[11] validation_0-merror:0.00000 validation_0-mlogloss:0.06099 validation_1-merror:0.10099 validation_1-mlogloss:0.35114
[12] validation_0-merror:0.00000 validation_0-mlogloss:0.04939 validation_1-merror:0.09852 validation_1-mlogloss:0.34063
[13] validation_0-merror:0.00000 validation_0-mlogloss:0.04041 validation_1-merror:0.10345 validation_1-mlogloss:0.33141
[14] validation_0-merror:0.00000 validation_0-mlogloss:0.03370 validation_1-merror:0.10099 validation_1-mlogloss:0.32415
[15] validation_0-merror:0.00000 validation_0-mlogloss:0.02832 validation_1-merror:0.10099 validation_1-mlogloss:0.31757
[16] validation_0-merror:0.00000 validation_0-mlogloss:0.02422 validation_1-merror:0.09852 validation_1-mlogloss:0.30989
[17] validation_0-merror:0.00000 validation_0-mlogloss:0.02089 validation_1-merror:0.09852 validation_1-mlogloss:0.30725
[18] validation_0-merror:0.00000 validation_0-mlogloss:0.01818 validation_1-merror:0.10099 validation_1-mlogloss:0.30426
[19] validation_0-merror:0.00000 validation_0-mlogloss:0.01608 validation_1-merror:0.09852 validation_1-mlogloss:0.30231
[20] validation_0-merror:0.00000 validation_0-mlogloss:0.01431 validation_1-merror:0.09852 validation_1-mlogloss:0.29995
[21] validation_0-merror:0.00000 validation_0-mlogloss:0.01292 validation_1-merror:0.09606 validation_1-mlogloss:0.29890
[22] validation_0-merror:0.00000 validation_0-mlogloss:0.01174 validation_1-merror:0.09852 validation_1-mlogloss:0.29805
[23] validation_0-merror:0.00000 validation_0-mlogloss:0.01074 validation_1-merror:0.09606 validation_1-mlogloss:0.29816
[24] validation_0-merror:0.00000 validation_0-mlogloss:0.00989 validation_1-merror:0.09606 validation_1-mlogloss:0.29815
[25] validation_0-merror:0.00000 validation_0-mlogloss:0.00927 validation_1-merror:0.09852 validation_1-mlogloss:0.29669
[26] validation_0-merror:0.00000 validation_0-mlogloss:0.00875 validation_1-merror:0.09852 validation_1-mlogloss:0.29431
[27] validation_0-merror:0.00000 validation_0-mlogloss:0.00831 validation_1-merror:0.09852 validation_1-mlogloss:0.29328
[28] validation_0-merror:0.00000 validation_0-mlogloss:0.00794 validation_1-merror:0.10099 validation_1-mlogloss:0.29464
[29] validation_0-merror:0.00000 validation_0-mlogloss:0.00760 validation_1-merror:0.09852 validation_1-mlogloss:0.29632
[30] validation_0-merror:0.00000 validation_0-mlogloss:0.00732 validation_1-merror:0.09852 validation_1-mlogloss:0.29805
[31] validation_0-merror:0.00000 validation_0-mlogloss:0.00708 validation_1-merror:0.10099 validation_1-mlogloss:0.29762
[32] validation_0-merror:0.00000 validation_0-mlogloss:0.00684 validation_1-merror:0.10099 validation_1-mlogloss:0.29781
[33] validation_0-merror:0.00000 validation_0-mlogloss:0.00666 validation_1-merror:0.10099 validation_1-mlogloss:0.29910
[34] validation_0-merror:0.00000 validation_0-mlogloss:0.00649 validation_1-merror:0.10099 validation_1-mlogloss:0.29847
[35] validation_0-merror:0.00000 validation_0-mlogloss:0.00634 validation_1-merror:0.10099 validation_1-mlogloss:0.29821
[36] validation_0-merror:0.00000 validation_0-mlogloss:0.00619 validation_1-merror:0.10099 validation_1-mlogloss:0.29706

y_pred = lvae_xgboost.predict(X_test)stats.append(prediction_stats("XGB_scANVI", lvae_xgboost, y_test, y_pred, lvae_df.values[:, :-1], lvae_df['target'].cat.codes.to_numpy()))lvae_xgboost.save_model("../results/02_mouse_integration/05_scANVI_xgboost.json")mscgen = scgen.SCGEN.load("../results/02_mouse_integration/scgen/")INFO File ../results/02_mouse_integration/scgen/model.pt already downloaded
mscgen_df = pd.DataFrame(mscgen.get_decoded_expression(),
index=mscgen.adata.obs_names, columns=mscgen.adata.var_names)
mscgen_df['target'] = mscgen.adata.obs['ct']
mscgen_df.head()| sox17 | ppp1r42 | arfgef1 | prdm14 | xkr9 | msc | ube2w | gm7654 | tmem70 | ly96 | ... | stn1 | gsto1 | 1700054a03rik | gm50273 | habp2 | ccdc186 | afap1l2 | pnlip | pnliprp2 | target | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| SRX259148 | 0.129524 | -0.009126 | 0.580587 | -0.027672 | 0.165227 | 0.593042 | 0.868470 | -0.005957 | 1.724664 | 1.235015 | ... | 1.047719 | 3.994854 | 0.026034 | 0.019212 | -0.022841 | 0.752102 | 0.036959 | 0.001139 | 0.642708 | 16C |
| SRX259191 | 0.042759 | -0.009889 | 0.608965 | 0.035671 | 0.257579 | 0.512698 | 0.708813 | -0.007634 | 1.550365 | 1.364577 | ... | 1.072520 | 3.394770 | 0.016012 | 0.020248 | -0.048329 | 0.858443 | 0.036058 | 0.028643 | 0.485618 | 8C |
| SRX259121 | -0.062429 | -0.002099 | 0.516726 | -0.141183 | 0.291033 | 0.251529 | 0.438559 | -0.008837 | 1.417975 | 1.416850 | ... | 0.471408 | 2.301015 | 0.038988 | 0.019344 | 0.005064 | 0.552338 | -0.009445 | -0.026622 | 0.361133 | 16C |
| SRX259140 | -0.067408 | -0.014072 | 0.476634 | 0.038267 | 0.143380 | 0.544186 | 0.797441 | -0.009345 | 1.724418 | 1.436977 | ... | 0.957134 | 3.538107 | 0.023010 | 0.025978 | -0.022783 | 0.754117 | 0.064578 | 0.004726 | 0.577420 | 16C |
| SRX259161 | 0.009253 | 0.033485 | 1.305035 | 0.533173 | -0.025159 | 0.184909 | 1.239522 | 0.022633 | 2.215558 | 1.189655 | ... | 1.302740 | 1.088298 | 0.003135 | 0.037767 | 0.004102 | 1.716617 | 0.146538 | 0.081518 | 0.166842 | 4C |
5 rows × 3001 columns
X_train, y_train, X_test, y_test = train_test_split_by_group(mscgen_df)mscgen_xgboost = train_xgboost(mscgen_df, X_train, y_train, X_test, y_test)[0] validation_0-merror:0.02753 validation_0-mlogloss:1.10220 validation_1-merror:0.07882 validation_1-mlogloss:1.17242
[1] validation_0-merror:0.01815 validation_0-mlogloss:0.77531 validation_1-merror:0.07389 validation_1-mlogloss:0.88407
[2] validation_0-merror:0.01252 validation_0-mlogloss:0.56490 validation_1-merror:0.06650 validation_1-mlogloss:0.69409
[3] validation_0-merror:0.00876 validation_0-mlogloss:0.42008 validation_1-merror:0.06404 validation_1-mlogloss:0.56456
[4] validation_0-merror:0.00688 validation_0-mlogloss:0.31672 validation_1-merror:0.06158 validation_1-mlogloss:0.46782
[5] validation_0-merror:0.00501 validation_0-mlogloss:0.24033 validation_1-merror:0.06158 validation_1-mlogloss:0.40182
[6] validation_0-merror:0.00375 validation_0-mlogloss:0.18417 validation_1-merror:0.06404 validation_1-mlogloss:0.35392
[7] validation_0-merror:0.00125 validation_0-mlogloss:0.14143 validation_1-merror:0.06158 validation_1-mlogloss:0.31127
[8] validation_0-merror:0.00063 validation_0-mlogloss:0.10950 validation_1-merror:0.05911 validation_1-mlogloss:0.28430
[9] validation_0-merror:0.00063 validation_0-mlogloss:0.08584 validation_1-merror:0.05419 validation_1-mlogloss:0.26294
[10] validation_0-merror:0.00000 validation_0-mlogloss:0.06750 validation_1-merror:0.05419 validation_1-mlogloss:0.24574
[11] validation_0-merror:0.00000 validation_0-mlogloss:0.05349 validation_1-merror:0.05419 validation_1-mlogloss:0.23149
[12] validation_0-merror:0.00000 validation_0-mlogloss:0.04296 validation_1-merror:0.05172 validation_1-mlogloss:0.22194
[13] validation_0-merror:0.00000 validation_0-mlogloss:0.03500 validation_1-merror:0.05419 validation_1-mlogloss:0.21303
[14] validation_0-merror:0.00000 validation_0-mlogloss:0.02891 validation_1-merror:0.05172 validation_1-mlogloss:0.20760
[15] validation_0-merror:0.00000 validation_0-mlogloss:0.02419 validation_1-merror:0.05172 validation_1-mlogloss:0.20254
[16] validation_0-merror:0.00000 validation_0-mlogloss:0.02050 validation_1-merror:0.05419 validation_1-mlogloss:0.20035
[17] validation_0-merror:0.00000 validation_0-mlogloss:0.01761 validation_1-merror:0.05419 validation_1-mlogloss:0.19680
[18] validation_0-merror:0.00000 validation_0-mlogloss:0.01529 validation_1-merror:0.05419 validation_1-mlogloss:0.19592
[19] validation_0-merror:0.00000 validation_0-mlogloss:0.01347 validation_1-merror:0.05419 validation_1-mlogloss:0.19497
[20] validation_0-merror:0.00000 validation_0-mlogloss:0.01202 validation_1-merror:0.05419 validation_1-mlogloss:0.19207
[21] validation_0-merror:0.00000 validation_0-mlogloss:0.01080 validation_1-merror:0.05419 validation_1-mlogloss:0.19109
[22] validation_0-merror:0.00000 validation_0-mlogloss:0.00985 validation_1-merror:0.05419 validation_1-mlogloss:0.19257
[23] validation_0-merror:0.00000 validation_0-mlogloss:0.00908 validation_1-merror:0.05419 validation_1-mlogloss:0.19313
[24] validation_0-merror:0.00000 validation_0-mlogloss:0.00847 validation_1-merror:0.05419 validation_1-mlogloss:0.19265
[25] validation_0-merror:0.00000 validation_0-mlogloss:0.00794 validation_1-merror:0.05419 validation_1-mlogloss:0.19297
[26] validation_0-merror:0.00000 validation_0-mlogloss:0.00753 validation_1-merror:0.05419 validation_1-mlogloss:0.19289
[27] validation_0-merror:0.00000 validation_0-mlogloss:0.00718 validation_1-merror:0.05419 validation_1-mlogloss:0.19182
[28] validation_0-merror:0.00000 validation_0-mlogloss:0.00689 validation_1-merror:0.05419 validation_1-mlogloss:0.19347
[29] validation_0-merror:0.00000 validation_0-mlogloss:0.00664 validation_1-merror:0.05419 validation_1-mlogloss:0.19384
[30] validation_0-merror:0.00000 validation_0-mlogloss:0.00646 validation_1-merror:0.05419 validation_1-mlogloss:0.19353

stats.append(prediction_stats("XGB_scGEN", mscgen_xgboost, y_test, y_pred, mscgen_df.values[:, :-1], mscgen_df['target'].cat.codes.to_numpy()))mscgen_xgboost.save_model("../results/02_mouse_integration/05_scGEN_xgboost.json")stats_df = pd.DataFrame(stats, columns=['method', 'accuracy', 'balanced_accuracy', 'f1_micro', 'f1_macro', 'cv_accuracy']).set_index('method')
stats_df.to_csv("../results/05_mouse_classifier_stats.csv")
stats_df| accuracy | balanced_accuracy | f1_micro | f1_macro | cv_accuracy | |
|---|---|---|---|---|---|
| method | |||||
| scANVI | 0.830339 | 0.649818 | 0.830339 | 0.634290 | NaN |
| scANVI_ns15 | 0.793413 | 0.879503 | 0.793413 | 0.777624 | NaN |
| XGB_scVI | 0.960591 | 0.963041 | 0.960591 | 0.967656 | 0.935637 |
| XGB_scANVI | 0.901478 | 0.917235 | 0.901478 | 0.923392 | 0.920664 |
| XGB_scGEN | 0.901478 | 0.917235 | 0.901478 | 0.923392 | 0.942602 |
stats_df.plot(kind='bar')
plt.gca().spines[['right', 'top']].set_visible(False)
plt.gca().legend(title='Metrics', bbox_to_anchor=(0.99, 1.02), loc='upper left')<matplotlib.legend.Legend at 0x7f01bfb91390>
