Spaces:
Sleeping
Sleeping
# app.py (Updated Version) | |
import gradio as gr | |
from transformers import pipeline | |
import pickle | |
# ============================================================================= | |
# 1. LOAD YOUR MODEL AND THE SAVED LABEL ENCODER | |
# ============================================================================= | |
# Define the path to your model repository | |
model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k" # Make sure this is correct | |
# Load the classification pipeline | |
classifier = pipeline("text-classification", model=model_path) | |
# Load the label encoder from the file you uploaded | |
with open("label_encoder.pkl", "rb") as f: | |
label_encoder = pickle.load(f) | |
# ============================================================================= | |
# 2. DEFINE THE PREDICTION FUNCTION WITH LABEL DECODING | |
# ============================================================================= | |
# This function now decodes the labels before displaying them. | |
def predict_family(sequence): | |
# Get the top 5 predictions from the model | |
predictions = classifier(sequence, top_k=5) | |
# The model outputs labels like "LABEL_455". We need to extract the number. | |
results = {} | |
for p in predictions: | |
# Extract the number from the label string (e.g., "LABEL_455" -> 455) | |
label_index = int(p['label'].split('_')[1]) | |
# Use the label_encoder to find the original family name | |
original_label = label_encoder.inverse_transform([label_index])[0] | |
# Store the real name and score | |
results[original_label] = p['score'] | |
return results | |
# ============================================================================= | |
# 3. CREATE THE GRADIO INTERFACE (No changes here) | |
# ============================================================================= | |
iface = gr.Interface( | |
fn=predict_family, | |
inputs=gr.Textbox( | |
lines=10, | |
label="Protein Amino Acid Sequence", | |
placeholder="Paste your protein sequence here..." | |
), | |
outputs=gr.Label( | |
num_top_classes=5, | |
label="Predicted Families" | |
), | |
title="Protein Family Classifier", | |
description="This demo uses a fine-tuned ESM-2 model to predict the protein family from its amino acid sequence. Enter a sequence to see the top 5 predictions and their confidence scores.", | |
examples=[ | |
["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"], | |
["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"] | |
], | |
allow_flagging="never" | |
) | |
# Launch the interface! | |
iface.launch() |