Spaces:
Running
Running
| import torch | |
| import selfies as sf | |
| import numpy as np | |
| import pandas as pd | |
| from rdkit import Chem | |
| from transformers import AutoTokenizer, AutoModel | |
| import gc | |
| from torch.utils.data import DataLoader, Dataset | |
| from multiprocessing import Pool, cpu_count | |
| from tqdm import tqdm | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| class SELFIESDataset(Dataset): | |
| def __init__(self, selfies_list): | |
| self.selfies = selfies_list | |
| def __len__(self): | |
| return len(self.selfies) | |
| def __getitem__(self, idx): | |
| return self.selfies[idx] | |
| class SELFIES(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = None | |
| self.tokenizer = None | |
| self.invalid = [] | |
| def smiles_to_selfies(self, smiles): | |
| try: | |
| return sf.encoder(smiles.strip()).replace('][', '] [') | |
| except: | |
| try: | |
| smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.strip())) | |
| return sf.encoder(smiles).replace('][', '] [') | |
| except: | |
| return None | |
| def get_selfies(self, smiles_list): | |
| with Pool(cpu_count()) as pool: | |
| selfies = list(pool.map(self.smiles_to_selfies, smiles_list)) | |
| self.invalid = [i for i, s in enumerate(selfies) if s is None] | |
| selfies = [s if s is not None else '[nop]' for s in selfies] | |
| return selfies | |
| def get_embedding_batch(self, selfies_batch): | |
| encodings = self.tokenizer( | |
| selfies_batch, | |
| return_tensors='pt', | |
| max_length=128, | |
| truncation=True, | |
| padding='max_length' | |
| ) | |
| encodings = {k: v.to(self.model.device) for k, v in encodings.items()} | |
| outputs = self.model.encoder( | |
| input_ids=encodings['input_ids'], | |
| attention_mask=encodings['attention_mask'] | |
| ) | |
| model_output = outputs.last_hidden_state | |
| input_mask_expanded = encodings['attention_mask'].unsqueeze(-1).expand(model_output.size()).float() | |
| sum_embeddings = torch.sum(model_output * input_mask_expanded, 1) | |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| pooled_output = sum_embeddings / sum_mask | |
| return pooled_output.cpu().numpy() | |
| def load(self, checkpoint=None): | |
| self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted") | |
| self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted") | |
| self.model.eval() | |
| def encode(self, smiles_list=[], use_gpu=False, return_tensor=False, batch_size=128, num_workers=4): | |
| selfies = self.get_selfies(smiles_list) | |
| dataset = SELFIESDataset(selfies) | |
| device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu") | |
| self.model.to(device) | |
| loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) | |
| embeddings = [] | |
| for batch in tqdm(loader, desc="Encoding"): | |
| emb = self.get_embedding_batch(batch) | |
| embeddings.append(emb) | |
| del emb | |
| gc.collect() | |
| emb = np.vstack(embeddings) | |
| for idx in self.invalid: | |
| emb[idx] = np.nan | |
| print(f"Cannot encode {smiles_list[idx]} to selfies. Embedding replaced by NaN.") | |
| return torch.tensor(emb) if return_tensor else pd.DataFrame(emb) | |