create scripts/datasets.py (#2)
Browse files- create scripts/datasets.py (06ac44905efe4681adf6d4de95ce3ef99b9134d1)
Co-authored-by: Ryan Keivanfar <[email protected]>
- scripts/datasets.py +557 -0
scripts/datasets.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# datasets.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import pyfaidx
|
9 |
+
import kipoiseq.transforms.functional
|
10 |
+
from rdkit import Chem
|
11 |
+
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
|
12 |
+
from rdkit.Chem import rdFingerprintGenerator
|
13 |
+
|
14 |
+
# --- Global Config ---
|
15 |
+
# Enformer typically uses a 196,608 bp input sequence.
|
16 |
+
# We will use a shorter input (1/4 of usual length) to speed up training.
|
17 |
+
ENFORMER_INPUT_SEQ_LENGTH = 49_152
|
18 |
+
|
19 |
+
# Relative paths from the project root directory
|
20 |
+
GENOME_FASTA_PATH = "data/hg38.fa"
|
21 |
+
TSS_REGIONS_CSV_PATH = "data/Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv"
|
22 |
+
|
23 |
+
# Path to pseudobulk target data, matching the provided dummy file
|
24 |
+
PSEUDOBULK_TARGET_DATA_PATH = "data/pseudobulk_dummy.csv"
|
25 |
+
|
26 |
+
# ----------------------
|
27 |
+
|
28 |
+
|
29 |
+
class GenomeOneHotEncoder:
|
30 |
+
"""
|
31 |
+
Encodes DNA sequences into one-hot format using kipoiseq.
|
32 |
+
"""
|
33 |
+
def __init__(self, sequence_length: int = ENFORMER_INPUT_SEQ_LENGTH):
|
34 |
+
self.sequence_length = sequence_length
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def _one_hot_encode(sequence: str) -> np.ndarray:
|
38 |
+
## one hot encodes DNA using the same code from the original Enformer paper.
|
39 |
+
## Ensures one-hot encoding is consistent with representations Enformer has
|
40 |
+
## already learned
|
41 |
+
return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)
|
42 |
+
|
43 |
+
def encode(self, seq: str) -> np.ndarray:
|
44 |
+
"""
|
45 |
+
One-hot encodes a DNA sequence using kipoiseq.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
seq (str): The DNA sequence string. The FastaReader should ensure this
|
49 |
+
sequence is already uppercase and of length ENFORMER_INPUT_SEQ_LENGTH.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
np.ndarray: A numpy array, typically (L, 4) for DNA, with one-hot encoded sequence.
|
53 |
+
"""
|
54 |
+
|
55 |
+
return GenomeOneHotEncoder._one_hot_encode(seq)
|
56 |
+
|
57 |
+
|
58 |
+
class FastaReader:
|
59 |
+
"""
|
60 |
+
Reads sequences from a FASTA file using pyfaidx.
|
61 |
+
Handles chromosome boundary conditions by padding with 'N'.
|
62 |
+
"""
|
63 |
+
def __init__(self, fasta_path: str):
|
64 |
+
self.fasta_path = fasta_path
|
65 |
+
self.genome = None
|
66 |
+
try:
|
67 |
+
self.genome = pyfaidx.Fasta(self.fasta_path, sequence_always_upper=True)
|
68 |
+
print(f"Successfully loaded and indexed genome using pyfaidx from: {self.fasta_path}")
|
69 |
+
except pyfaidx.FastaIndexingError as e:
|
70 |
+
print(f"Error: Could not index FASTA file at {self.fasta_path}.")
|
71 |
+
print("Ensure it's a valid FASTA file and the .fai index can be created/read in its directory.")
|
72 |
+
print(f"pyfaidx error: {e}")
|
73 |
+
raise
|
74 |
+
except FileNotFoundError:
|
75 |
+
print(f"Error: FASTA file not found at {self.fasta_path}.")
|
76 |
+
raise
|
77 |
+
|
78 |
+
def get_sequence(self, chrom: str, start_0based: int, end_0based_exclusive: int) -> str:
|
79 |
+
"""
|
80 |
+
Fetches a DNA sequence for the given 0-based genomic interval.
|
81 |
+
Pads with 'N' if the interval extends beyond chromosome boundaries.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
chrom (str): Chromosome name (e.g., 'chr1').
|
85 |
+
start_0based (int): 0-based start coordinate (inclusive).
|
86 |
+
end_0based_exclusive (int): 0-based end coordinate (exclusive).
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
str: The DNA sequence, padded with 'N's to match the requested length
|
90 |
+
(end_0based_exclusive - start_0based).
|
91 |
+
"""
|
92 |
+
if self.genome is None:
|
93 |
+
raise RuntimeError("FastaReader not properly initialized (pyfaidx missing or genome loading failed).")
|
94 |
+
|
95 |
+
# Sanitize chromosome name (e.g., '1' vs 'chr1')
|
96 |
+
true_chrom_name = chrom
|
97 |
+
if chrom not in self.genome:
|
98 |
+
alternative_chrom_name = 'chr' + chrom if not chrom.startswith('chr') else chrom.replace('chr', '', 1)
|
99 |
+
if alternative_chrom_name in self.genome:
|
100 |
+
true_chrom_name = alternative_chrom_name
|
101 |
+
else:
|
102 |
+
available_chroms_sample = list(self.genome.keys())[:5]
|
103 |
+
raise ValueError(
|
104 |
+
f"Chromosome '{chrom}' (and alternative '{alternative_chrom_name}') not found in FASTA file. "
|
105 |
+
f"Available chromosomes sample: {available_chroms_sample}..."
|
106 |
+
)
|
107 |
+
|
108 |
+
chrom_len = len(self.genome[true_chrom_name])
|
109 |
+
seq_len_requested = end_0based_exclusive - start_0based
|
110 |
+
|
111 |
+
# init sequence with Ns for padding
|
112 |
+
sequence_parts = []
|
113 |
+
|
114 |
+
# handle padding at the beginning
|
115 |
+
padding_start_len = 0
|
116 |
+
if start_0based < 0:
|
117 |
+
padding_start_len = abs(start_0based)
|
118 |
+
sequence_parts.append('N' * padding_start_len)
|
119 |
+
effective_start = 0
|
120 |
+
else:
|
121 |
+
effective_start = start_0based
|
122 |
+
|
123 |
+
# determine the part of the sequence to fetch from FASTA
|
124 |
+
fetch_len = min(end_0based_exclusive, chrom_len) - effective_start
|
125 |
+
|
126 |
+
if fetch_len > 0:
|
127 |
+
sequence_parts.append(self.genome[true_chrom_name][effective_start : effective_start + fetch_len].seq)
|
128 |
+
elif effective_start >= chrom_len: # Requested start is beyond chromosome end
|
129 |
+
pass # No sequence to fetch, only padding needed
|
130 |
+
|
131 |
+
# handle padding at the end
|
132 |
+
current_len = sum(len(p) for p in sequence_parts)
|
133 |
+
padding_end_len = seq_len_requested - current_len
|
134 |
+
if padding_end_len > 0:
|
135 |
+
sequence_parts.append('N' * padding_end_len)
|
136 |
+
|
137 |
+
final_sequence = "".join(sequence_parts)
|
138 |
+
|
139 |
+
# Final check for length; this should be guaranteed by logic above
|
140 |
+
if len(final_sequence) != seq_len_requested:
|
141 |
+
# This indicates a logic error in padding/fetching
|
142 |
+
raise RuntimeError(
|
143 |
+
f"Internal error: Final sequence length {len(final_sequence)} for {true_chrom_name}:{start_0based}-{end_0based_exclusive} "
|
144 |
+
f"does not match requested {seq_len_requested}."
|
145 |
+
)
|
146 |
+
return final_sequence
|
147 |
+
|
148 |
+
|
149 |
+
# --- Main Dataset Classes ---
|
150 |
+
|
151 |
+
class TahoeDataset(Dataset):
|
152 |
+
"""
|
153 |
+
PyTorch Dataset for loading Tahoe data for Enformer fine-tuning.
|
154 |
+
- Reads genomic regions from a regions CSV.
|
155 |
+
- Reads pseudobulk conditions and expression values from a pseudobulk CSV.
|
156 |
+
- Merges these two data sources based on gene identifiers.
|
157 |
+
- Each sample represents a unique gene-condition pair.
|
158 |
+
- Fetches DNA sequence for the gene, resized to `enformer_input_seq_length`.
|
159 |
+
- One-hot encodes sequence and returns it with the specific expression value.
|
160 |
+
"""
|
161 |
+
ORIGINAL_ENFORMER_WINDOW_SIZE = 196_608
|
162 |
+
|
163 |
+
def __init__(self,
|
164 |
+
tss_regions_csv_path: str,
|
165 |
+
genome_fasta_path: str,
|
166 |
+
pseudobulk_data_path: str,
|
167 |
+
enformer_input_seq_length: int = ENFORMER_INPUT_SEQ_LENGTH,
|
168 |
+
regions_csv_gene_col: str = 'gene_name', # Gene ID column in tss_regions_csv
|
169 |
+
pseudobulk_csv_gene_col: str = 'gene_id', # Gene ID column in pseudobulk_data_csv
|
170 |
+
regions_csv_chr_col: str = 'seqnames', # Chromosome column in tss_regions_csv
|
171 |
+
regions_csv_start_col: str = 'starts', # 0-based start col in tss_regions_csv
|
172 |
+
regions_csv_end_col: str = 'ends'): # 0-based exclusive end col in tss_regions_csv
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.enformer_input_seq_length = enformer_input_seq_length
|
176 |
+
# Store column names for clarity
|
177 |
+
self.regions_gene_col = regions_csv_gene_col
|
178 |
+
self.pseudobulk_gene_col = pseudobulk_csv_gene_col
|
179 |
+
self.regions_chr_col = regions_csv_chr_col
|
180 |
+
self.regions_start_col = regions_csv_start_col
|
181 |
+
self.regions_end_col = regions_csv_end_col
|
182 |
+
|
183 |
+
print(f"Initializing TahoeDataset...")
|
184 |
+
print(f" Target model input sequence length: {self.enformer_input_seq_length} bp")
|
185 |
+
print(f" Genomic regions are assumed to define a {self.ORIGINAL_ENFORMER_WINDOW_SIZE} bp window for centering.")
|
186 |
+
|
187 |
+
# Load genomic regions data
|
188 |
+
print(f" Loading TSS regions from: {tss_regions_csv_path}")
|
189 |
+
try:
|
190 |
+
regions_df = pd.read_csv(tss_regions_csv_path)
|
191 |
+
print(f" Successfully loaded regions CSV with {len(regions_df)} gene region entries.")
|
192 |
+
expected_region_cols = [self.regions_chr_col, self.regions_gene_col,
|
193 |
+
self.regions_start_col, self.regions_end_col]
|
194 |
+
missing_region_cols = [col for col in expected_region_cols if col not in regions_df.columns]
|
195 |
+
if missing_region_cols:
|
196 |
+
raise ValueError(f"Missing columns in regions CSV ('{tss_regions_csv_path}'): {missing_region_cols}. Expected: {expected_region_cols}")
|
197 |
+
except FileNotFoundError:
|
198 |
+
print(f"FATAL ERROR: Regions CSV file not found at {tss_regions_csv_path}")
|
199 |
+
raise
|
200 |
+
except Exception as e:
|
201 |
+
print(f"FATAL ERROR loading or validating regions CSV: {e}")
|
202 |
+
raise
|
203 |
+
|
204 |
+
# Load pseudobulk target data
|
205 |
+
print(f" Loading pseudobulk targets from: {pseudobulk_data_path}")
|
206 |
+
try:
|
207 |
+
pseudobulk_df = pd.read_csv(pseudobulk_data_path)
|
208 |
+
print(f" Successfully loaded pseudobulk CSV with {len(pseudobulk_df)} condition entries.")
|
209 |
+
expected_pb_cols = [self.pseudobulk_gene_col, 'cell_line', 'drug_id', 'drug_dose', 'expression']
|
210 |
+
missing_pb_cols = [col for col in expected_pb_cols if col not in pseudobulk_df.columns]
|
211 |
+
if missing_pb_cols:
|
212 |
+
raise ValueError(f"Missing columns in pseudobulk CSV ('{pseudobulk_data_path}'): {missing_pb_cols}. Expected: {expected_pb_cols}")
|
213 |
+
except FileNotFoundError:
|
214 |
+
print(f"FATAL ERROR: Pseudobulk CSV file not found at {pseudobulk_data_path}")
|
215 |
+
raise
|
216 |
+
except Exception as e:
|
217 |
+
print(f"FATAL ERROR loading or validating pseudobulk CSV: {e}")
|
218 |
+
raise
|
219 |
+
|
220 |
+
# Merge regions with pseudobulk data
|
221 |
+
print(f" Merging genomic regions with pseudobulk target data...")
|
222 |
+
print(f" Regions gene column: '{self.regions_gene_col}', Pseudobulk gene column: '{self.pseudobulk_gene_col}'")
|
223 |
+
|
224 |
+
regions_df[self.regions_gene_col] = regions_df[self.regions_gene_col].astype(str)
|
225 |
+
pseudobulk_df[self.pseudobulk_gene_col] = pseudobulk_df[self.pseudobulk_gene_col].astype(str)
|
226 |
+
|
227 |
+
self.samples_df = pd.merge(
|
228 |
+
regions_df,
|
229 |
+
pseudobulk_df,
|
230 |
+
left_on=self.regions_gene_col,
|
231 |
+
right_on=self.pseudobulk_gene_col,
|
232 |
+
how='inner' # Keeps only genes present in both DataFrames
|
233 |
+
)
|
234 |
+
|
235 |
+
if len(self.samples_df) == 0:
|
236 |
+
print("WARNING: The merge operation resulted in an empty DataFrame.")
|
237 |
+
print(f" No common genes found between column '{self.regions_gene_col}' in regions CSV ")
|
238 |
+
print(f" and column '{self.pseudobulk_gene_col}' in pseudobulk CSV.")
|
239 |
+
print(f" Please check that gene identifiers match and are of the same type in both files.")
|
240 |
+
# Example gene IDs for debugging:
|
241 |
+
if not regions_df.empty: print(f" Sample gene IDs from regions CSV: {regions_df[self.regions_gene_col].unique()[:5].tolist()}")
|
242 |
+
if not pseudobulk_df.empty: print(f" Sample gene IDs from pseudobulk CSV: {pseudobulk_df[self.pseudobulk_gene_col].unique()[:5].tolist()}")
|
243 |
+
else:
|
244 |
+
print(f" Successfully merged data: {len(self.samples_df)} total samples (gene-condition pairs).")
|
245 |
+
|
246 |
+
# Check for genes in regions_df not found in pseudobulk_df (and thus dropped)
|
247 |
+
original_region_genes = set(regions_df[self.regions_gene_col].unique())
|
248 |
+
merged_region_genes = set(self.samples_df[self.regions_gene_col].unique())
|
249 |
+
dropped_region_genes = original_region_genes - merged_region_genes
|
250 |
+
if dropped_region_genes:
|
251 |
+
print(f" WARNING: {len(dropped_region_genes)} unique gene IDs from the regions CSV ('{self.regions_gene_col}') were not found in the pseudobulk CSV ('{self.pseudobulk_gene_col}') and were dropped.")
|
252 |
+
print(f" Examples of dropped region gene IDs: {list(dropped_region_genes)[:min(5, len(dropped_region_genes))]}")
|
253 |
+
|
254 |
+
# Check for genes in pseudobulk_df not found in regions_df (and thus dropped)
|
255 |
+
original_pseudobulk_genes = set(pseudobulk_df[self.pseudobulk_gene_col].unique())
|
256 |
+
|
257 |
+
merged_pseudobulk_genes = set(self.samples_df[self.regions_gene_col].unique()) # Genes that made it into the merge, identified by the regions_gene_col key
|
258 |
+
|
259 |
+
final_merged_keys_from_pseudobulk_perspective = set(self.samples_df[self.pseudobulk_gene_col].unique())
|
260 |
+
dropped_pseudobulk_genes = original_pseudobulk_genes - final_merged_keys_from_pseudobulk_perspective
|
261 |
+
|
262 |
+
if dropped_pseudobulk_genes:
|
263 |
+
print(f" WARNING: {len(dropped_pseudobulk_genes)} unique gene IDs from the pseudobulk CSV ('{self.pseudobulk_gene_col}') were not found in the regions CSV ('{self.regions_gene_col}') and were dropped.")
|
264 |
+
print(f" Examples of dropped pseudobulk gene IDs: {list(dropped_pseudobulk_genes)[:min(5, len(dropped_pseudobulk_genes))]}")
|
265 |
+
|
266 |
+
if 'expression' in self.samples_df and self.samples_df['expression'].isnull().any():
|
267 |
+
print("WARNING: NA values found in 'expression' column after merge. These samples might cause errors or yield NaN targets.")
|
268 |
+
print(" Consider handling these (e.g., fill with a default or drop rows withna(subset=['expression'])).")
|
269 |
+
# self.samples_df.dropna(subset=['expression'], inplace=True) # Example: drop rows with NA expression
|
270 |
+
|
271 |
+
print(f" Initializing FASTA reader for genome: {genome_fasta_path}")
|
272 |
+
self.fasta_reader = FastaReader(genome_fasta_path)
|
273 |
+
self.encoder = GenomeOneHotEncoder(sequence_length=self.enformer_input_seq_length)
|
274 |
+
print("TahoeDataset initialized successfully.")
|
275 |
+
|
276 |
+
def __len__(self):
|
277 |
+
return len(self.samples_df)
|
278 |
+
|
279 |
+
def __getitem__(self, idx: int):
|
280 |
+
if torch.is_tensor(idx):
|
281 |
+
idx = idx.tolist()
|
282 |
+
|
283 |
+
if not (0 <= idx < len(self.samples_df)):
|
284 |
+
raise IndexError(f"Index {idx} out of bounds for dataset of length {len(self.samples_df)}")
|
285 |
+
|
286 |
+
sample_info = self.samples_df.iloc[idx]
|
287 |
+
|
288 |
+
try:
|
289 |
+
chrom = str(sample_info[self.regions_chr_col])
|
290 |
+
# Gene name from the regions CSV (used for merge, should be consistent)
|
291 |
+
gene_name_for_logging = str(sample_info[self.regions_gene_col])
|
292 |
+
|
293 |
+
csv_region_start = int(sample_info[self.regions_start_col])
|
294 |
+
csv_region_end = int(sample_info[self.regions_end_col])
|
295 |
+
|
296 |
+
expression_value = float(sample_info['expression']) # Assuming 'expression' is the target column
|
297 |
+
except KeyError as e:
|
298 |
+
print(f"FATAL ERROR in __getitem__ (idx {idx}): Missing expected column {e} in merged samples_df.")
|
299 |
+
print(f" Available columns: {self.samples_df.columns.tolist()}")
|
300 |
+
print(f" Sample info for this index: {sample_info.to_dict() if isinstance(sample_info, pd.Series) else sample_info}")
|
301 |
+
raise
|
302 |
+
except ValueError as e:
|
303 |
+
print(f"FATAL ERROR in __getitem__ (idx {idx}): Could not convert data for gene {gene_name_for_logging}. Error: {e}")
|
304 |
+
print(f" Expression value was: '{sample_info.get('expression', 'N/A')}'")
|
305 |
+
raise
|
306 |
+
except Exception as e: # Catch any other unexpected error for this item
|
307 |
+
print(f"FATAL ERROR in __getitem__ (idx {idx}) for gene {gene_name_for_logging}: An unexpected error occurred: {type(e).__name__} - {e}")
|
308 |
+
raise
|
309 |
+
|
310 |
+
# --- Sequence window calculation ---
|
311 |
+
actual_csv_window_len = csv_region_end - csv_region_start
|
312 |
+
if actual_csv_window_len != self.ORIGINAL_ENFORMER_WINDOW_SIZE:
|
313 |
+
# Warning if the input CSV regions are not consistently 196kb.
|
314 |
+
# The centering logic below will still try to work based on csv_region_end.
|
315 |
+
print(f"WARNING for gene {gene_name_for_logging} (idx {idx}): Region {chrom}:{csv_region_start}-{csv_region_end} from CSV "
|
316 |
+
f"has length {actual_csv_window_len}bp, but expected {self.ORIGINAL_ENFORMER_WINDOW_SIZE}bp "
|
317 |
+
f"for the original window definition used for centering. Sequence extraction might be affected if assumptions are wrong.")
|
318 |
+
|
319 |
+
# Initialize final sequence coordinates with those from the CSV.
|
320 |
+
# These will be used if no resizing is needed.
|
321 |
+
final_seq_start_0based = csv_region_start
|
322 |
+
final_seq_end_0based_exclusive = csv_region_end
|
323 |
+
|
324 |
+
# If the target model input sequence length is different from the original Enformer window size,
|
325 |
+
# recalculate start and end positions by centering the target length within the original window.
|
326 |
+
if self.enformer_input_seq_length != self.ORIGINAL_ENFORMER_WINDOW_SIZE:
|
327 |
+
# Calculate the center of the ORIGINAL_ENFORMER_WINDOW_SIZE.
|
328 |
+
# Assumes 'csv_region_end' is the exclusive end of this original window.
|
329 |
+
original_window_center = csv_region_end - (self.ORIGINAL_ENFORMER_WINDOW_SIZE // 2)
|
330 |
+
|
331 |
+
half_target_seq_len = self.enformer_input_seq_length // 2
|
332 |
+
final_seq_start_0based = original_window_center - half_target_seq_len
|
333 |
+
# Ensure the end is exclusive and maintains the correct length for the target sequence
|
334 |
+
final_seq_end_0based_exclusive = final_seq_start_0based + self.enformer_input_seq_length
|
335 |
+
|
336 |
+
# Fetch and encode DNA sequence
|
337 |
+
dna_sequence = self.fasta_reader.get_sequence(chrom, final_seq_start_0based, final_seq_end_0based_exclusive)
|
338 |
+
one_hot_sequence = self.encoder.encode(dna_sequence)
|
339 |
+
one_hot_sequence_tensor = torch.tensor(one_hot_sequence, dtype=torch.float32)
|
340 |
+
|
341 |
+
# Target is the specific expression value for this gene-condition pair
|
342 |
+
target_tensor = torch.tensor([expression_value], dtype=torch.float32)
|
343 |
+
|
344 |
+
return one_hot_sequence_tensor, target_tensor
|
345 |
+
|
346 |
+
|
347 |
+
# --- Extended Dataset for SMILES ---
|
348 |
+
class TahoeSMILESDataset(Dataset):
|
349 |
+
"""
|
350 |
+
Extends TahoeDataset to also return:
|
351 |
+
- Morgan Fingerprints for the drug
|
352 |
+
- drug dose
|
353 |
+
- target expression
|
354 |
+
"""
|
355 |
+
def __init__(self,
|
356 |
+
regions_csv_path: str, # Renamed from tss_regions_csv_path for clarity with config
|
357 |
+
pbulk_parquet_path: str, # Renamed from pseudobulk_data_path for clarity with config
|
358 |
+
drug_meta_csv_path: str, # Renamed from drug_metadata_path for clarity with config
|
359 |
+
fasta_file_path: str, # Renamed from genome_fasta_path for clarity with config
|
360 |
+
enformer_input_seq_length: int = ENFORMER_INPUT_SEQ_LENGTH,
|
361 |
+
# Morgan fingerprint parameters (from data_config)
|
362 |
+
morgan_fp_radius: int = 2,
|
363 |
+
morgan_fp_nbits: int = 2048,
|
364 |
+
# Column names from regions_csv (from data_config)
|
365 |
+
regions_gene_col: str = 'gene_name',
|
366 |
+
regions_chr_col: str = 'seqnames',
|
367 |
+
regions_start_col: str = 'starts',
|
368 |
+
regions_end_col: str = 'ends',
|
369 |
+
# Column names from pbulk_parquet (from data_config)
|
370 |
+
pbulk_gene_col: str = 'gene_id',
|
371 |
+
pbulk_drug_col: str = 'drug_id',
|
372 |
+
pbulk_dose_col: str = 'drug_dose',
|
373 |
+
pbulk_expr_col: str = 'expression',
|
374 |
+
pbulk_cell_line_col: str = 'cell_line',
|
375 |
+
# Column names from drug_meta_csv (from data_config)
|
376 |
+
drug_meta_id_col: str = 'drug',
|
377 |
+
drug_meta_smiles_col: str = 'canonical_smiles',
|
378 |
+
filter_drugs_by_ids: list = None, # Added from dataset_args
|
379 |
+
regions_strand_col: str = None, # Added from dataset_args, though not used in current __getitem__
|
380 |
+
regions_set_col: str = 'set', # New: Name of the column in regions_csv for data splitting
|
381 |
+
target_set: str = None # New: Specific set to load (e.g., "train", "valid", "test")
|
382 |
+
):
|
383 |
+
super().__init__()
|
384 |
+
|
385 |
+
# store config
|
386 |
+
self.seq_len = enformer_input_seq_length
|
387 |
+
self.morgan_fp_radius = morgan_fp_radius
|
388 |
+
self.morgan_fp_nbits = morgan_fp_nbits
|
389 |
+
|
390 |
+
self.regions_gene_col = regions_gene_col
|
391 |
+
self.regions_chr_col = regions_chr_col
|
392 |
+
self.regions_start_col = regions_start_col
|
393 |
+
self.regions_end_col = regions_end_col
|
394 |
+
self.regions_set_col = regions_set_col # Store the name of the set column
|
395 |
+
|
396 |
+
self.pbulk_gene_col = pbulk_gene_col
|
397 |
+
self.pbulk_drug_col = pbulk_drug_col
|
398 |
+
self.pbulk_dose_col = pbulk_dose_col
|
399 |
+
self.pbulk_expr_col = pbulk_expr_col
|
400 |
+
self.pbulk_cell_line_col = pbulk_cell_line_col
|
401 |
+
|
402 |
+
self.drug_meta_id_col = drug_meta_id_col
|
403 |
+
self.drug_meta_smiles_col= drug_meta_smiles_col
|
404 |
+
|
405 |
+
self.target_set = target_set # Store the specific set value for this instance
|
406 |
+
|
407 |
+
# --- Morgan Fingerprint Generator (NEW) ---
|
408 |
+
self._morgan_gen = rdFingerprintGenerator.GetMorganGenerator(
|
409 |
+
radius=self.morgan_fp_radius,
|
410 |
+
fpSize=self.morgan_fp_nbits
|
411 |
+
)
|
412 |
+
|
413 |
+
# load & merge regions + pseudobulk
|
414 |
+
print(f" Loading TSS regions from: {regions_csv_path}")
|
415 |
+
try:
|
416 |
+
regs = pd.read_csv(regions_csv_path)
|
417 |
+
print(f" Successfully loaded regions CSV with {len(regs)} gene region entries.")
|
418 |
+
except FileNotFoundError:
|
419 |
+
print(f"FATAL ERROR: Regions CSV file not found at {regions_csv_path}")
|
420 |
+
raise
|
421 |
+
except Exception as e:
|
422 |
+
print(f"FATAL ERROR loading regions CSV: {e}")
|
423 |
+
raise
|
424 |
+
|
425 |
+
print(f" Loading pseudobulk targets from: {pbulk_parquet_path} (expected Parquet format)")
|
426 |
+
try:
|
427 |
+
pb = pd.read_parquet(pbulk_parquet_path)
|
428 |
+
print(f" Successfully loaded pseudobulk Parquet file with {len(pb)} entries.")
|
429 |
+
except FileNotFoundError:
|
430 |
+
print(f"FATAL ERROR: Pseudobulk Parquet file not found at {pbulk_parquet_path}")
|
431 |
+
raise
|
432 |
+
except Exception as e:
|
433 |
+
print(f"FATAL ERROR loading or parsing pseudobulk Parquet file: {e}")
|
434 |
+
print(" Ensure the file is a valid Parquet file and you have a Parquet engine like 'pyarrow' or 'fastparquet' installed.")
|
435 |
+
raise
|
436 |
+
|
437 |
+
# Ensure gene ID columns are strings for merging
|
438 |
+
regs[self.regions_gene_col] = regs[self.regions_gene_col].astype(str)
|
439 |
+
pb[self.pbulk_gene_col] = pb[self.pbulk_gene_col].astype(str)
|
440 |
+
|
441 |
+
print(f" Merging genomic regions with pseudobulk target data...")
|
442 |
+
print(f" Regions gene column: '{self.regions_gene_col}', Pseudobulk gene column: '{self.pbulk_gene_col}'")
|
443 |
+
self.samples_df = regs.merge(
|
444 |
+
pb,
|
445 |
+
left_on = self.regions_gene_col,
|
446 |
+
right_on = self.pbulk_gene_col,
|
447 |
+
how = 'inner'
|
448 |
+
)
|
449 |
+
|
450 |
+
if filter_drugs_by_ids and self.pbulk_drug_col in self.samples_df.columns:
|
451 |
+
print(f" Filtering samples to include only drugs: {filter_drugs_by_ids}")
|
452 |
+
initial_count = len(self.samples_df)
|
453 |
+
self.samples_df = self.samples_df[self.samples_df[self.pbulk_drug_col].isin(filter_drugs_by_ids)]
|
454 |
+
print(f" Retained {len(self.samples_df)} samples after drug filtering (from {initial_count}).")
|
455 |
+
if len(self.samples_df) == 0:
|
456 |
+
print("WARNING: No samples remaining after filtering by drug IDs. Check your filter_drugs_by_ids list and drug IDs in pbulk data.")
|
457 |
+
|
458 |
+
# Filter by target_set if specified
|
459 |
+
if self.target_set:
|
460 |
+
if self.regions_set_col in self.samples_df.columns:
|
461 |
+
print(f" Filtering samples for set: '{self.target_set}' using column '{self.regions_set_col}'.")
|
462 |
+
initial_count_set_filter = len(self.samples_df)
|
463 |
+
self.samples_df = self.samples_df[self.samples_df[self.regions_set_col] == self.target_set].copy()
|
464 |
+
print(f" Retained {len(self.samples_df)} samples after filtering for set '{self.target_set}' (from {initial_count_set_filter}).")
|
465 |
+
if len(self.samples_df) == 0:
|
466 |
+
print(f"WARNING: No samples remaining for this dataset instance (target_set='{self.target_set}') after filtering. Check the '{self.regions_set_col}' column in '{regions_csv_path}' for entries matching '{self.target_set}' and their overlap with pseudobulk data.")
|
467 |
+
else:
|
468 |
+
print(f"WARNING: target_set '{self.target_set}' was specified, but the column '{self.regions_set_col}' was not found in the merged DataFrame. No set-specific filtering was applied for this dataset instance. This instance will contain all data that matched other criteria.")
|
469 |
+
|
470 |
+
# load drug metadata
|
471 |
+
print(f" Loading drug metadata from: {drug_meta_csv_path}")
|
472 |
+
try:
|
473 |
+
dm = pd.read_csv(drug_meta_csv_path)
|
474 |
+
print(f" Successfully loaded drug metadata with {len(dm)} entries.")
|
475 |
+
except FileNotFoundError:
|
476 |
+
print(f"FATAL ERROR: Drug metadata CSV not found at {drug_meta_csv_path}")
|
477 |
+
raise
|
478 |
+
except Exception as e:
|
479 |
+
print(f"FATAL ERROR loading drug metadata CSV: {e}")
|
480 |
+
raise
|
481 |
+
|
482 |
+
# Ensure SMILES and ID columns are present and fill NA SMILES with empty string
|
483 |
+
if self.drug_meta_smiles_col not in dm.columns:
|
484 |
+
raise ValueError(f"SMILES column '{self.drug_meta_smiles_col}' not found in drug metadata.")
|
485 |
+
if self.drug_meta_id_col not in dm.columns:
|
486 |
+
raise ValueError(f"Drug ID column '{self.drug_meta_id_col}' not found in drug metadata.")
|
487 |
+
dm[self.drug_meta_smiles_col] = dm[self.drug_meta_smiles_col].fillna('').astype(str)
|
488 |
+
self.drug_meta = dm.set_index(self.drug_meta_id_col)
|
489 |
+
|
490 |
+
# fasta reader & one-hot encoder
|
491 |
+
self.fasta_reader = FastaReader(fasta_file_path)
|
492 |
+
self.encoder = GenomeOneHotEncoder(sequence_length=self.seq_len)
|
493 |
+
print("TahoeSMILESDataset initialized.")
|
494 |
+
|
495 |
+
def _generate_morgan_fingerprint(self, smiles_string: str) -> np.ndarray:
|
496 |
+
"""Generates a Morgan fingerprint from a SMILES string using the new generator API."""
|
497 |
+
if not smiles_string:
|
498 |
+
return np.zeros(self.morgan_fp_nbits, dtype=np.float32)
|
499 |
+
try:
|
500 |
+
mol = Chem.MolFromSmiles(smiles_string)
|
501 |
+
if mol:
|
502 |
+
# Use the generator's NumPy helper:
|
503 |
+
fp_array = self._morgan_gen.GetFingerprintAsNumPy(mol)
|
504 |
+
return fp_array.astype(np.float32)
|
505 |
+
else:
|
506 |
+
return np.zeros(self.morgan_fp_nbits, dtype=np.float32)
|
507 |
+
except Exception as e:
|
508 |
+
return np.zeros(self.morgan_fp_nbits, dtype=np.float32)
|
509 |
+
|
510 |
+
def __len__(self):
|
511 |
+
return len(self.samples_df)
|
512 |
+
|
513 |
+
def __getitem__(self, idx):
|
514 |
+
row = self.samples_df.iloc[idx]
|
515 |
+
|
516 |
+
# --- DNA sequence ---
|
517 |
+
chrom = str(row[self.regions_chr_col])
|
518 |
+
start = int(row[self.regions_start_col])
|
519 |
+
end = int(row[self.regions_end_col])
|
520 |
+
orig = end - start
|
521 |
+
if self.seq_len != orig:
|
522 |
+
center = end - orig//2
|
523 |
+
half = self.seq_len//2
|
524 |
+
start, end = center-half, center+half
|
525 |
+
|
526 |
+
seq = self.fasta_reader.get_sequence(chrom, start, end)
|
527 |
+
oh = self.encoder.encode(seq)
|
528 |
+
seq_tensor = torch.tensor(oh, dtype=torch.float32)
|
529 |
+
|
530 |
+
# --- Morgan Fingerprint ---
|
531 |
+
drug_id_for_fp = row[self.pbulk_drug_col]
|
532 |
+
smiles_string = ''
|
533 |
+
if drug_id_for_fp in self.drug_meta.index:
|
534 |
+
smiles_string = self.drug_meta.loc[drug_id_for_fp, self.drug_meta_smiles_col]
|
535 |
+
# If multiple entries for a drug_id, loc might return a Series. Take the first one.
|
536 |
+
if isinstance(smiles_string, pd.Series):
|
537 |
+
smiles_string = smiles_string.iloc[0]
|
538 |
+
else:
|
539 |
+
# print(f"Warning: Drug ID {drug_id_for_fp} not found in drug_meta. Using empty SMILES for fingerprint.")
|
540 |
+
pass # SMILES string remains empty, will result in zero vector
|
541 |
+
|
542 |
+
morgan_fp = self._generate_morgan_fingerprint(str(smiles_string)) # Ensure it's a string
|
543 |
+
morgan_fp_tensor = torch.tensor(morgan_fp, dtype=torch.float32)
|
544 |
+
|
545 |
+
# --- dose & target ---
|
546 |
+
dose_val = float(row[self.pbulk_dose_col])
|
547 |
+
expression_val = float(row[self.pbulk_expr_col])
|
548 |
+
|
549 |
+
dose_tensor = torch.tensor([dose_val], dtype=torch.float32)
|
550 |
+
tgt_tensor = torch.tensor([expression_val], dtype=torch.float32)
|
551 |
+
|
552 |
+
# --- Metadata for Logging ---
|
553 |
+
gene_id_meta = str(row[self.pbulk_gene_col])
|
554 |
+
drug_id_meta = str(row[self.pbulk_drug_col])
|
555 |
+
cell_line_meta = str(row[self.pbulk_cell_line_col])
|
556 |
+
|
557 |
+
return seq_tensor, morgan_fp_tensor, dose_tensor, tgt_tensor, gene_id_meta, drug_id_meta, cell_line_meta, chrom, start, end
|