Spaces:
Sleeping
Sleeping
File size: 3,744 Bytes
0dee19d 7707dbe 0dee19d 2132cad 0dee19d 7707dbe 0dee19d 7707dbe 2132cad 2897cc2 c780aad 0dee19d 2132cad 0dee19d 7707dbe 0dee19d 2897cc2 c780aad 2132cad c780aad 2132cad 7707dbe 0dee19d 7707dbe 0dee19d 7707dbe 0dee19d 7707dbe 0dee19d c780aad 0dee19d 7707dbe 0dee19d 7707dbe c780aad 7707dbe 43cf177 7707dbe 0dee19d 7707dbe c780aad 7707dbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# app.py (Final, Robust Version)
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pickle
from huggingface_hub import hf_hub_download
# =============================================================================
# 1. LOAD MODEL, TOKENIZER, AND LABEL ENCODER
# =============================================================================
# Define the path to your model repository
model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# Move model to GPU if available for faster inference
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model loaded on device: {device}")
# Download and load the label encoder
print("Downloading and loading label encoder...")
encoder_path = hf_hub_download(repo_id=model_path, filename="label_encoder.pkl")
with open(encoder_path, "rb") as f:
label_encoder = pickle.load(f)
print("Label encoder loaded.")
# =============================================================================
# 2. DEFINE THE LOW-LEVEL PREDICTION FUNCTION
# =============================================================================
# This function manually replicates the training data processing steps.
def predict_family(sequence):
# 1. Tokenize the input sequence with the exact same settings as training
inputs = tokenizer(
sequence,
return_tensors="pt", # Return PyTorch tensors
truncation=True,
padding=True,
max_length=256 # Ensure this matches your training max_length
).to(device) # Move tokenized inputs to the same device as the model
# 2. Get model predictions (logits)
with torch.no_grad(): # Disable gradient calculation for efficiency
logits = model(**inputs).logits
# 3. Get the top 5 predictions
top_k_indices = torch.topk(logits, 5, dim=-1).indices.squeeze().tolist()
# 4. Convert logits to probabilities (softmax)
probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
# 5. Decode the numerical labels back to family names
results = {}
for index in top_k_indices:
family_name = label_encoder.inverse_transform([index])[0]
confidence_score = probabilities[index]
results[family_name] = confidence_score
return results
# =============================================================================
# 3. CREATE THE GRADIO INTERFACE (No changes here)
# =============================================================================
print("Creating Gradio interface...")
iface = gr.Interface(
fn=predict_family,
inputs=gr.Textbox(
lines=10,
label="Protein Amino Acid Sequence",
placeholder="Paste your protein sequence here..."
),
outputs=gr.Label(
num_top_classes=5,
label="Predicted Families"
),
title="Protein Family Classifier",
description="This demo uses a fine-tuned ESM-2 model to predict the protein family from its amino acid sequence. Enter a sequence to see the top 5 predictions and their confidence scores.",
examples=[
["MLLVLKISRNAITTFSKEQLDSF"],
["SNYRPFVFKENDEVLALMAVWEFDDFIYVEHLAVDSKLRGKGVGSELIKNYLNRCDKRVFLEVEPPNCEISKKRVSFYEKLGFSF"],
["KRAIDLLLTLGSAILVLPLVLAIAAWIRMDSPGSPFFTQRRIGQHGREMHILKFRTMVQNAECVLHDCLAANPALNAEWERDQKLKCDPRVTRAGAFLRKTSLDELPQLWNVLRGEMSLVGPRPIVQDEVEKYGEVFDLYTRVKPGITGLWQVSGRNDVSYPQRVEMDRYYICNWSVWFDIWILAKTVPVVLH"]
],
allow_flagging="never"
)
# Launch the interface!
print("Launching app...")
iface.launch() |