Generate context-specific representations from ChromBERT

Note: The remaining examples show Bash command-line usage only for extracting region embeddings.

embed_region subcommand: Generate region embeddings for input regions.

For the Python API, see `examples/api/embed_region.ipynb <../api/embed_region.ipynb>`__.

If you need to use Apptainer container, please refer to the `apptainer_use.ipynb <apptainer_use.ipynb>`__ tutorial for detailed instructions on using apptainer exec with chrombert-tools.

For more details, please refer to the `embed_region <https://chrombert-tools.readthedocs.io/en/latest/commands/embed_region.html>`__ command documentation

Generate region embeddings (pre-trained and general)

[1]:
### options parameter
!chrombert-tools embed_region -h
Usage: chrombert-tools embed_region [OPTIONS]

  Generate region embeddings for specified regions or gene promoter regions

Options:
  --region FILE                   Region BED file.
  --gene TEXT                     Gene symbols or IDs separated by ';'.
  --cell-type-bw FILE             Cell type accessibility BigWig file.
  --cell-type-peak FILE           Cell type accessibility Peak BED file.
  --ft-ckpt FILE                  Fine-tuned checkpoint. If provided, skip
                                  fine-tuning.
  --odir DIRECTORY                [default: ./output]
  --oname TEXT                    [default: embedding]
  --genome [hg38|mm10]            [default: hg38]
  --resolution [1kb|200bp|2kb|4kb]
                                  [default: 1kb]
  --mode [fast|full]              Used when training cell-specific model.
                                  [default: fast]
  --batch-size INTEGER            [default: 4]
  --chrombert-cache-dir DIRECTORY
                                  [default: ~/.cache/chrombert/data]
  --chrombert-region-file FILE
  --chrombert-region-emb-file FILE
  --chrombert-gene-meta FILE
  -h, --help                      Show this message and exit.
[2]:
%%bash
# --region: Input regions to embed.
# --odir: Output directory.
# --genome: Genome to use.
# --resolution: Resolution of the input regions.
chrombert-tools embed_region \
    --region ../data/umap_region_1kb_downsample.bed \
    --odir ./output_emb_region_1kb \
    --genome hg38 \
    --resolution 1kb
Region summary - total: 3000, overlapping with ChromBERT: 3000 (one region may overlap multiple ChromBERT regions, we keep overlaps with ≥50% coverage of either the ChromBERT bin or the input region), non-overlapping: 0
Using cached region embeddings...

Finished!
Focus region summary - total: 3000, overlapping with ChromBERT: 3000, non-overlapping: 0
Note: It is possible for a single region to overlap multiple ChromBERT regions.
Overlapping regions BED file: ./output_emb_region_1kb/overlap_region.bed
Non-overlapping regions BED file: ./output_emb_region_1kb/no_overlap_region.bed
Region embeddings saved to: ./output_emb_region_1kb/region_emb_embedding.npy
Embedding type: general
[3]:
import numpy as np
import pandas as pd
[4]:
# overlap_region_emb:
# Region embedding array with shape ``(n_regions, 768)``. Each row corresponds to one ChromBERT bin overlapping the input regions.

overlap_region_emb = np.load("./output_emb_region_1kb/region_emb_embedding.npy")
print(overlap_region_emb.shape)

# overlap_region.bed: Input regions that overlap ChromBERT bins.
# Contains columns: chrom, start, end, build_region_index

# model_input.tsv: Input regions that overlap ChromBERT bins.
# Contains columns: chrom, start, end, build_region_index, start_input, end_input

overlap_region = pd.read_csv("./output_emb_region_1kb/model_input.tsv",sep='\t')
len(overlap_region),overlap_region.head()
(3000, 768)
[4]:
(3000,
   chrom    start      end  build_region_index  start_input  end_input
 0  chr1    54000    55000                   6        54000      55000
 1  chr1   916000   917000                 213       916000     917000
 2  chr1   960000   961000                 255       960000     961000
 3  chr1  1004000  1005000                 296      1004000    1005000
 4  chr1  1013000  1014000                 305      1013000    1014000)
[5]:
# Map each region to its user-defined annotation class


region_anno = pd.read_csv("../data/umap_region_1kb_downsample.bed",sep="\t",header=None,names=["chrom","start_input","end_input","build_region_index","anno"])
overlap_region = overlap_region.merge(region_anno[["chrom","start_input","end_input","anno"]],on=["chrom","start_input","end_input"])
overlap_region


# umap plot
from chrombert_tools import umap_plot
umap_plot(overlap_region_emb, overlap_region["anno"], "./output_emb_region_1kb")
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.
  warn(
../../_images/examples_cli_embed_region_7_1.png

Generate gene (promoter region) embeddings

[6]:
%%bash
# --gene: Gene names or ENSEMBL IDs to embed.
# --odir: Output directory.
# --genome: Genome to use.
# --resolution: Resolution of the input regions.

chrombert-tools embed_region \
    --gene "ENSG00000170921;TANC2;ENSG00000200997;DPYD;SNORA70;tp53;brd4" \
    --odir "./output_emb_genes" \
    --genome hg38 \
    --resolution 1kb


Using cached region embeddings for gene pooling...

Finished!
Note: All gene names were converted to lowercase for matching.
Gene count summary - requested: 7, matched: 7, not found: 0
Matched gene meta saved to: ./output_emb_genes/overlap_genes_meta.tsv
Gene embeddings saved to: ./output_emb_genes/gene_emb_embedding.pkl
Embedding type: general
[7]:
# gene_emb.pkl: Python dictionary mapping each matched gene to a 768-dimensional embedding.
import pickle
with open("./output_emb_genes/gene_emb_embedding.pkl", "rb") as f:
    gene_emb_dict = pickle.load(f)
for key, value in gene_emb_dict.items():
    print(key, value.shape)

ensg00000170921 (768,)
tanc2 (768,)
ensg00000200997 (768,)
dpyd (768,)
snora70 (768,)
tp53 (768,)
brd4 (768,)

Generate cell-type-specific embeddings

We use cell-type-specific chromatin accessibility peak and signal files as input.

The embed_region subcommand uses these data to build a cell-type-specific model and generate cell-type-specific embeddings.

[ ]:
# download myoblast
# Myoblast cell-type-specific chromatin accessibility peak and signal files
import subprocess,os
if not os.path.exists('../data/myoblast_ENCFF647RNC_peak.bed'):
    cmd = f'wget https://www.encodeproject.org/files/ENCFF647RNC/@@download/ENCFF647RNC.bed.gz -O ../data/myoblast_ENCFF647RNC_peak.bed.gz'
    subprocess.run(cmd, shell=True)
    cmd = f"gzip -d ../data/myoblast_ENCFF647RNC_peak.bed.gz"
    subprocess.run(cmd, shell=True)

# import subprocess
if not os.path.exists('../data/myoblast_ENCFF149ERN_signal.bigwig'):
    cmd = f'wget https://www.encodeproject.org/files/ENCFF149ERN/@@download/ENCFF149ERN.bigWig -O ../data/myoblast_ENCFF149ERN_signal.bigwig'
    subprocess.run(cmd, shell=True)
[ ]:
# We only use the first 100 lines of the peak file for demonstration
!head -n 100 ../data/myoblast_ENCFF647RNC_peak.bed > ../data/myoblast_ENCFF647RNC_peak_100.bed
[ ]:
%%bash
# --region: Input regions to embed.
# --odir: Output directory.
# --genome: Genome to use.
# --resolution: Resolution of the input regions.
# --cell-type-bw: Cell-type-specific bigwig file.
# --cell-type-peak: Cell-type-specific peak file.

export CUDA_VISIBLE_DEVICES=0
chrombert-tools embed_region \
    --region ../data/myoblast_ENCFF647RNC_peak_100.bed \
    --odir ./output_cell_specific_emb_train \
    --genome hg38 \
    --resolution 1kb \
    --cell-type-bw ../data/myoblast_ENCFF149ERN_signal.bigwig \
    --cell-type-peak ../data/myoblast_ENCFF647RNC_peak.bed


Preparing dataset ...
Region summary - total: 373422, overlapping with ChromBERT: 368260 (one region may overlap multiple ChromBERT regions, we keep overlaps with ≥50% coverage of either the ChromBERT bin or the input region), non-overlapping: 7920
Total regions: 324690
Fast mode: downsampling to 20k regions
Fine-tuning cell-specific model...

[Attempt 0/2] seed=55
Load pretrained ckpt /mnt/Storage/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt successfully!
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name  | Type             | Params | Mode
---------------------------------------------------
0 | model | ChromBERTGeneral | 62.8 M | train
---------------------------------------------------
18.9 M    Trainable params
43.9 M    Non-trainable params
62.8 M    Total params
251.095   Total estimated model params size (MB)
153       Modules in train mode
0         Modules in eval mode
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:484: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
Epoch 0:  20%|██        | 800/4000 [02:18<09:15,  5.76it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Metric default_validation/pcc improved. New best score: 0.270
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.26it/s]
Epoch 0:  40%|████      | 1600/4000 [05:02<07:33,  5.30it/s, v_num=0, default_validation/r2=-0.0472, default_validation/pcc=0.270, default_validation/scc=0.196, default_validation/mae=0.222, default_validation/mse=0.121, default_validation/rmse=0.348, default_validation/mean=0.0806, default_validation/median=0.0757]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.23it/s]
Epoch 0:  60%|██████    | 2400/4000 [07:46<05:11,  5.14it/s, v_num=0, default_validation/r2=0.0133, default_validation/pcc=0.251, default_validation/scc=0.231, default_validation/mae=0.201, default_validation/mse=0.120, default_validation/rmse=0.346, default_validation/mean=0.118, default_validation/median=0.120]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Metric default_validation/pcc improved by 0.161 >= min_delta = 0.01. New best score: 0.432
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.09it/s]
Epoch 0:  80%|████████  | 3200/4000 [10:34<02:38,  5.04it/s, v_num=0, default_validation/r2=0.045, default_validation/pcc=0.432, default_validation/scc=0.404, default_validation/mae=0.184, default_validation/mse=0.117, default_validation/rmse=0.342, default_validation/mean=0.0654, default_validation/median=0.0396]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Metric default_validation/pcc improved by 0.067 >= min_delta = 0.01. New best score: 0.498
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.14it/s]
Epoch 0: 100%|██████████| 4000/4000 [13:18<00:00,  5.01it/s, v_num=0, default_validation/r2=0.206, default_validation/pcc=0.498, default_validation/scc=0.461, default_validation/mae=0.176, default_validation/mse=0.100, default_validation/rmse=0.316, default_validation/mean=0.116, default_validation/median=0.0732]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Metric default_validation/pcc improved by 0.096 >= min_delta = 0.01. New best score: 0.595
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.16it/s]
Epoch 1:  20%|██        | 800/4000 [02:19<09:19,  5.72it/s, v_num=0, default_validation/r2=0.316, default_validation/pcc=0.595, default_validation/scc=0.531, default_validation/mae=0.158, default_validation/mse=0.0776, default_validation/rmse=0.279, default_validation/mean=0.112, default_validation/median=0.0549]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Metric default_validation/pcc improved by 0.082 >= min_delta = 0.01. New best score: 0.676
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.10it/s]
Epoch 1:  40%|████      | 1600/4000 [05:03<07:35,  5.27it/s, v_num=0, default_validation/r2=0.403, default_validation/pcc=0.676, default_validation/scc=0.609, default_validation/mae=0.187, default_validation/mse=0.0683, default_validation/rmse=0.261, default_validation/mean=0.235, default_validation/median=0.202]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Metric default_validation/pcc improved by 0.124 >= min_delta = 0.01. New best score: 0.800
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.21it/s]
Epoch 1:  60%|██████    | 2400/4000 [07:47<05:11,  5.14it/s, v_num=0, default_validation/r2=0.534, default_validation/pcc=0.800, default_validation/scc=0.705, default_validation/mae=0.130, default_validation/mse=0.0541, default_validation/rmse=0.233, default_validation/mean=0.102, default_validation/median=0.040]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Metric default_validation/pcc improved by 0.035 >= min_delta = 0.01. New best score: 0.835
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.21it/s]
Epoch 1:  80%|████████  | 3200/4000 [10:30<02:37,  5.07it/s, v_num=0, default_validation/r2=0.697, default_validation/pcc=0.835, default_validation/scc=0.775, default_validation/mae=0.109, default_validation/mse=0.035, default_validation/rmse=0.187, default_validation/mean=0.172, default_validation/median=0.0654]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Metric default_validation/pcc improved by 0.036 >= min_delta = 0.01. New best score: 0.871
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.02it/s]
Epoch 1: 100%|██████████| 4000/4000 [13:45<00:00,  4.84it/s, v_num=0, default_validation/r2=0.756, default_validation/pcc=0.871, default_validation/scc=0.804, default_validation/mae=0.102, default_validation/mse=0.0306, default_validation/rmse=0.175, default_validation/mean=0.168, default_validation/median=0.0669]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Validation DataLoader 0: 100%|██████████| 250/250 [00:27<00:00,  9.02it/s]
Epoch 2:  20%|██        | 800/4000 [02:17<09:10,  5.81it/s, v_num=0, default_validation/r2=0.746, default_validation/pcc=0.878, default_validation/scc=0.807, default_validation/mae=0.103, default_validation/mse=0.0297, default_validation/rmse=0.172, default_validation/mean=0.178, default_validation/median=0.0845]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.20it/s]
Epoch 2:  40%|████      | 1600/4000 [05:00<07:30,  5.33it/s, v_num=0, default_validation/r2=0.725, default_validation/pcc=0.856, default_validation/scc=0.800, default_validation/mae=0.100, default_validation/mse=0.0314, default_validation/rmse=0.177, default_validation/mean=0.148, default_validation/median=0.0564]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Validation DataLoader 0: 100%|██████████| 250/250 [00:28<00:00,  8.84it/s]
Epoch 2:  60%|██████    | 2400/4000 [07:46<05:10,  5.15it/s, v_num=0, default_validation/r2=0.708, default_validation/pcc=0.864, default_validation/scc=0.818, default_validation/mae=0.100, default_validation/mse=0.0329, default_validation/rmse=0.181, default_validation/mean=0.133, default_validation/median=0.0444]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.23it/s]
Epoch 2:  80%|████████  | 3200/4000 [11:07<02:46,  4.80it/s, v_num=0, default_validation/r2=0.750, default_validation/pcc=0.868, default_validation/scc=0.819, default_validation/mae=0.105, default_validation/mse=0.0291, default_validation/rmse=0.171, default_validation/mean=0.191, default_validation/median=0.0898]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation: |          | 0/? [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/250 [00:00<?, ?it/s]
Monitored metric default_validation/pcc did not improve in the last 5 records. Best score: 0.871. Signaling Trainer to stop.
Validation DataLoader 0: 100%|██████████| 250/250 [00:24<00:00, 10.21it/s]
Epoch 2:  80%|████████  | 3200/4000 [11:32<02:53,  4.62it/s, v_num=0, default_validation/r2=0.733, default_validation/pcc=0.869, default_validation/scc=0.806, default_validation/mae=0.0985, default_validation/mse=0.0305, default_validation/rmse=0.175, default_validation/mean=0.130, default_validation/median=0.026]
Evaluating the finetuned model performance
Load pretrained ckpt /mnt/Storage/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt successfully!
Loading checkpoint from /mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/cli/output_cell_specific_emb_train/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=1-step=126.ckpt
Loading from pl module, remove prefix 'model.'
Loading from pl module, replace 'pretrain_model' with 'pretrain_model.chrombert'
Loaded 111/111 parameters
/mnt/Storage/home/chenqianqian/miniconda3/envs/chrombert/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.
  warnings.warn(*args, **kwargs)
ft_ckpt: /mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/cli/output_cell_specific_emb_train/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=1-step=126.ckpt, test_metrics: {'pearsonr': 0.8807425498962402, 'spearmanr': 0.8116908073425293, 'mse': 0.03028654120862484, 'mae': 0.10750571638345718, 'r2': 0.7522819668127768}
Attempt metrics: pearsonr=0.8807425498962402
Accepted run (pearsonr=0.8807 >= 0.4).

Finished stage 2: obtained a fine-tuned ChromBERT
Best pearsonr=0.8807425498962402, metrics={'pearsonr': 0.8807425498962402, 'spearmanr': 0.8116908073425293, 'mse': 0.03028654120862484, 'mae': 0.10750571638345718, 'r2': 0.7522819668127768, 'ft_ckpt': '/mnt/Storage2/home/chenqianqian/projects/chrombert/chrombert_tools/ChromBERT-tools/examples/cli/output_cell_specific_emb_train/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=1-step=126.ckpt'}
Region summary - total: 100, overlapping with ChromBERT: 101 (one region may overlap multiple ChromBERT regions, we keep overlaps with ≥50% coverage of either the ChromBERT bin or the input region), non-overlapping: 0
Your supervised_file does not contain the 'label' column. Please verify whether ground truth column ('label') is required. If it is not needed, you may disregard this message.
100%|██████████| 26/26 [00:02<00:00, 12.38it/s]

Finished!
Focus region summary - total: 100, overlapping with ChromBERT: 101, non-overlapping: 0
Note: It is possible for a single region to overlap multiple ChromBERT regions.
Overlapping regions BED file: ./output_cell_specific_emb_train/overlap_region.bed
Non-overlapping regions BED file: ./output_cell_specific_emb_train/no_overlap_region.bed
Region embeddings saved to: ./output_cell_specific_emb_train/region_emb_embedding.npy
Embedding type: cell-specific
[ ]:
# region_emb.npy: cell-type-specific region embeddings
import numpy as np
import pandas as pd
overlap_region_emb = np.load("./output_cell_specific_emb_train/region_emb_embedding.npy")
print(overlap_region_emb.shape)

# overlap_region.bed: input regions overlapped with ChromBERT's reference regions; contains columns: chrom, start, end, build_region_index

overlap_region = pd.read_csv("./output_cell_specific_emb_train/overlap_region.bed",sep='\t',header=None, names=['chrom','start','end','build_region_index'])
len(overlap_region),overlap_region.head()
(101, 768)
(101,
   chrom   start     end  build_region_index
 0  chr1  180791  180871                  38
 1  chr1  181400  181580                  39
 2  chr1  182681  182820                  40
 3  chr1  191400  191540                  46
 4  chr1  268011  268080                  54)
[8]:
# overlap_region.bed: input regions overlapped with ChromBERT's reference regions; contains columns: chrom, start, end, build_region_index, start_input, end_input
model_input = pd.read_csv("./output_cell_specific_emb_train/model_input.tsv",sep='\t')
len(model_input),model_input.head()
[8]:
(101,
   chrom   start     end  build_region_index  start_input  end_input
 0  chr1  180000  181000                  38       180791     180871
 1  chr1  181000  182000                  39       181400     181580
 2  chr1  182000  183000                  40       182681     182820
 3  chr1  191000  192000                  46       191400     191540
 4  chr1  268000  269000                  54       268011     268080)

Generate cell-type-specific embeddings using a fine-tuned checkpoint

We use a fine-tuned checkpoint as input.

The embed_region subcommand uses this checkpoint to generate cell-type-specific embeddings.

[9]:
import glob
ft_ckpt_dir = "./output_cell_specific_emb_train/train/**/*.ckpt" # Path to the cell-type-specific model checkpoint generated in the previous step

ft_ckpt = glob.glob(ft_ckpt_dir, recursive=True)[0]
ft_ckpt
[9]:
'./output_cell_specific_emb_train/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=1-step=126.ckpt'
[11]:

# --region: Input regions to embed. # --odir: Output directory. # --genome: Genome to use. # --resolution: Resolution of the input regions. # --ft-ckpt: Path to the fine-tuned checkpoint. !chrombert-tools embed_region \ --region ../data/myoblast_ENCFF647RNC_peak_100.bed \ --odir ./output_cell_specific_emb_load_ckpt \ --genome hg38 \ --resolution 1kb \ --ft-ckpt {ft_ckpt}
Using provided fine-tuned checkpoint: ./output_cell_specific_emb_train/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=1-step=126.ckpt
Load pretrained ckpt /mnt/Storage/home/chenqianqian/.cache/chrombert/data/checkpoint/hg38_6k_1kb_pretrain.ckpt successfully!
Loading checkpoint from ./output_cell_specific_emb_train/train/try_00_seed_55/lightning_logs/lightning_logs/version_0/checkpoints/epoch=1-step=126.ckpt
Loading from pl module, remove prefix 'model.'
Loading from pl module, replace 'pretrain_model' with 'pretrain_model.chrombert'
Loaded 111/111 parameters
Region summary - total: 100, overlapping with ChromBERT: 101 (one region may overlap multiple ChromBERT regions, we keep overlaps with ≥50% coverage of either the ChromBERT bin or the input region), non-overlapping: 0
Your supervised_file does not contain the 'label' column. Please verify whether ground truth column ('label') is required. If it is not needed, you may disregard this message.
100%|███████████████████████████████████████████| 26/26 [00:04<00:00,  5.39it/s]

Finished!
Focus region summary - total: 100, overlapping with ChromBERT: 101, non-overlapping: 0
Note: It is possible for a single region to overlap multiple ChromBERT regions.
Overlapping regions BED file: ./output_cell_specific_emb_load_ckpt/overlap_region.bed
Non-overlapping regions BED file: ./output_cell_specific_emb_load_ckpt/no_overlap_region.bed
Region embeddings saved to: ./output_cell_specific_emb_load_ckpt/region_emb_embedding.npy
Embedding type: cell-specific
[28]:
# region_emb.npy: one 768-dim vector per region
overlap_region_emb2 = np.load("./output_cell_specific_emb_load_ckpt/region_emb_embedding.npy")
print(overlap_region_emb.shape)

# overlap_region.bed: input regions overlapped with ChromBERT's reference regions; contains columns: chrom, start, end, build_region_index

overlap_region2 = pd.read_csv("./output_cell_specific_emb_load_ckpt/overlap_region.bed",sep='\t',header=None, names=['chrom','start','end','build_region_index'])
len(overlap_region),overlap_region.head()
(101, 768)
[28]:
(101,
   chrom   start     end  build_region_index
 0  chr1  180791  180871                  38
 1  chr1  181400  181580                  39
 2  chr1  182681  182820                  40
 3  chr1  191400  191540                  46
 4  chr1  268011  268080                  54)
[ ]:
assert (overlap_region_emb2 == overlap_region_emb).all()  # Check whether the embeddings are the same as those from the model-building step
[ ]: