Build a model to predict the gene activity¶
This notebook shows how to use the ChromBERT-tools Python API gene_activity_regression to predict gene activity from TSS-centered representations that incorporate the promoter region together with upstream and downstream flanking regions.
For the bash command-line usage, see `examples/cli/gene_activity_regression.ipynb <../cli/gene_activity_regression.ipynb>`__.
For more details, please refer to the `gene_activity_regression <https://chrombert-tools.readthedocs.io/en/latest/commands/gene_activity_regression.html>`__ command documentation
[1]:
from chrombert_tools import gene_activity_regression
predict the foldchange of gene activity between cell-state-transitions¶
[ ]:
# Note: The input files used in this test were downsampled to 5,000 genes for expression analysis.
# This test typically takes 40–100 minutes.
# For formal analyses, please use the full gene set.
results = gene_activity_regression(
exp_tpm1="../data/fibroblast_expression_sample5000.csv", # expression of state1
exp_tpm2="../data/myoblast_expression_sample5000.csv", # expression of state2
odir="./output_gene_activity_repression" # output directory
)
Stage 1: prepare expression dataset
Processing stage 1: prepare expression dataset
Mode: two states (log1p TPM fold change; per state, genes inner-merged across ';' files then TPM mean)
Finished stage 1
Stage 2: train ChromBERT (expression fold change, two states)
[Attempt 0/2] seed=55
Load pretrained ckpt /mnt/Storage/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt successfully!
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loading `train_dataloader` to estimate number of stepping batches.
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary. Estimated model size in MB will not be accurate. Using 32 bits instead.
| Name | Type | Params | Mode
-----------------------------------------------
0 | model | ChromBERTGEP | 62.8 M | train
-----------------------------------------------
18.9 M Trainable params
43.9 M Non-trainable params
62.8 M Total params
251.095 Total estimated model params size (MB)
154 Modules in train mode
0 Modules in eval mode
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:484: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
Metric default_validation/pcc improved. New best score: 0.007
Metric default_validation/pcc improved by 0.067 >= min_delta = 0.01. New best score: 0.073
Metric default_validation/pcc improved by 0.331 >= min_delta = 0.01. New best score: 0.405
Metric default_validation/pcc improved by 0.011 >= min_delta = 0.01. New best score: 0.416
Metric default_validation/pcc improved by 0.032 >= min_delta = 0.01. New best score: 0.449
Metric default_validation/pcc improved by 0.047 >= min_delta = 0.01. New best score: 0.496
Metric default_validation/pcc improved by 0.015 >= min_delta = 0.01. New best score: 0.511
Monitored metric default_validation/pcc did not improve in the last 10 records. Best score: 0.511. Signaling Trainer to stop.
Evaluating the finetuned model performance
Load pretrained ckpt /mnt/Storage/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt successfully!
Loading checkpoint from /mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/api/output_gene_activity_repression/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=3-step=107.ckpt
Loading from pl module, remove prefix 'model.'
Loading from pl module, replace 'pretrain_model' with 'pretrain_model.chrombert'
Loaded 111/111 parameters
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
warnings.warn(*args, **kwargs)
ft_ckpt: /mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/api/output_gene_activity_repression/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=3-step=107.ckpt, test_metrics: {'pearsonr': 0.367515504360199, 'spearmanr': 0.3869723677635193, 'mse': 0.7395175099372864, 'mae': 0.5959581732749939, 'r2': 0.05928433735034799}
Attempt metrics: pearsonr=0.367515504360199
Accepted run (pearsonr=0.3675 >= 0.2).
Finished stage 2: obtained a fine-tuned ChromBERT
Best pearsonr=0.367515504360199, metrics={'pearsonr': 0.367515504360199, 'spearmanr': 0.3869723677635193, 'mse': 0.7395175099372864, 'mae': 0.5959581732749939, 'r2': 0.05928433735034799, 'ft_ckpt': '/mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/api/output_gene_activity_repression/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=3-step=107.ckpt'}
Finished stage 2 (trained)
Stage 3: Predicting
Region summary - total: 422, overlapping with ChromBERT: 422 (one region may overlap multiple ChromBERT regions, we keep overlaps with ≥50% coverage of either the ChromBERT bin or the input region), non-overlapping: 0
Predict input: ./output_gene_activity_repression/predict/model_input.tsv
Predicting: 100%|██████████| 106/106 [03:04<00:00, 1.74s/it]
Predictions saved: /mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/api/output_gene_activity_repression/predict/predictions.csv (422 rows)
Finished stage 3
============================================================
All stages completed!
============================================================
Fine-tuned checkpoint: /mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/api/output_gene_activity_repression/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=3-step=107.ckpt
Predictions: ./output_gene_activity_repression/predict/predictions.csv
[ ]:
predict_df = results.predictions_df # test predict
predict_df.head()
| chrom | start | end | build_region_index | predicted_value | true_label | |
|---|---|---|---|---|---|---|
| 0 | chr1 | 1540000 | 1541000 | 797 | 0.434477 | -0.993054 |
| 1 | chr1 | 2212000 | 2213000 | 1414 | 0.948955 | -0.082941 |
| 2 | chr1 | 6261000 | 6262000 | 4717 | 0.193692 | 0.916313 |
| 3 | chr1 | 15152000 | 15153000 | 12259 | 0.345492 | -0.136341 |
| 4 | chr1 | 19251000 | 19252000 | 15865 | -0.339500 | -0.583960 |
[ ]:
# calculate the pearson and spearman correlation
from scipy.stats import pearsonr,spearmanr
pearsonr(predict_df.true_label,predict_df.predicted_value),spearmanr(predict_df.true_label,predict_df.predicted_value)
(PearsonRResult(statistic=0.36751555800453894, pvalue=6.09072052584172e-15),
SignificanceResult(statistic=0.386966464843722, pvalue=1.5931780224661528e-16))
[ ]:
ft_ckpt = results.model_ckpt # the path of model checkpoint
model_config = results.model_config # model configuration file
data_config = results.data_config # data configuration file
[ ]:
ft_ckpt
'/mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/api/output_gene_activity_repression/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=3-step=107.ckpt'
Interpretation¶
After building a model, we can use the interpretation layer of ChromBERT-tools to derive biological insights.
Here, we identify important factors between different region groups as an example.
[ ]:
region1_file = "./output_gene_activity_repression/dataset/up.csv" # Genes with increased activity generated by the command above
region2_file = "./output_gene_activity_repression/dataset/nochange.csv" # Genes with unchanged activity generated by the command above
[ ]:
# The ``interpret_regulator_effects_between_region_groups`` command generates regulator embeddings for different region groups and calculates the embedding difference of each regulator between groups as 1 − cosine similarity. Regulators with larger embedding differences are considered more likely to be key regulators.
from chrombert_tools import interpret_regulator_effects_between_region_groups as run_key_regulator
factor_importance_rank = run_key_regulator(
region1_file=region1_file, # your focus region group 1
region2_file=region2_file, # your focus region group 2
odir="output_gene_activity_repression", # output directory
genome="hg38", # genome
resolution="1kb", # resolution
ft_ckpt=ft_ckpt, # fine-tuned checkpoint in step 1
batch_size=64, # batch size
model_config=results.model_config, # Model configuration file in Step 1
data_config=results.data_config, # Data configuration file in Step 1
)
Region summary - total: 868, overlapping with ChromBERT: 868 (one region may overlap multiple ChromBERT regions, we keep overlaps with ≥50% coverage of either the ChromBERT bin or the input region), non-overlapping: 0
Region summary - total: 1000, overlapping with ChromBERT: 1000 (one region may overlap multiple ChromBERT regions, we keep overlaps with ≥50% coverage of either the ChromBERT bin or the input region), non-overlapping: 0
Load pretrained ckpt /mnt/Storage/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt successfully!
Loading checkpoint from /mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/api/output_gene_activity_repression/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=3-step=107.ckpt
Loading from pl module, remove prefix 'model.'
Loading from pl module, replace 'pretrain_model' with 'pretrain_model.chrombert'
Loaded 111/111 parameters
100%|██████████| 434/434 [02:09<00:00, 3.34it/s]
100%|██████████| 500/500 [02:32<00:00, 3.28it/s]
Identify key regulators across regions(top 25)
factors similarity rank embedding_shift
0 histone lysine acetylation 0.948817 1 0.051183
1 histone lysine crotonylation 0.950988 2 0.049012
2 cbx6 0.955744 3 0.044256
3 cbx7 0.957784 4 0.042216
4 h2ax 0.962921 5 0.037079
5 h2ak5ac 0.965769 6 0.034231
6 h3k36me1 0.970262 7 0.029738
7 h3k79me1 0.970611 8 0.029389
8 h2bk20ac 0.973015 9 0.026985
9 cbx8 0.973283 10 0.026717
10 brd7 0.974076 11 0.025924
11 h2bk120ac 0.974172 12 0.025828
12 h4tetraac 0.975699 13 0.024301
13 h4k91ac 0.975757 14 0.024243
14 ring1 0.976644 15 0.023356
15 h2bk12ac 0.976955 16 0.023045
16 h2bk15ac 0.977530 17 0.022470
17 rb1 0.979065 18 0.020935
18 prkdc 0.980111 19 0.019889
19 h3k9/14ac 0.980144 20 0.019856
20 h2afy 0.980789 21 0.019211
21 sumo1 0.980812 22 0.019188
22 klf11 0.981067 23 0.018933
23 hira 0.981255 24 0.018745
24 h3k56ac 0.981475 25 0.018525
Finished
Used fine-tuned ChromBERT checkpoint: /mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/api/output_gene_activity_repression/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=3-step=107.ckpt
Key regulators across regions saved to: output_gene_activity_repression/results/factor_importance_rank.csv
[ ]: