File size: 3,848 Bytes
baa4839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# app.py (Final, Robust Version)

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import pickle
from huggingface_hub import hf_hub_download

# =============================================================================
# 1. LOAD MODEL, TOKENIZER, AND LABEL ENCODER
# =============================================================================
# Define the path to your model repository
model_path = "Tarive/esm2_t12_35M_UR50D-5k-families-balanced-augmented-weighted_optimized"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)

print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# Move model to GPU if available for faster inference
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Model loaded on device: {device}")

# Download and load the label encoder
print("Downloading and loading label encoder...")
encoder_path = hf_hub_download(repo_id=model_path, filename="label_encoder_5k-2.pkl")
with open(encoder_path, "rb") as f:
    label_encoder = pickle.load(f)
print("Label encoder loaded.")


# =============================================================================
# 2. DEFINE THE LOW-LEVEL PREDICTION FUNCTION
# =============================================================================
# This function manually replicates the training data processing steps.
def predict_family(sequence):
    # 1. Tokenize the input sequence with the exact same settings as training
    inputs = tokenizer(
        sequence, 
        return_tensors="pt", # Return PyTorch tensors
        truncation=True, 
        padding=True, 
        max_length=256 # Ensure this matches your training max_length
    ).to(device) # Move tokenized inputs to the same device as the model

    # 2. Get model predictions (logits)
    with torch.no_grad(): # Disable gradient calculation for efficiency
        logits = model(**inputs).logits

    # 3. Get the top 5 predictions
    top_k_indices = torch.topk(logits, 5, dim=-1).indices.squeeze().tolist()
    
    # 4. Convert logits to probabilities (softmax)
    probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()

    # 5. Decode the numerical labels back to family names
    results = {}
    for index in top_k_indices:
        family_name = label_encoder.inverse_transform([index])[0]
        confidence_score = probabilities[index]
        results[family_name] = confidence_score
        
    return results

# =============================================================================
# 3. CREATE THE GRADIO INTERFACE (No changes here)
# =============================================================================
print("Creating Gradio interface...")
iface = gr.Interface(
    fn=predict_family,
    inputs=gr.Textbox(
        lines=10, 
        label="Protein Amino Acid Sequence", 
        placeholder="Paste your protein sequence here..."
    ),
    outputs=gr.Label(
        num_top_classes=5,
        label="Predicted Families"
    ),
    title="Protein Family Classifier",
    description="This demo uses a fine-tuned ESM-2 model to predict the protein family from its amino acid sequence. Enter a sequence to see the top 5 predictions and their confidence scores.",
    examples=[
        ["LAAARMRPQDIDRFVPHQANARIFDAVGRNLGIADEAIVKTIAEYGNSSAATIPLSLSLAHRAAPFRPGEKVLLAAAGAGLSGGALVVGI"],
        ["MSLPDMRLPIQNAIFYPEMVNYTFNRLDLTSISCLTFEKPKRDLFRAIDVCEWVASMGNPYVSVLLGADDKAVELFLEGKIGFLDIPVLIESVLSSVNFHIEENLEDILRAV"],
        ["VSYISSQYPHHPDVFSVVRQACVRSLSCEVCPGREGPIFFGDEHRSHVFSHTFFLKDSQARGFQRWYSIVMVMMDKVFLLNSWPFLVKQIRNFIDQLQAKANKVYFSEQTDCPQRALRLKSSFTMTPANFRRQRSNISVRGLYELTNDKQVFYTAHVWFTWILKAC"]
    ],
    allow_flagging="never"
)

# Launch the interface!
print("Launching app...")
iface.launch()