Geneformer Cell Type Classification Benchmark¶
Here we benchmark four models, with two baselines. These models are tasked with cell type classification, using the Crohn's disease small intestine dataset from Elmentaite et al. (2020), Developmental Cell. This dataset contains approximately 22,500 single cells from both healthy children aged 4-13 and children with Crohn's disease. This dataset contains 31 unique cell types which we assume to be annotated accurately. This dataset was held out of our pre-training dataset as all diseased samples were removed.
- Baseline (1) scRNA workflow: this model uses PCA with 10 components and random forest on normalized and log transformed expression counts to produce a result.
- Baseline (2) geneformer with random weight initialization. Some performance can come from large random projections, but we want to do better than that.
- geneformer-10M + geneformer106M as described in the model cards.
First, we download the dataset from czi that we are interested in, and then create the requisite sc_memmap dataset object.
#NBVAL_CHECK_OUTPUT
import cellxgene_census
CENSUS_VERSION = "2023-12-15"
with cellxgene_census.open_soma(census_version=CENSUS_VERSION) as census:
adata = cellxgene_census.get_anndata(census, "Homo sapiens",
obs_value_filter='dataset_id=="8e47ed12-c658-4252-b126-381df8d52a3d"',
)
uq_cells = sorted(adata.obs['cell_type'].unique().tolist())
uq_cells
['B cell', 'CD4-positive, alpha-beta T cell', 'CD8-positive, alpha-beta T cell', 'IgA plasma cell', 'IgG plasma cell', 'M cell of gut', 'T follicular helper cell', 'activated CD4-positive, alpha-beta T cell, human', 'conventional dendritic cell', 'dendritic cell, human', 'endothelial cell of artery', 'endothelial cell of lymphatic vessel', 'enterocyte', 'enteroendocrine cell', 'fibroblast', 'gamma-delta T cell', 'glial cell', 'intestinal crypt stem cell', 'intestinal tuft cell', 'intestine goblet cell', 'mast cell', 'memory B cell', 'monocyte', 'myeloid cell', 'myofibroblast cell', 'pericyte', 'plasma cell', 'plasmacytoid dendritic cell', 'regulatory T cell', 'transit amplifying cell', 'vein endothelial cell']
#NBVAL_CHECK_OUTPUT
import random
from contextlib import contextmanager
@contextmanager
def random_seed(seed:int):
state = random.getstate()
random.seed(seed)
try:
yield
finally:
# Go back to previous state
random.setstate(state)
with random_seed(32):
indices = list(range(len(adata)))
random.shuffle(indices)
micro_batch_size:int = 32
num_steps:int = 256
selection = sorted(indices[:micro_batch_size*num_steps])
# NOTE: there's a current constraint that predict_step needs to be a function of micro-batch-size.
# this is something we are working on fixing. A quick hack is to set micro-batch-size=1, but this is
# slow. In this notebook we are going to use mbs=32 and subsample the anndata.
adata = adata[selection].copy() # so it's not a view
adata.shape
(8192, 60664)
import shutil
from bionemo.core import BIONEMO_CACHE_DIR
cleanup:bool=True
notebook_workdir = BIONEMO_CACHE_DIR / "notebook_tutorials" / "geneformer_celltype_classification"
if cleanup and notebook_workdir.exists():
shutil.rmtree(notebook_workdir)
notebook_workdir.mkdir(parents=True, exist_ok=True)
input_dir = notebook_workdir / "celltype-bench-dataset-input"
data_dir = notebook_workdir / "celltype-bench-dataset"
input_dir.mkdir(parents=True, exist_ok=True)
h5ad_outfile = input_dir / "hs-celltype-bench.h5ad"
adata.write_h5ad(h5ad_outfile)
Create the scmemmap object, check outputs¶
!convert_h5ad_to_scdl --data-path {input_dir} --save-path {data_dir}
Importantly, the .npy files are used by BioNeMo dataset object. features.csv contains the metadata requested, in this case cell_type. It's important that the output of our model has the same order as features.csv, as this contains the labels used in the following benchmark.
#NBVAL_CHECK_OUTPUT
from glob import glob
files = sorted([f.split("/")[-1] for f in glob(str(data_dir/"*"))]) # strip off the directory name and sort for the test
files
['col_ptr.npy', 'data.npy', 'features', 'metadata.json', 'row_ptr.npy', 'version.json']
# NOTE: calling the load(...) function directly does not currently work for downloads through NGC in an interactive
# notebook environment. Get aound this below by calling the CLI download endpoint which executes in a subshell.
# 106m checkpoint
geneformer_106m_out = !download_bionemo_data "geneformer/106M_240530:2.0"
# 10m checkpoint
geneformer_10m_out = !download_bionemo_data "geneformer/10M_240530:2.0"
# 10m bionemo2 trained checkpoint
geneformer_10m_bnmo2_out = !download_bionemo_data "geneformer/10M_241113:2.0"
# Result includes a list of outputs, the last one is the path so grab that from each:
geneformer_106m = geneformer_106m_out[-1]
geneformer_10m = geneformer_10m_out[-1]
geneformer_10m_bnmo2 = geneformer_10m_bnmo2_out[-1]
result_path_10m = notebook_workdir / "results_10m.pt"
result_path_10m_bnmo2 = notebook_workdir / "results_10m_bnmo2.pt"
results_path_10m_random = notebook_workdir / "results_10m_randomweights.pt"
result_path_106m = notebook_workdir / "results_106m.pt"
Execute inference¶
We run inference on all there of our models, which are downloaded by load(...)
function in a previous cell. We have a one-off inference script for geneformer that is installed as part of the bionemo-geneformer
package. See the pyproject.toml
in the source directory if you are curious or want to use this as a template to make your own inference scripts. This script should work for any sc_memmap
converted geneformer dataset, and geneformer bionemo2 model checkpoint though.
!infer_geneformer \
--data-dir {data_dir} \
--checkpoint-path {geneformer_10m} \
--results-path {result_path_10m} \
--micro-batch-size {micro_batch_size} \
--seq-len 2048 \
--num-dataset-workers 10 \
--num-gpus 1 \
--include-input-ids
[NeMo W 2024-12-20 00:13:55 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work warn("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work", RuntimeWarning) [NeMo W 2024-12-20 00:13:55 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead. cm = get_cmap("Set1") [NeMo W 2024-12-20 00:13:56 ssm:31] The package `megatron.core` was not imported in this environment which is needed for SSMs. [NeMo W 2024-12-20 00:13:56 preprocess:101] Tokenizer vocab file: /workspace/bionemo2/.cache/bionemo/d8e3ea569bc43768c24aa651aff77722df202078415528497c22394046b08cc3-singlecell-scdltestdata-20241203.tar.gz.untar/cellxgene_2023-12-15_small_processed_scdl/train/geneformer.vocab already exists. Overwriting... [NeMo I 2024-12-20 00:13:56 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:13:56 remote:124] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl?download=true [NeMo I 2024-12-20 00:13:56 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:13:56 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:13:56 remote:124] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl?download=true [NeMo I 2024-12-20 00:13:56 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:13:56 infer_geneformer:83] *************** Preprocessing Finished ************ GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs [NeMo I 2024-12-20 00:13:56 megatron_strategy:315] Fixing mis-match between ddp-config & mcore-optimizer config [NeMo I 2024-12-20 00:13:56 megatron_init:396] Rank 0 has data parallel group : [0] [NeMo I 2024-12-20 00:13:56 megatron_init:402] Rank 0 has combined group of data parallel and context parallel : [0] [NeMo I 2024-12-20 00:13:56 megatron_init:407] All data parallel group ranks with context parallel combined: [[0]] [NeMo I 2024-12-20 00:13:56 megatron_init:410] Ranks 0 has data parallel rank: 0 [NeMo I 2024-12-20 00:13:56 megatron_init:418] Rank 0 has context parallel group: [0] [NeMo I 2024-12-20 00:13:56 megatron_init:421] All context parallel group ranks: [[0]] [NeMo I 2024-12-20 00:13:56 megatron_init:422] Ranks 0 has context parallel rank: 0 [NeMo I 2024-12-20 00:13:56 megatron_init:429] Rank 0 has model parallel group: [0] [NeMo I 2024-12-20 00:13:56 megatron_init:430] All model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:13:56 megatron_init:439] Rank 0 has tensor model parallel group: [0] [NeMo I 2024-12-20 00:13:56 megatron_init:443] All tensor model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:13:56 megatron_init:444] Rank 0 has tensor model parallel rank: 0 [NeMo I 2024-12-20 00:13:56 megatron_init:464] Rank 0 has pipeline model parallel group: [0] [NeMo I 2024-12-20 00:13:56 megatron_init:476] Rank 0 has embedding group: [0] [NeMo I 2024-12-20 00:13:56 megatron_init:482] All pipeline model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:13:56 megatron_init:483] Rank 0 has pipeline model parallel rank 0 [NeMo I 2024-12-20 00:13:56 megatron_init:484] All embedding group ranks: [[0]] [NeMo I 2024-12-20 00:13:56 megatron_init:485] Rank 0 has embedding rank: 0 Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1 ---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 1 processes ---------------------------------------------------------------------------------------------------- WARNING: Logging before flag parsing goes to stderr. W1220 00:13:57.369251 136514873635264 config.py:85] Loading /workspace/bionemo2/.cache/bionemo/a27061ee347f453b1bf175e288df31e9813903ebcb4924a77ac50dccc730889d-geneformer_10M_240530_nemo2.tar.gz.untar [NeMo I 2024-12-20 00:13:58 base:44] Padded vocab_size: 25472, original vocab_size: 25429, dummy tokens: 43. [NeMo W 2024-12-20 00:13:58 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py:25: UserWarning: 'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead. warnings.warn( LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [NeMo W 2024-12-20 00:13:58 megatron_strategy:329] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored. [NeMo I 2024-12-20 00:13:58 megatron_parallel:549] > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 10300032
!infer_geneformer \
--data-dir {data_dir} \
--checkpoint-path {geneformer_10m_bnmo2} \
--results-path {result_path_10m_bnmo2} \
--micro-batch-size {micro_batch_size} \
--seq-len 2048 \
--num-dataset-workers 10 \
--num-gpus 1 \
--include-input-ids
[NeMo W 2024-12-20 00:14:31 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work warn("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work", RuntimeWarning) [NeMo W 2024-12-20 00:14:31 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead. cm = get_cmap("Set1") [NeMo W 2024-12-20 00:14:32 ssm:31] The package `megatron.core` was not imported in this environment which is needed for SSMs. [NeMo W 2024-12-20 00:14:32 preprocess:101] Tokenizer vocab file: /workspace/bionemo2/.cache/bionemo/d8e3ea569bc43768c24aa651aff77722df202078415528497c22394046b08cc3-singlecell-scdltestdata-20241203.tar.gz.untar/cellxgene_2023-12-15_small_processed_scdl/train/geneformer.vocab already exists. Overwriting... [NeMo I 2024-12-20 00:14:32 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:14:32 remote:124] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl?download=true [NeMo I 2024-12-20 00:14:32 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:14:32 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:14:32 remote:124] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl?download=true [NeMo I 2024-12-20 00:14:32 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:14:32 infer_geneformer:83] *************** Preprocessing Finished ************ GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs [NeMo I 2024-12-20 00:14:32 megatron_strategy:315] Fixing mis-match between ddp-config & mcore-optimizer config [NeMo I 2024-12-20 00:14:32 megatron_init:396] Rank 0 has data parallel group : [0] [NeMo I 2024-12-20 00:14:32 megatron_init:402] Rank 0 has combined group of data parallel and context parallel : [0] [NeMo I 2024-12-20 00:14:32 megatron_init:407] All data parallel group ranks with context parallel combined: [[0]] [NeMo I 2024-12-20 00:14:32 megatron_init:410] Ranks 0 has data parallel rank: 0 [NeMo I 2024-12-20 00:14:32 megatron_init:418] Rank 0 has context parallel group: [0] [NeMo I 2024-12-20 00:14:32 megatron_init:421] All context parallel group ranks: [[0]] [NeMo I 2024-12-20 00:14:32 megatron_init:422] Ranks 0 has context parallel rank: 0 [NeMo I 2024-12-20 00:14:32 megatron_init:429] Rank 0 has model parallel group: [0] [NeMo I 2024-12-20 00:14:32 megatron_init:430] All model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:14:32 megatron_init:439] Rank 0 has tensor model parallel group: [0] [NeMo I 2024-12-20 00:14:32 megatron_init:443] All tensor model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:14:32 megatron_init:444] Rank 0 has tensor model parallel rank: 0 [NeMo I 2024-12-20 00:14:32 megatron_init:464] Rank 0 has pipeline model parallel group: [0] [NeMo I 2024-12-20 00:14:32 megatron_init:476] Rank 0 has embedding group: [0] [NeMo I 2024-12-20 00:14:32 megatron_init:482] All pipeline model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:14:32 megatron_init:483] Rank 0 has pipeline model parallel rank 0 [NeMo I 2024-12-20 00:14:32 megatron_init:484] All embedding group ranks: [[0]] [NeMo I 2024-12-20 00:14:32 megatron_init:485] Rank 0 has embedding rank: 0 Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1 ---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 1 processes ---------------------------------------------------------------------------------------------------- WARNING: Logging before flag parsing goes to stderr. W1220 00:14:33.385426 140118637597120 config.py:85] Loading /workspace/bionemo2/.cache/bionemo/fb6e70cd6bd98fb8941b5de978e95db17a6b8596f1c03f4d641a6d2ba6599757-geneformer_10M_241113_nemo2.tar.gz.untar [NeMo I 2024-12-20 00:14:34 base:44] Padded vocab_size: 25472, original vocab_size: 25429, dummy tokens: 43. [NeMo W 2024-12-20 00:14:34 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py:25: UserWarning: 'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead. warnings.warn( LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [NeMo W 2024-12-20 00:14:34 megatron_strategy:329] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored. [NeMo I 2024-12-20 00:14:34 megatron_parallel:549] > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 10300032
!infer_geneformer \
--data-dir {data_dir} \
--results-path {results_path_10m_random} \
--micro-batch-size {micro_batch_size} \
--seq-len 2048 \
--num-dataset-workers 10 \
--num-gpus 1 \
--include-input-ids
[NeMo W 2024-12-20 00:15:07 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work warn("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work", RuntimeWarning) [NeMo W 2024-12-20 00:15:08 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead. cm = get_cmap("Set1") [NeMo W 2024-12-20 00:15:08 ssm:31] The package `megatron.core` was not imported in this environment which is needed for SSMs. [NeMo W 2024-12-20 00:15:09 preprocess:101] Tokenizer vocab file: /workspace/bionemo2/.cache/bionemo/d8e3ea569bc43768c24aa651aff77722df202078415528497c22394046b08cc3-singlecell-scdltestdata-20241203.tar.gz.untar/cellxgene_2023-12-15_small_processed_scdl/train/geneformer.vocab already exists. Overwriting... [NeMo I 2024-12-20 00:15:09 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:15:09 remote:124] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl?download=true [NeMo I 2024-12-20 00:15:09 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:15:09 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:15:09 remote:124] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl?download=true [NeMo I 2024-12-20 00:15:09 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:15:09 infer_geneformer:83] *************** Preprocessing Finished ************ GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs [NeMo I 2024-12-20 00:15:09 megatron_strategy:315] Fixing mis-match between ddp-config & mcore-optimizer config [NeMo I 2024-12-20 00:15:09 megatron_init:396] Rank 0 has data parallel group : [0] [NeMo I 2024-12-20 00:15:09 megatron_init:402] Rank 0 has combined group of data parallel and context parallel : [0] [NeMo I 2024-12-20 00:15:09 megatron_init:407] All data parallel group ranks with context parallel combined: [[0]] [NeMo I 2024-12-20 00:15:09 megatron_init:410] Ranks 0 has data parallel rank: 0 [NeMo I 2024-12-20 00:15:09 megatron_init:418] Rank 0 has context parallel group: [0] [NeMo I 2024-12-20 00:15:09 megatron_init:421] All context parallel group ranks: [[0]] [NeMo I 2024-12-20 00:15:09 megatron_init:422] Ranks 0 has context parallel rank: 0 [NeMo I 2024-12-20 00:15:09 megatron_init:429] Rank 0 has model parallel group: [0] [NeMo I 2024-12-20 00:15:09 megatron_init:430] All model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:15:09 megatron_init:439] Rank 0 has tensor model parallel group: [0] [NeMo I 2024-12-20 00:15:09 megatron_init:443] All tensor model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:15:09 megatron_init:444] Rank 0 has tensor model parallel rank: 0 [NeMo I 2024-12-20 00:15:09 megatron_init:464] Rank 0 has pipeline model parallel group: [0] [NeMo I 2024-12-20 00:15:09 megatron_init:476] Rank 0 has embedding group: [0] [NeMo I 2024-12-20 00:15:09 megatron_init:482] All pipeline model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:15:09 megatron_init:483] Rank 0 has pipeline model parallel rank 0 [NeMo I 2024-12-20 00:15:09 megatron_init:484] All embedding group ranks: [[0]] [NeMo I 2024-12-20 00:15:09 megatron_init:485] Rank 0 has embedding rank: 0 Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1 ---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 1 processes ---------------------------------------------------------------------------------------------------- [NeMo I 2024-12-20 00:15:09 base:44] Padded vocab_size: 25472, original vocab_size: 25429, dummy tokens: 43. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [NeMo W 2024-12-20 00:15:09 megatron_strategy:329] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored. [NeMo I 2024-12-20 00:15:09 megatron_parallel:549] > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 10300032
!infer_geneformer \
--data-dir {data_dir} \
--checkpoint-path {geneformer_106m} \
--results-path {result_path_106m} \
--micro-batch-size {micro_batch_size} \
--seq-len 2048 \
--num-dataset-workers 10 \
--num-gpus 1 \
--include-input-ids
[NeMo W 2024-12-20 00:15:44 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work warn("Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work", RuntimeWarning) [NeMo W 2024-12-20 00:15:45 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/pyannote/core/notebook.py:134: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead. cm = get_cmap("Set1") [NeMo W 2024-12-20 00:15:45 ssm:31] The package `megatron.core` was not imported in this environment which is needed for SSMs. [NeMo W 2024-12-20 00:15:46 preprocess:101] Tokenizer vocab file: /workspace/bionemo2/.cache/bionemo/d8e3ea569bc43768c24aa651aff77722df202078415528497c22394046b08cc3-singlecell-scdltestdata-20241203.tar.gz.untar/cellxgene_2023-12-15_small_processed_scdl/train/geneformer.vocab already exists. Overwriting... [NeMo I 2024-12-20 00:15:46 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:15:46 remote:124] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl?download=true [NeMo I 2024-12-20 00:15:46 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:15:46 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:15:46 remote:124] Resource already exists, skipping download: https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl?download=true [NeMo I 2024-12-20 00:15:46 remote:136] No checksum provided, filename exists. Assuming it is complete. [NeMo I 2024-12-20 00:15:46 infer_geneformer:83] *************** Preprocessing Finished ************ GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs [NeMo I 2024-12-20 00:15:46 megatron_strategy:315] Fixing mis-match between ddp-config & mcore-optimizer config [NeMo I 2024-12-20 00:15:46 megatron_init:396] Rank 0 has data parallel group : [0] [NeMo I 2024-12-20 00:15:46 megatron_init:402] Rank 0 has combined group of data parallel and context parallel : [0] [NeMo I 2024-12-20 00:15:46 megatron_init:407] All data parallel group ranks with context parallel combined: [[0]] [NeMo I 2024-12-20 00:15:46 megatron_init:410] Ranks 0 has data parallel rank: 0 [NeMo I 2024-12-20 00:15:46 megatron_init:418] Rank 0 has context parallel group: [0] [NeMo I 2024-12-20 00:15:46 megatron_init:421] All context parallel group ranks: [[0]] [NeMo I 2024-12-20 00:15:46 megatron_init:422] Ranks 0 has context parallel rank: 0 [NeMo I 2024-12-20 00:15:46 megatron_init:429] Rank 0 has model parallel group: [0] [NeMo I 2024-12-20 00:15:46 megatron_init:430] All model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:15:46 megatron_init:439] Rank 0 has tensor model parallel group: [0] [NeMo I 2024-12-20 00:15:46 megatron_init:443] All tensor model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:15:46 megatron_init:444] Rank 0 has tensor model parallel rank: 0 [NeMo I 2024-12-20 00:15:46 megatron_init:464] Rank 0 has pipeline model parallel group: [0] [NeMo I 2024-12-20 00:15:46 megatron_init:476] Rank 0 has embedding group: [0] [NeMo I 2024-12-20 00:15:46 megatron_init:482] All pipeline model parallel group ranks: [[0]] [NeMo I 2024-12-20 00:15:46 megatron_init:483] Rank 0 has pipeline model parallel rank 0 [NeMo I 2024-12-20 00:15:46 megatron_init:484] All embedding group ranks: [[0]] [NeMo I 2024-12-20 00:15:46 megatron_init:485] Rank 0 has embedding rank: 0 Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1 ---------------------------------------------------------------------------------------------------- distributed_backend=nccl All distributed processes registered. Starting with 1 processes ---------------------------------------------------------------------------------------------------- WARNING: Logging before flag parsing goes to stderr. W1220 00:15:46.932231 126658719736256 config.py:85] Loading /workspace/bionemo2/.cache/bionemo/7d67a526379eb8581f2aaaf03425ae9ec81a38570b24ddc8b22818e5d26ea772-geneformer_106M_240530_nemo2.tar.gz.untar [NeMo I 2024-12-20 00:15:47 base:44] Padded vocab_size: 25472, original vocab_size: 25429, dummy tokens: 43. [NeMo W 2024-12-20 00:15:47 nemo_logging:361] /usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py:25: UserWarning: 'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead. warnings.warn( LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] [NeMo W 2024-12-20 00:15:48 megatron_strategy:329] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored. [NeMo I 2024-12-20 00:15:48 megatron_parallel:549] > number of parameters on (tensor, pipeline) model parallel rank (0, 0): 106808960
Benchmarking¶
see below the benchmarking snippet. We take in a datavector, and a set of labels. We optionally fit PCA and then a RF model inside cross validation. Metrics are using the macro
(average over each class) for handling multi-class labels. Additionally, we return the confusion matrix for further investigation.
def run_benchmark(data, labels, use_pca=True):
'''
data - contains the single cell expression (or whatever feature) in each row.
labels - contains the string label for each cell
data_shape (R, C)
labels_shape (R,)
'''
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.decomposition import PCA
from sklearn.model_selection import cross_val_predict
np.random.seed(1337)
# Define the target dimension 'n_components'
n_components = 10 # for example, adjust based on your specific needs
# Create a pipeline that includes Gaussian random projection and RandomForestClassifier
if use_pca:
pipeline = Pipeline([
('projection', PCA(n_components=n_components)),
('classifier', RandomForestClassifier(class_weight='balanced'))
])
else:
pipeline = Pipeline([
('classifier', RandomForestClassifier(class_weight='balanced'))
])
# Set up StratifiedKFold to ensure each fold reflects the overall distribution of labels
cv = StratifiedKFold(n_splits=5)
# Define the scoring functions
scoring = {
'accuracy': make_scorer(accuracy_score),
'precision': make_scorer(precision_score, average='macro'), # 'macro' averages over classes
'recall': make_scorer(recall_score, average='macro'),
'f1_score': make_scorer(f1_score, average='macro'),
# 'roc_auc' requires probability or decision function; hence use multi_class if applicable
'roc_auc': make_scorer(roc_auc_score, multi_class='ovr', needs_proba=True),
}
# Perform stratified cross-validation with multiple metrics using the pipeline
results = cross_validate(pipeline, data, labels, cv=cv, scoring=scoring, return_train_score=False)
# Print the cross-validation results
print("Cross-validation metrics:")
results_out = {}
for metric, scores in results.items():
if metric.startswith('test_'):
results_out[metric] = (scores.mean(), scores.std())
print(f"{metric[5:]}: {scores.mean():.3f} (+/- {scores.std():.3f})")
predictions = cross_val_predict(pipeline, data, labels, cv=cv)
# v Return confusion matrix and metrics.
conf_matrix = confusion_matrix(labels, predictions)
return results_out, conf_matrix
#NBVAL_CHECK_OUTPUT
import torch
infer_Xs_10m = torch.load(result_path_10m / "predictions__rank_0.pt")['embeddings'].float().cpu().numpy()
assert len(adata) == len(infer_Xs_10m), (len(adata), len(infer_Xs_10m))
infer_Xs_10m.shape
(8192, 256)
#NBVAL_CHECK_OUTPUT
import torch
infer_Xs_10m_bnmo2 = torch.load(result_path_10m_bnmo2 / "predictions__rank_0.pt")['embeddings'].float().cpu().numpy()
assert len(adata) == len(infer_Xs_10m_bnmo2), (len(adata), len(infer_Xs_10m))
infer_Xs_10m_bnmo2.shape
(8192, 256)
#NBVAL_CHECK_OUTPUT
infer_Xs_106m = torch.load(result_path_106m / "predictions__rank_0.pt")['embeddings'].float().cpu().numpy()
assert len(adata) == len(infer_Xs_106m), (len(adata), len(infer_Xs_106m))
infer_Xs_106m.shape
(8192, 768)
#NBVAL_CHECK_OUTPUT
import torch
infer_Xs_10m_random = torch.load(results_path_10m_random / "predictions__rank_0.pt")['embeddings'].float().cpu().numpy()
assert len(adata) == len(infer_Xs_10m_random), (len(adata), len(infer_Xs_10m_random))
infer_Xs_10m_random.shape
(8192, 256)
import pandas as pd
import numpy as np
# Now fetch the class labels and raw expression for the same dataset. These are used as labels in classification and as one of our baselines.
infer_metadata = adata.obs
raw_Xs = np.asarray(adata.X.todense())
# Here we perform a norm over the total counts for each cell, adding a pseudocount to assist with the following logarithm.
normed_Xs = (raw_Xs + 1) / raw_Xs.sum(axis=1, keepdims=True)
logp1_Xs = np.log( normed_Xs )
# Now we look at our dataset, how is the distribution of cell counts? Its clear that certain celltypes dominate the dataset, this is good to keep in mind when investigating models.
# we expect the macro averages and F1-score to be the most reliable metrics for overall performance.
from collections import Counter
import seaborn as sb
labels = infer_metadata['cell_type'].values
label_counts = Counter(labels)
ax = sb.barplot(x=label_counts.keys(), y=label_counts.values())
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.set_title("Cell type counts for classification dataset")
/tmp/ipykernel_106608/2938980837.py:10: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator. ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
Text(0.5, 1.0, 'Cell type counts for classification dataset')
# Now we assign integer labels to each of our strings. These do not need to be transformed into one-hot vectors as Random Forest is non-parametric.
from sklearn.preprocessing import LabelEncoder
label_encoder = LabelEncoder()
integer_labels = label_encoder.fit_transform(labels)
print(integer_labels)
[ 1 1 19 ... 17 14 14]
# Distribution of log transforms, looks decent.
from matplotlib import pyplot
pyplot.hist(logp1_Xs.flatten());
def plot_cm(cm, labels=label_encoder.classes_):
'''
Helper function for visualizing accuracy across labels.
'''
from matplotlib.colors import BoundaryNorm, ListedColormap
# Example confusion matrix (replace with your actual data)
conf_matrix = np.random.rand(31, 31)
# Define the bins and the color map
#bounds = np.arange(0.0, 1.1, 0.1)
#cmap = ListedColormap(sb.color_palette("RdYlBu_r", len(bounds) - 1))
#norm = BoundaryNorm(boundaries=bounds, ncolors=len(bounds) - 1, clip=True)
#_ = sb.heatmap(cm / cm.sum(axis=0),cmap=cmap, norm=norm, cbar_kws={"ticks": bounds}, linewidths=0.5, linecolor='black', xticklabels=labels, yticklabels=labels)
_ = sb.heatmap(cm / cm.sum(axis=0), cmap=sb.color_palette("Blues", as_cmap=True), vmin=0, vmax=1, linewidth=0.1, linecolor='lightgrey', xticklabels=labels, yticklabels=labels)
pyplot.xticks(rotation=45, ha='right')
_ = pyplot.yticks(rotation=0)
Execute benchmarks¶
Finally we execute our benchmarks, and collect results and confusion matrix. You can see in the figures below, we plot the performance by cell type for each model (confusion matrix heatmap). Perhaps unsurprisingly, we see that the most frequent cell type (enterocyte) has the highest accuracy across all models. This suggests bias in the model due to unbalanced data, however, further investigation is beyond the scope of this tutorial. Furthermore, we see continually improved performance as we move through the models, from baselines, to our provided pretrained model.
Perhaps most interesting is the 106M parameter model, which clearly outperforms all other models by all metrics, but especially by F1-score. This suggests that training larger models based on geneformer perform well, and that more work may be done.
logp1_results, logp1_cm = run_benchmark(logp1_Xs, integer_labels)
/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result))
Cross-validation metrics: accuracy: 0.776 (+/- 0.033) precision: 0.630 (+/- 0.046) recall: 0.549 (+/- 0.015) f1_score: 0.562 (+/- 0.024) roc_auc: 0.970 (+/- 0.009)
plot_cm(logp1_cm)
/tmp/ipykernel_106608/3742577664.py:16: RuntimeWarning: invalid value encountered in divide _ = sb.heatmap(cm / cm.sum(axis=0), cmap=sb.color_palette("Blues", as_cmap=True), vmin=0, vmax=1, linewidth=0.1, linecolor='lightgrey', xticklabels=labels, yticklabels=labels)
results_10m_random, cm_10m_random = run_benchmark(infer_Xs_10m_random, integer_labels, use_pca=False)
/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result))
Cross-validation metrics: accuracy: 0.396 (+/- 0.013) precision: 0.158 (+/- 0.046) recall: 0.091 (+/- 0.006) f1_score: 0.079 (+/- 0.007) roc_auc: 0.741 (+/- 0.007)
plot_cm(cm_10m_random)
/tmp/ipykernel_106608/3742577664.py:16: RuntimeWarning: invalid value encountered in divide _ = sb.heatmap(cm / cm.sum(axis=0), cmap=sb.color_palette("Blues", as_cmap=True), vmin=0, vmax=1, linewidth=0.1, linecolor='lightgrey', xticklabels=labels, yticklabels=labels)
results_10m, cm_10m = run_benchmark(infer_Xs_10m, integer_labels, use_pca=False)
/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result))
Cross-validation metrics: accuracy: 0.839 (+/- 0.016) precision: 0.788 (+/- 0.029) recall: 0.677 (+/- 0.015) f1_score: 0.702 (+/- 0.017) roc_auc: 0.986 (+/- 0.006)
plot_cm(cm_10m)
results_10m_bnmo2, cm_10m_bnmo2 = run_benchmark(infer_Xs_10m_bnmo2, integer_labels, use_pca=False)
/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) /usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result))
Cross-validation metrics: accuracy: 0.834 (+/- 0.021) precision: 0.790 (+/- 0.052) recall: 0.675 (+/- 0.031) f1_score: 0.703 (+/- 0.037) roc_auc: 0.990 (+/- 0.007)
plot_cm(cm_10m_bnmo2)
results_106M, cm_106M = run_benchmark(infer_Xs_106m, integer_labels, use_pca=False)
/usr/local/lib/python3.10/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result))
Cross-validation metrics: accuracy: 0.905 (+/- 0.015) precision: 0.912 (+/- 0.025) recall: 0.819 (+/- 0.015) f1_score: 0.843 (+/- 0.016) roc_auc: 0.991 (+/- 0.006)
plot_cm(cm_106M)
data = {
'model': [
'Baseline Logp1 PCA+RF',
'10M RandomWeights',
'10M parameters',
'10M parameters BioNeMo2 re-trained',
'106M parameters'],
'f1_score_mean': [
logp1_results['test_f1_score'][0],
results_10m_random['test_f1_score'][0],
results_10m['test_f1_score'][0],
results_10m_bnmo2['test_f1_score'][0],
results_106M['test_f1_score'][0]
],
'f1_score_std': [
logp1_results['test_f1_score'][1],
results_10m_random['test_f1_score'][1],
results_10m['test_f1_score'][1],
results_10m_bnmo2['test_f1_score'][1],
results_106M['test_f1_score'][1]
],
'accuracy_mean': [
logp1_results['test_accuracy'][0],
results_10m_random['test_accuracy'][0],
results_10m['test_accuracy'][0],
results_10m_bnmo2['test_accuracy'][0],
results_106M['test_accuracy'][0]
],
'accuracy_std': [
logp1_results['test_accuracy'][1],
results_10m_random['test_accuracy'][1],
results_10m['test_accuracy'][1],
results_10m_bnmo2['test_accuracy'][1],
results_106M['test_accuracy'][1]
]
}
df = pd.DataFrame(data)
fig, ax = pyplot.subplots(figsize=(10, 10))
# F1 Score plot
sb.barplot(x='model', y='f1_score_mean', data=df, capsize=0.2, palette='viridis', ax=ax)
ax.set_title('F1 Score Comparison')
ax.set_xlabel('Model')
ax.set_ylabel('F1 Score')
ax.set_yticks(np.arange(.0, 1.05, .05))
ax.set_ylim(.0, 1.0)
pyplot.xticks(rotation=45, ha='right')
pyplot.savefig("F1-score-models.png")
# Accuracy plot
fig, ax = pyplot.subplots(figsize=(10, 10))
sb.barplot(x='model', y='accuracy_mean', data=df, ax=ax, capsize=0.2, palette='viridis')
ax.set_title('Accuracy Comparison')
ax.set_xlabel('Model')
ax.set_ylabel('Accuracy')
ax.set_yticks(np.arange(.0, 1.05, .05))
ax.set_ylim(.0, 1.0)
pyplot.xticks(rotation=45, ha='right')
pyplot.savefig("average-accuracy-models.png")
/tmp/ipykernel_106608/805283967.py:42: FutureWarning: Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect. sb.barplot(x='model', y='f1_score_mean', data=df, capsize=0.2, palette='viridis', ax=ax) /tmp/ipykernel_106608/805283967.py:53: FutureWarning: Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect. sb.barplot(x='model', y='accuracy_mean', data=df, ax=ax, capsize=0.2, palette='viridis')