qhuang20 ryankeivanfar commited on
Commit
7f61943
·
verified ·
1 Parent(s): be788c6

create scripts/datasets.py (#2)

Browse files

- create scripts/datasets.py (06ac44905efe4681adf6d4de95ce3ef99b9134d1)


Co-authored-by: Ryan Keivanfar <[email protected]>

Files changed (1) hide show
  1. scripts/datasets.py +557 -0
scripts/datasets.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # datasets.py
2
+
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import pandas as pd
6
+ import numpy as np
7
+ import os
8
+ import pyfaidx
9
+ import kipoiseq.transforms.functional
10
+ from rdkit import Chem
11
+ from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
12
+ from rdkit.Chem import rdFingerprintGenerator
13
+
14
+ # --- Global Config ---
15
+ # Enformer typically uses a 196,608 bp input sequence.
16
+ # We will use a shorter input (1/4 of usual length) to speed up training.
17
+ ENFORMER_INPUT_SEQ_LENGTH = 49_152
18
+
19
+ # Relative paths from the project root directory
20
+ GENOME_FASTA_PATH = "data/hg38.fa"
21
+ TSS_REGIONS_CSV_PATH = "data/Enformer_genomic_regions_TSSCenteredGenes_FixedOverlapRemoval.csv"
22
+
23
+ # Path to pseudobulk target data, matching the provided dummy file
24
+ PSEUDOBULK_TARGET_DATA_PATH = "data/pseudobulk_dummy.csv"
25
+
26
+ # ----------------------
27
+
28
+
29
+ class GenomeOneHotEncoder:
30
+ """
31
+ Encodes DNA sequences into one-hot format using kipoiseq.
32
+ """
33
+ def __init__(self, sequence_length: int = ENFORMER_INPUT_SEQ_LENGTH):
34
+ self.sequence_length = sequence_length
35
+
36
+ @staticmethod
37
+ def _one_hot_encode(sequence: str) -> np.ndarray:
38
+ ## one hot encodes DNA using the same code from the original Enformer paper.
39
+ ## Ensures one-hot encoding is consistent with representations Enformer has
40
+ ## already learned
41
+ return kipoiseq.transforms.functional.one_hot_dna(sequence).astype(np.float32)
42
+
43
+ def encode(self, seq: str) -> np.ndarray:
44
+ """
45
+ One-hot encodes a DNA sequence using kipoiseq.
46
+
47
+ Args:
48
+ seq (str): The DNA sequence string. The FastaReader should ensure this
49
+ sequence is already uppercase and of length ENFORMER_INPUT_SEQ_LENGTH.
50
+
51
+ Returns:
52
+ np.ndarray: A numpy array, typically (L, 4) for DNA, with one-hot encoded sequence.
53
+ """
54
+
55
+ return GenomeOneHotEncoder._one_hot_encode(seq)
56
+
57
+
58
+ class FastaReader:
59
+ """
60
+ Reads sequences from a FASTA file using pyfaidx.
61
+ Handles chromosome boundary conditions by padding with 'N'.
62
+ """
63
+ def __init__(self, fasta_path: str):
64
+ self.fasta_path = fasta_path
65
+ self.genome = None
66
+ try:
67
+ self.genome = pyfaidx.Fasta(self.fasta_path, sequence_always_upper=True)
68
+ print(f"Successfully loaded and indexed genome using pyfaidx from: {self.fasta_path}")
69
+ except pyfaidx.FastaIndexingError as e:
70
+ print(f"Error: Could not index FASTA file at {self.fasta_path}.")
71
+ print("Ensure it's a valid FASTA file and the .fai index can be created/read in its directory.")
72
+ print(f"pyfaidx error: {e}")
73
+ raise
74
+ except FileNotFoundError:
75
+ print(f"Error: FASTA file not found at {self.fasta_path}.")
76
+ raise
77
+
78
+ def get_sequence(self, chrom: str, start_0based: int, end_0based_exclusive: int) -> str:
79
+ """
80
+ Fetches a DNA sequence for the given 0-based genomic interval.
81
+ Pads with 'N' if the interval extends beyond chromosome boundaries.
82
+
83
+ Args:
84
+ chrom (str): Chromosome name (e.g., 'chr1').
85
+ start_0based (int): 0-based start coordinate (inclusive).
86
+ end_0based_exclusive (int): 0-based end coordinate (exclusive).
87
+
88
+ Returns:
89
+ str: The DNA sequence, padded with 'N's to match the requested length
90
+ (end_0based_exclusive - start_0based).
91
+ """
92
+ if self.genome is None:
93
+ raise RuntimeError("FastaReader not properly initialized (pyfaidx missing or genome loading failed).")
94
+
95
+ # Sanitize chromosome name (e.g., '1' vs 'chr1')
96
+ true_chrom_name = chrom
97
+ if chrom not in self.genome:
98
+ alternative_chrom_name = 'chr' + chrom if not chrom.startswith('chr') else chrom.replace('chr', '', 1)
99
+ if alternative_chrom_name in self.genome:
100
+ true_chrom_name = alternative_chrom_name
101
+ else:
102
+ available_chroms_sample = list(self.genome.keys())[:5]
103
+ raise ValueError(
104
+ f"Chromosome '{chrom}' (and alternative '{alternative_chrom_name}') not found in FASTA file. "
105
+ f"Available chromosomes sample: {available_chroms_sample}..."
106
+ )
107
+
108
+ chrom_len = len(self.genome[true_chrom_name])
109
+ seq_len_requested = end_0based_exclusive - start_0based
110
+
111
+ # init sequence with Ns for padding
112
+ sequence_parts = []
113
+
114
+ # handle padding at the beginning
115
+ padding_start_len = 0
116
+ if start_0based < 0:
117
+ padding_start_len = abs(start_0based)
118
+ sequence_parts.append('N' * padding_start_len)
119
+ effective_start = 0
120
+ else:
121
+ effective_start = start_0based
122
+
123
+ # determine the part of the sequence to fetch from FASTA
124
+ fetch_len = min(end_0based_exclusive, chrom_len) - effective_start
125
+
126
+ if fetch_len > 0:
127
+ sequence_parts.append(self.genome[true_chrom_name][effective_start : effective_start + fetch_len].seq)
128
+ elif effective_start >= chrom_len: # Requested start is beyond chromosome end
129
+ pass # No sequence to fetch, only padding needed
130
+
131
+ # handle padding at the end
132
+ current_len = sum(len(p) for p in sequence_parts)
133
+ padding_end_len = seq_len_requested - current_len
134
+ if padding_end_len > 0:
135
+ sequence_parts.append('N' * padding_end_len)
136
+
137
+ final_sequence = "".join(sequence_parts)
138
+
139
+ # Final check for length; this should be guaranteed by logic above
140
+ if len(final_sequence) != seq_len_requested:
141
+ # This indicates a logic error in padding/fetching
142
+ raise RuntimeError(
143
+ f"Internal error: Final sequence length {len(final_sequence)} for {true_chrom_name}:{start_0based}-{end_0based_exclusive} "
144
+ f"does not match requested {seq_len_requested}."
145
+ )
146
+ return final_sequence
147
+
148
+
149
+ # --- Main Dataset Classes ---
150
+
151
+ class TahoeDataset(Dataset):
152
+ """
153
+ PyTorch Dataset for loading Tahoe data for Enformer fine-tuning.
154
+ - Reads genomic regions from a regions CSV.
155
+ - Reads pseudobulk conditions and expression values from a pseudobulk CSV.
156
+ - Merges these two data sources based on gene identifiers.
157
+ - Each sample represents a unique gene-condition pair.
158
+ - Fetches DNA sequence for the gene, resized to `enformer_input_seq_length`.
159
+ - One-hot encodes sequence and returns it with the specific expression value.
160
+ """
161
+ ORIGINAL_ENFORMER_WINDOW_SIZE = 196_608
162
+
163
+ def __init__(self,
164
+ tss_regions_csv_path: str,
165
+ genome_fasta_path: str,
166
+ pseudobulk_data_path: str,
167
+ enformer_input_seq_length: int = ENFORMER_INPUT_SEQ_LENGTH,
168
+ regions_csv_gene_col: str = 'gene_name', # Gene ID column in tss_regions_csv
169
+ pseudobulk_csv_gene_col: str = 'gene_id', # Gene ID column in pseudobulk_data_csv
170
+ regions_csv_chr_col: str = 'seqnames', # Chromosome column in tss_regions_csv
171
+ regions_csv_start_col: str = 'starts', # 0-based start col in tss_regions_csv
172
+ regions_csv_end_col: str = 'ends'): # 0-based exclusive end col in tss_regions_csv
173
+ super().__init__()
174
+
175
+ self.enformer_input_seq_length = enformer_input_seq_length
176
+ # Store column names for clarity
177
+ self.regions_gene_col = regions_csv_gene_col
178
+ self.pseudobulk_gene_col = pseudobulk_csv_gene_col
179
+ self.regions_chr_col = regions_csv_chr_col
180
+ self.regions_start_col = regions_csv_start_col
181
+ self.regions_end_col = regions_csv_end_col
182
+
183
+ print(f"Initializing TahoeDataset...")
184
+ print(f" Target model input sequence length: {self.enformer_input_seq_length} bp")
185
+ print(f" Genomic regions are assumed to define a {self.ORIGINAL_ENFORMER_WINDOW_SIZE} bp window for centering.")
186
+
187
+ # Load genomic regions data
188
+ print(f" Loading TSS regions from: {tss_regions_csv_path}")
189
+ try:
190
+ regions_df = pd.read_csv(tss_regions_csv_path)
191
+ print(f" Successfully loaded regions CSV with {len(regions_df)} gene region entries.")
192
+ expected_region_cols = [self.regions_chr_col, self.regions_gene_col,
193
+ self.regions_start_col, self.regions_end_col]
194
+ missing_region_cols = [col for col in expected_region_cols if col not in regions_df.columns]
195
+ if missing_region_cols:
196
+ raise ValueError(f"Missing columns in regions CSV ('{tss_regions_csv_path}'): {missing_region_cols}. Expected: {expected_region_cols}")
197
+ except FileNotFoundError:
198
+ print(f"FATAL ERROR: Regions CSV file not found at {tss_regions_csv_path}")
199
+ raise
200
+ except Exception as e:
201
+ print(f"FATAL ERROR loading or validating regions CSV: {e}")
202
+ raise
203
+
204
+ # Load pseudobulk target data
205
+ print(f" Loading pseudobulk targets from: {pseudobulk_data_path}")
206
+ try:
207
+ pseudobulk_df = pd.read_csv(pseudobulk_data_path)
208
+ print(f" Successfully loaded pseudobulk CSV with {len(pseudobulk_df)} condition entries.")
209
+ expected_pb_cols = [self.pseudobulk_gene_col, 'cell_line', 'drug_id', 'drug_dose', 'expression']
210
+ missing_pb_cols = [col for col in expected_pb_cols if col not in pseudobulk_df.columns]
211
+ if missing_pb_cols:
212
+ raise ValueError(f"Missing columns in pseudobulk CSV ('{pseudobulk_data_path}'): {missing_pb_cols}. Expected: {expected_pb_cols}")
213
+ except FileNotFoundError:
214
+ print(f"FATAL ERROR: Pseudobulk CSV file not found at {pseudobulk_data_path}")
215
+ raise
216
+ except Exception as e:
217
+ print(f"FATAL ERROR loading or validating pseudobulk CSV: {e}")
218
+ raise
219
+
220
+ # Merge regions with pseudobulk data
221
+ print(f" Merging genomic regions with pseudobulk target data...")
222
+ print(f" Regions gene column: '{self.regions_gene_col}', Pseudobulk gene column: '{self.pseudobulk_gene_col}'")
223
+
224
+ regions_df[self.regions_gene_col] = regions_df[self.regions_gene_col].astype(str)
225
+ pseudobulk_df[self.pseudobulk_gene_col] = pseudobulk_df[self.pseudobulk_gene_col].astype(str)
226
+
227
+ self.samples_df = pd.merge(
228
+ regions_df,
229
+ pseudobulk_df,
230
+ left_on=self.regions_gene_col,
231
+ right_on=self.pseudobulk_gene_col,
232
+ how='inner' # Keeps only genes present in both DataFrames
233
+ )
234
+
235
+ if len(self.samples_df) == 0:
236
+ print("WARNING: The merge operation resulted in an empty DataFrame.")
237
+ print(f" No common genes found between column '{self.regions_gene_col}' in regions CSV ")
238
+ print(f" and column '{self.pseudobulk_gene_col}' in pseudobulk CSV.")
239
+ print(f" Please check that gene identifiers match and are of the same type in both files.")
240
+ # Example gene IDs for debugging:
241
+ if not regions_df.empty: print(f" Sample gene IDs from regions CSV: {regions_df[self.regions_gene_col].unique()[:5].tolist()}")
242
+ if not pseudobulk_df.empty: print(f" Sample gene IDs from pseudobulk CSV: {pseudobulk_df[self.pseudobulk_gene_col].unique()[:5].tolist()}")
243
+ else:
244
+ print(f" Successfully merged data: {len(self.samples_df)} total samples (gene-condition pairs).")
245
+
246
+ # Check for genes in regions_df not found in pseudobulk_df (and thus dropped)
247
+ original_region_genes = set(regions_df[self.regions_gene_col].unique())
248
+ merged_region_genes = set(self.samples_df[self.regions_gene_col].unique())
249
+ dropped_region_genes = original_region_genes - merged_region_genes
250
+ if dropped_region_genes:
251
+ print(f" WARNING: {len(dropped_region_genes)} unique gene IDs from the regions CSV ('{self.regions_gene_col}') were not found in the pseudobulk CSV ('{self.pseudobulk_gene_col}') and were dropped.")
252
+ print(f" Examples of dropped region gene IDs: {list(dropped_region_genes)[:min(5, len(dropped_region_genes))]}")
253
+
254
+ # Check for genes in pseudobulk_df not found in regions_df (and thus dropped)
255
+ original_pseudobulk_genes = set(pseudobulk_df[self.pseudobulk_gene_col].unique())
256
+
257
+ merged_pseudobulk_genes = set(self.samples_df[self.regions_gene_col].unique()) # Genes that made it into the merge, identified by the regions_gene_col key
258
+
259
+ final_merged_keys_from_pseudobulk_perspective = set(self.samples_df[self.pseudobulk_gene_col].unique())
260
+ dropped_pseudobulk_genes = original_pseudobulk_genes - final_merged_keys_from_pseudobulk_perspective
261
+
262
+ if dropped_pseudobulk_genes:
263
+ print(f" WARNING: {len(dropped_pseudobulk_genes)} unique gene IDs from the pseudobulk CSV ('{self.pseudobulk_gene_col}') were not found in the regions CSV ('{self.regions_gene_col}') and were dropped.")
264
+ print(f" Examples of dropped pseudobulk gene IDs: {list(dropped_pseudobulk_genes)[:min(5, len(dropped_pseudobulk_genes))]}")
265
+
266
+ if 'expression' in self.samples_df and self.samples_df['expression'].isnull().any():
267
+ print("WARNING: NA values found in 'expression' column after merge. These samples might cause errors or yield NaN targets.")
268
+ print(" Consider handling these (e.g., fill with a default or drop rows withna(subset=['expression'])).")
269
+ # self.samples_df.dropna(subset=['expression'], inplace=True) # Example: drop rows with NA expression
270
+
271
+ print(f" Initializing FASTA reader for genome: {genome_fasta_path}")
272
+ self.fasta_reader = FastaReader(genome_fasta_path)
273
+ self.encoder = GenomeOneHotEncoder(sequence_length=self.enformer_input_seq_length)
274
+ print("TahoeDataset initialized successfully.")
275
+
276
+ def __len__(self):
277
+ return len(self.samples_df)
278
+
279
+ def __getitem__(self, idx: int):
280
+ if torch.is_tensor(idx):
281
+ idx = idx.tolist()
282
+
283
+ if not (0 <= idx < len(self.samples_df)):
284
+ raise IndexError(f"Index {idx} out of bounds for dataset of length {len(self.samples_df)}")
285
+
286
+ sample_info = self.samples_df.iloc[idx]
287
+
288
+ try:
289
+ chrom = str(sample_info[self.regions_chr_col])
290
+ # Gene name from the regions CSV (used for merge, should be consistent)
291
+ gene_name_for_logging = str(sample_info[self.regions_gene_col])
292
+
293
+ csv_region_start = int(sample_info[self.regions_start_col])
294
+ csv_region_end = int(sample_info[self.regions_end_col])
295
+
296
+ expression_value = float(sample_info['expression']) # Assuming 'expression' is the target column
297
+ except KeyError as e:
298
+ print(f"FATAL ERROR in __getitem__ (idx {idx}): Missing expected column {e} in merged samples_df.")
299
+ print(f" Available columns: {self.samples_df.columns.tolist()}")
300
+ print(f" Sample info for this index: {sample_info.to_dict() if isinstance(sample_info, pd.Series) else sample_info}")
301
+ raise
302
+ except ValueError as e:
303
+ print(f"FATAL ERROR in __getitem__ (idx {idx}): Could not convert data for gene {gene_name_for_logging}. Error: {e}")
304
+ print(f" Expression value was: '{sample_info.get('expression', 'N/A')}'")
305
+ raise
306
+ except Exception as e: # Catch any other unexpected error for this item
307
+ print(f"FATAL ERROR in __getitem__ (idx {idx}) for gene {gene_name_for_logging}: An unexpected error occurred: {type(e).__name__} - {e}")
308
+ raise
309
+
310
+ # --- Sequence window calculation ---
311
+ actual_csv_window_len = csv_region_end - csv_region_start
312
+ if actual_csv_window_len != self.ORIGINAL_ENFORMER_WINDOW_SIZE:
313
+ # Warning if the input CSV regions are not consistently 196kb.
314
+ # The centering logic below will still try to work based on csv_region_end.
315
+ print(f"WARNING for gene {gene_name_for_logging} (idx {idx}): Region {chrom}:{csv_region_start}-{csv_region_end} from CSV "
316
+ f"has length {actual_csv_window_len}bp, but expected {self.ORIGINAL_ENFORMER_WINDOW_SIZE}bp "
317
+ f"for the original window definition used for centering. Sequence extraction might be affected if assumptions are wrong.")
318
+
319
+ # Initialize final sequence coordinates with those from the CSV.
320
+ # These will be used if no resizing is needed.
321
+ final_seq_start_0based = csv_region_start
322
+ final_seq_end_0based_exclusive = csv_region_end
323
+
324
+ # If the target model input sequence length is different from the original Enformer window size,
325
+ # recalculate start and end positions by centering the target length within the original window.
326
+ if self.enformer_input_seq_length != self.ORIGINAL_ENFORMER_WINDOW_SIZE:
327
+ # Calculate the center of the ORIGINAL_ENFORMER_WINDOW_SIZE.
328
+ # Assumes 'csv_region_end' is the exclusive end of this original window.
329
+ original_window_center = csv_region_end - (self.ORIGINAL_ENFORMER_WINDOW_SIZE // 2)
330
+
331
+ half_target_seq_len = self.enformer_input_seq_length // 2
332
+ final_seq_start_0based = original_window_center - half_target_seq_len
333
+ # Ensure the end is exclusive and maintains the correct length for the target sequence
334
+ final_seq_end_0based_exclusive = final_seq_start_0based + self.enformer_input_seq_length
335
+
336
+ # Fetch and encode DNA sequence
337
+ dna_sequence = self.fasta_reader.get_sequence(chrom, final_seq_start_0based, final_seq_end_0based_exclusive)
338
+ one_hot_sequence = self.encoder.encode(dna_sequence)
339
+ one_hot_sequence_tensor = torch.tensor(one_hot_sequence, dtype=torch.float32)
340
+
341
+ # Target is the specific expression value for this gene-condition pair
342
+ target_tensor = torch.tensor([expression_value], dtype=torch.float32)
343
+
344
+ return one_hot_sequence_tensor, target_tensor
345
+
346
+
347
+ # --- Extended Dataset for SMILES ---
348
+ class TahoeSMILESDataset(Dataset):
349
+ """
350
+ Extends TahoeDataset to also return:
351
+ - Morgan Fingerprints for the drug
352
+ - drug dose
353
+ - target expression
354
+ """
355
+ def __init__(self,
356
+ regions_csv_path: str, # Renamed from tss_regions_csv_path for clarity with config
357
+ pbulk_parquet_path: str, # Renamed from pseudobulk_data_path for clarity with config
358
+ drug_meta_csv_path: str, # Renamed from drug_metadata_path for clarity with config
359
+ fasta_file_path: str, # Renamed from genome_fasta_path for clarity with config
360
+ enformer_input_seq_length: int = ENFORMER_INPUT_SEQ_LENGTH,
361
+ # Morgan fingerprint parameters (from data_config)
362
+ morgan_fp_radius: int = 2,
363
+ morgan_fp_nbits: int = 2048,
364
+ # Column names from regions_csv (from data_config)
365
+ regions_gene_col: str = 'gene_name',
366
+ regions_chr_col: str = 'seqnames',
367
+ regions_start_col: str = 'starts',
368
+ regions_end_col: str = 'ends',
369
+ # Column names from pbulk_parquet (from data_config)
370
+ pbulk_gene_col: str = 'gene_id',
371
+ pbulk_drug_col: str = 'drug_id',
372
+ pbulk_dose_col: str = 'drug_dose',
373
+ pbulk_expr_col: str = 'expression',
374
+ pbulk_cell_line_col: str = 'cell_line',
375
+ # Column names from drug_meta_csv (from data_config)
376
+ drug_meta_id_col: str = 'drug',
377
+ drug_meta_smiles_col: str = 'canonical_smiles',
378
+ filter_drugs_by_ids: list = None, # Added from dataset_args
379
+ regions_strand_col: str = None, # Added from dataset_args, though not used in current __getitem__
380
+ regions_set_col: str = 'set', # New: Name of the column in regions_csv for data splitting
381
+ target_set: str = None # New: Specific set to load (e.g., "train", "valid", "test")
382
+ ):
383
+ super().__init__()
384
+
385
+ # store config
386
+ self.seq_len = enformer_input_seq_length
387
+ self.morgan_fp_radius = morgan_fp_radius
388
+ self.morgan_fp_nbits = morgan_fp_nbits
389
+
390
+ self.regions_gene_col = regions_gene_col
391
+ self.regions_chr_col = regions_chr_col
392
+ self.regions_start_col = regions_start_col
393
+ self.regions_end_col = regions_end_col
394
+ self.regions_set_col = regions_set_col # Store the name of the set column
395
+
396
+ self.pbulk_gene_col = pbulk_gene_col
397
+ self.pbulk_drug_col = pbulk_drug_col
398
+ self.pbulk_dose_col = pbulk_dose_col
399
+ self.pbulk_expr_col = pbulk_expr_col
400
+ self.pbulk_cell_line_col = pbulk_cell_line_col
401
+
402
+ self.drug_meta_id_col = drug_meta_id_col
403
+ self.drug_meta_smiles_col= drug_meta_smiles_col
404
+
405
+ self.target_set = target_set # Store the specific set value for this instance
406
+
407
+ # --- Morgan Fingerprint Generator (NEW) ---
408
+ self._morgan_gen = rdFingerprintGenerator.GetMorganGenerator(
409
+ radius=self.morgan_fp_radius,
410
+ fpSize=self.morgan_fp_nbits
411
+ )
412
+
413
+ # load & merge regions + pseudobulk
414
+ print(f" Loading TSS regions from: {regions_csv_path}")
415
+ try:
416
+ regs = pd.read_csv(regions_csv_path)
417
+ print(f" Successfully loaded regions CSV with {len(regs)} gene region entries.")
418
+ except FileNotFoundError:
419
+ print(f"FATAL ERROR: Regions CSV file not found at {regions_csv_path}")
420
+ raise
421
+ except Exception as e:
422
+ print(f"FATAL ERROR loading regions CSV: {e}")
423
+ raise
424
+
425
+ print(f" Loading pseudobulk targets from: {pbulk_parquet_path} (expected Parquet format)")
426
+ try:
427
+ pb = pd.read_parquet(pbulk_parquet_path)
428
+ print(f" Successfully loaded pseudobulk Parquet file with {len(pb)} entries.")
429
+ except FileNotFoundError:
430
+ print(f"FATAL ERROR: Pseudobulk Parquet file not found at {pbulk_parquet_path}")
431
+ raise
432
+ except Exception as e:
433
+ print(f"FATAL ERROR loading or parsing pseudobulk Parquet file: {e}")
434
+ print(" Ensure the file is a valid Parquet file and you have a Parquet engine like 'pyarrow' or 'fastparquet' installed.")
435
+ raise
436
+
437
+ # Ensure gene ID columns are strings for merging
438
+ regs[self.regions_gene_col] = regs[self.regions_gene_col].astype(str)
439
+ pb[self.pbulk_gene_col] = pb[self.pbulk_gene_col].astype(str)
440
+
441
+ print(f" Merging genomic regions with pseudobulk target data...")
442
+ print(f" Regions gene column: '{self.regions_gene_col}', Pseudobulk gene column: '{self.pbulk_gene_col}'")
443
+ self.samples_df = regs.merge(
444
+ pb,
445
+ left_on = self.regions_gene_col,
446
+ right_on = self.pbulk_gene_col,
447
+ how = 'inner'
448
+ )
449
+
450
+ if filter_drugs_by_ids and self.pbulk_drug_col in self.samples_df.columns:
451
+ print(f" Filtering samples to include only drugs: {filter_drugs_by_ids}")
452
+ initial_count = len(self.samples_df)
453
+ self.samples_df = self.samples_df[self.samples_df[self.pbulk_drug_col].isin(filter_drugs_by_ids)]
454
+ print(f" Retained {len(self.samples_df)} samples after drug filtering (from {initial_count}).")
455
+ if len(self.samples_df) == 0:
456
+ print("WARNING: No samples remaining after filtering by drug IDs. Check your filter_drugs_by_ids list and drug IDs in pbulk data.")
457
+
458
+ # Filter by target_set if specified
459
+ if self.target_set:
460
+ if self.regions_set_col in self.samples_df.columns:
461
+ print(f" Filtering samples for set: '{self.target_set}' using column '{self.regions_set_col}'.")
462
+ initial_count_set_filter = len(self.samples_df)
463
+ self.samples_df = self.samples_df[self.samples_df[self.regions_set_col] == self.target_set].copy()
464
+ print(f" Retained {len(self.samples_df)} samples after filtering for set '{self.target_set}' (from {initial_count_set_filter}).")
465
+ if len(self.samples_df) == 0:
466
+ print(f"WARNING: No samples remaining for this dataset instance (target_set='{self.target_set}') after filtering. Check the '{self.regions_set_col}' column in '{regions_csv_path}' for entries matching '{self.target_set}' and their overlap with pseudobulk data.")
467
+ else:
468
+ print(f"WARNING: target_set '{self.target_set}' was specified, but the column '{self.regions_set_col}' was not found in the merged DataFrame. No set-specific filtering was applied for this dataset instance. This instance will contain all data that matched other criteria.")
469
+
470
+ # load drug metadata
471
+ print(f" Loading drug metadata from: {drug_meta_csv_path}")
472
+ try:
473
+ dm = pd.read_csv(drug_meta_csv_path)
474
+ print(f" Successfully loaded drug metadata with {len(dm)} entries.")
475
+ except FileNotFoundError:
476
+ print(f"FATAL ERROR: Drug metadata CSV not found at {drug_meta_csv_path}")
477
+ raise
478
+ except Exception as e:
479
+ print(f"FATAL ERROR loading drug metadata CSV: {e}")
480
+ raise
481
+
482
+ # Ensure SMILES and ID columns are present and fill NA SMILES with empty string
483
+ if self.drug_meta_smiles_col not in dm.columns:
484
+ raise ValueError(f"SMILES column '{self.drug_meta_smiles_col}' not found in drug metadata.")
485
+ if self.drug_meta_id_col not in dm.columns:
486
+ raise ValueError(f"Drug ID column '{self.drug_meta_id_col}' not found in drug metadata.")
487
+ dm[self.drug_meta_smiles_col] = dm[self.drug_meta_smiles_col].fillna('').astype(str)
488
+ self.drug_meta = dm.set_index(self.drug_meta_id_col)
489
+
490
+ # fasta reader & one-hot encoder
491
+ self.fasta_reader = FastaReader(fasta_file_path)
492
+ self.encoder = GenomeOneHotEncoder(sequence_length=self.seq_len)
493
+ print("TahoeSMILESDataset initialized.")
494
+
495
+ def _generate_morgan_fingerprint(self, smiles_string: str) -> np.ndarray:
496
+ """Generates a Morgan fingerprint from a SMILES string using the new generator API."""
497
+ if not smiles_string:
498
+ return np.zeros(self.morgan_fp_nbits, dtype=np.float32)
499
+ try:
500
+ mol = Chem.MolFromSmiles(smiles_string)
501
+ if mol:
502
+ # Use the generator's NumPy helper:
503
+ fp_array = self._morgan_gen.GetFingerprintAsNumPy(mol)
504
+ return fp_array.astype(np.float32)
505
+ else:
506
+ return np.zeros(self.morgan_fp_nbits, dtype=np.float32)
507
+ except Exception as e:
508
+ return np.zeros(self.morgan_fp_nbits, dtype=np.float32)
509
+
510
+ def __len__(self):
511
+ return len(self.samples_df)
512
+
513
+ def __getitem__(self, idx):
514
+ row = self.samples_df.iloc[idx]
515
+
516
+ # --- DNA sequence ---
517
+ chrom = str(row[self.regions_chr_col])
518
+ start = int(row[self.regions_start_col])
519
+ end = int(row[self.regions_end_col])
520
+ orig = end - start
521
+ if self.seq_len != orig:
522
+ center = end - orig//2
523
+ half = self.seq_len//2
524
+ start, end = center-half, center+half
525
+
526
+ seq = self.fasta_reader.get_sequence(chrom, start, end)
527
+ oh = self.encoder.encode(seq)
528
+ seq_tensor = torch.tensor(oh, dtype=torch.float32)
529
+
530
+ # --- Morgan Fingerprint ---
531
+ drug_id_for_fp = row[self.pbulk_drug_col]
532
+ smiles_string = ''
533
+ if drug_id_for_fp in self.drug_meta.index:
534
+ smiles_string = self.drug_meta.loc[drug_id_for_fp, self.drug_meta_smiles_col]
535
+ # If multiple entries for a drug_id, loc might return a Series. Take the first one.
536
+ if isinstance(smiles_string, pd.Series):
537
+ smiles_string = smiles_string.iloc[0]
538
+ else:
539
+ # print(f"Warning: Drug ID {drug_id_for_fp} not found in drug_meta. Using empty SMILES for fingerprint.")
540
+ pass # SMILES string remains empty, will result in zero vector
541
+
542
+ morgan_fp = self._generate_morgan_fingerprint(str(smiles_string)) # Ensure it's a string
543
+ morgan_fp_tensor = torch.tensor(morgan_fp, dtype=torch.float32)
544
+
545
+ # --- dose & target ---
546
+ dose_val = float(row[self.pbulk_dose_col])
547
+ expression_val = float(row[self.pbulk_expr_col])
548
+
549
+ dose_tensor = torch.tensor([dose_val], dtype=torch.float32)
550
+ tgt_tensor = torch.tensor([expression_val], dtype=torch.float32)
551
+
552
+ # --- Metadata for Logging ---
553
+ gene_id_meta = str(row[self.pbulk_gene_col])
554
+ drug_id_meta = str(row[self.pbulk_drug_col])
555
+ cell_line_meta = str(row[self.pbulk_cell_line_col])
556
+
557
+ return seq_tensor, morgan_fp_tensor, dose_tensor, tgt_tensor, gene_id_meta, drug_id_meta, cell_line_meta, chrom, start, end