# app.py (Final, Robust Version) import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import pickle from huggingface_hub import hf_hub_download # ============================================================================= # 1. LOAD MODEL, TOKENIZER, AND LABEL ENCODER # ============================================================================= # Define the path to your model repository model_path = "Tarive/esm2_t12_35M_UR50D-5k-families-balanced-augmented-weighted_optimized" print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_path) print("Loading model...") model = AutoModelForSequenceClassification.from_pretrained(model_path) # Move model to GPU if available for faster inference device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) print(f"Model loaded on device: {device}") # Download and load the label encoder print("Downloading and loading label encoder...") encoder_path = hf_hub_download(repo_id=model_path, filename="label_encoder_5k-2.pkl") with open(encoder_path, "rb") as f: label_encoder = pickle.load(f) print("Label encoder loaded.") # ============================================================================= # 2. DEFINE THE LOW-LEVEL PREDICTION FUNCTION # ============================================================================= # This function manually replicates the training data processing steps. def predict_family(sequence): # 1. Tokenize the input sequence with the exact same settings as training inputs = tokenizer( sequence, return_tensors="pt", # Return PyTorch tensors truncation=True, padding=True, max_length=256 # Ensure this matches your training max_length ).to(device) # Move tokenized inputs to the same device as the model # 2. Get model predictions (logits) with torch.no_grad(): # Disable gradient calculation for efficiency logits = model(**inputs).logits # 3. Get the top 5 predictions top_k_indices = torch.topk(logits, 5, dim=-1).indices.squeeze().tolist() # 4. Convert logits to probabilities (softmax) probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist() # 5. Decode the numerical labels back to family names results = {} for index in top_k_indices: family_name = label_encoder.inverse_transform([index])[0] confidence_score = probabilities[index] results[family_name] = confidence_score return results # ============================================================================= # 3. CREATE THE GRADIO INTERFACE (No changes here) # ============================================================================= print("Creating Gradio interface...") iface = gr.Interface( fn=predict_family, inputs=gr.Textbox( lines=10, label="Protein Amino Acid Sequence", placeholder="Paste your protein sequence here..." ), outputs=gr.Label( num_top_classes=5, label="Predicted Families" ), title="Protein Family Classifier", description="This demo uses a fine-tuned ESM-2 model to predict the protein family from its amino acid sequence. Enter a sequence to see the top 5 predictions and their confidence scores.", examples=[ ["LAAARMRPQDIDRFVPHQANARIFDAVGRNLGIADEAIVKTIAEYGNSSAATIPLSLSLAHRAAPFRPGEKVLLAAAGAGLSGGALVVGI"], ["MSLPDMRLPIQNAIFYPEMVNYTFNRLDLTSISCLTFEKPKRDLFRAIDVCEWVASMGNPYVSVLLGADDKAVELFLEGKIGFLDIPVLIESVLSSVNFHIEENLEDILRAV"], ["VSYISSQYPHHPDVFSVVRQACVRSLSCEVCPGREGPIFFGDEHRSHVFSHTFFLKDSQARGFQRWYSIVMVMMDKVFLLNSWPFLVKQIRNFIDQLQAKANKVYFSEQTDCPQRALRLKSSFTMTPANFRRQRSNISVRGLYELTNDKQVFYTAHVWFTWILKAC"] ], allow_flagging="never" ) # Launch the interface! print("Launching app...") iface.launch()