File size: 3,744 Bytes
0dee19d
7707dbe
 
0dee19d
 
2132cad
0dee19d
7707dbe
 
0dee19d
7707dbe
2132cad
2897cc2
c780aad
0dee19d
 
2132cad
0dee19d
 
 
 
 
 
7707dbe
0dee19d
 
2897cc2
c780aad
2132cad
c780aad
2132cad
7707dbe
 
0dee19d
7707dbe
0dee19d
7707dbe
0dee19d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7707dbe
0dee19d
 
c780aad
0dee19d
 
 
 
 
 
 
7707dbe
 
 
0dee19d
7707dbe
c780aad
7707dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
43cf177
 
 
7707dbe
0dee19d
7707dbe
 
 
c780aad
7707dbe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# app.py (Final, Robust Version)

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pickle
from huggingface_hub import hf_hub_download

# =============================================================================
# 1. LOAD MODEL, TOKENIZER, AND LABEL ENCODER
# =============================================================================
# Define the path to your model repository
model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)

print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# Move model to GPU if available for faster inference
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model loaded on device: {device}")

# Download and load the label encoder
print("Downloading and loading label encoder...")
encoder_path = hf_hub_download(repo_id=model_path, filename="label_encoder.pkl")
with open(encoder_path, "rb") as f:
    label_encoder = pickle.load(f)
print("Label encoder loaded.")


# =============================================================================
# 2. DEFINE THE LOW-LEVEL PREDICTION FUNCTION
# =============================================================================
# This function manually replicates the training data processing steps.
def predict_family(sequence):
    # 1. Tokenize the input sequence with the exact same settings as training
    inputs = tokenizer(
        sequence, 
        return_tensors="pt", # Return PyTorch tensors
        truncation=True, 
        padding=True, 
        max_length=256 # Ensure this matches your training max_length
    ).to(device) # Move tokenized inputs to the same device as the model

    # 2. Get model predictions (logits)
    with torch.no_grad(): # Disable gradient calculation for efficiency
        logits = model(**inputs).logits

    # 3. Get the top 5 predictions
    top_k_indices = torch.topk(logits, 5, dim=-1).indices.squeeze().tolist()
    
    # 4. Convert logits to probabilities (softmax)
    probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()

    # 5. Decode the numerical labels back to family names
    results = {}
    for index in top_k_indices:
        family_name = label_encoder.inverse_transform([index])[0]
        confidence_score = probabilities[index]
        results[family_name] = confidence_score
        
    return results

# =============================================================================
# 3. CREATE THE GRADIO INTERFACE (No changes here)
# =============================================================================
print("Creating Gradio interface...")
iface = gr.Interface(
    fn=predict_family,
    inputs=gr.Textbox(
        lines=10, 
        label="Protein Amino Acid Sequence", 
        placeholder="Paste your protein sequence here..."
    ),
    outputs=gr.Label(
        num_top_classes=5,
        label="Predicted Families"
    ),
    title="Protein Family Classifier",
    description="This demo uses a fine-tuned ESM-2 model to predict the protein family from its amino acid sequence. Enter a sequence to see the top 5 predictions and their confidence scores.",
    examples=[
        ["MLLVLKISRNAITTFSKEQLDSF"],
        ["SNYRPFVFKENDEVLALMAVWEFDDFIYVEHLAVDSKLRGKGVGSELIKNYLNRCDKRVFLEVEPPNCEISKKRVSFYEKLGFSF"],
        ["KRAIDLLLTLGSAILVLPLVLAIAAWIRMDSPGSPFFTQRRIGQHGREMHILKFRTMVQNAECVLHDCLAANPALNAEWERDQKLKCDPRVTRAGAFLRKTSLDELPQLWNVLRGEMSLVGPRPIVQDEVEKYGEVFDLYTRVKPGITGLWQVSGRNDVSYPQRVEMDRYYICNWSVWFDIWILAKTVPVVLH"]
    ],
    allow_flagging="never"
)

# Launch the interface!
print("Launching app...")
iface.launch()