Tarive commited on
Commit
baa4839
·
verified ·
1 Parent(s): 2f2d3f5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (Final, Robust Version)
2
+
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+ import pickle
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ # =============================================================================
10
+ # 1. LOAD MODEL, TOKENIZER, AND LABEL ENCODER
11
+ # =============================================================================
12
+ # Define the path to your model repository
13
+ model_path = "Tarive/esm2_t12_35M_UR50D-5k-families-balanced-augmented-weighted_optimized"
14
+
15
+ print("Loading tokenizer...")
16
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
17
+
18
+ print("Loading model...")
19
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
20
+ # Move model to GPU if available for faster inference
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ model.to(device)
23
+ print(f"Model loaded on device: {device}")
24
+
25
+ # Download and load the label encoder
26
+ print("Downloading and loading label encoder...")
27
+ encoder_path = hf_hub_download(repo_id=model_path, filename="label_encoder_5k-2.pkl")
28
+ with open(encoder_path, "rb") as f:
29
+ label_encoder = pickle.load(f)
30
+ print("Label encoder loaded.")
31
+
32
+
33
+ # =============================================================================
34
+ # 2. DEFINE THE LOW-LEVEL PREDICTION FUNCTION
35
+ # =============================================================================
36
+ # This function manually replicates the training data processing steps.
37
+ def predict_family(sequence):
38
+ # 1. Tokenize the input sequence with the exact same settings as training
39
+ inputs = tokenizer(
40
+ sequence,
41
+ return_tensors="pt", # Return PyTorch tensors
42
+ truncation=True,
43
+ padding=True,
44
+ max_length=256 # Ensure this matches your training max_length
45
+ ).to(device) # Move tokenized inputs to the same device as the model
46
+
47
+ # 2. Get model predictions (logits)
48
+ with torch.no_grad(): # Disable gradient calculation for efficiency
49
+ logits = model(**inputs).logits
50
+
51
+ # 3. Get the top 5 predictions
52
+ top_k_indices = torch.topk(logits, 5, dim=-1).indices.squeeze().tolist()
53
+
54
+ # 4. Convert logits to probabilities (softmax)
55
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
56
+
57
+ # 5. Decode the numerical labels back to family names
58
+ results = {}
59
+ for index in top_k_indices:
60
+ family_name = label_encoder.inverse_transform([index])[0]
61
+ confidence_score = probabilities[index]
62
+ results[family_name] = confidence_score
63
+
64
+ return results
65
+
66
+ # =============================================================================
67
+ # 3. CREATE THE GRADIO INTERFACE (No changes here)
68
+ # =============================================================================
69
+ print("Creating Gradio interface...")
70
+ iface = gr.Interface(
71
+ fn=predict_family,
72
+ inputs=gr.Textbox(
73
+ lines=10,
74
+ label="Protein Amino Acid Sequence",
75
+ placeholder="Paste your protein sequence here..."
76
+ ),
77
+ outputs=gr.Label(
78
+ num_top_classes=5,
79
+ label="Predicted Families"
80
+ ),
81
+ title="Protein Family Classifier",
82
+ description="This demo uses a fine-tuned ESM-2 model to predict the protein family from its amino acid sequence. Enter a sequence to see the top 5 predictions and their confidence scores.",
83
+ examples=[
84
+ ["LAAARMRPQDIDRFVPHQANARIFDAVGRNLGIADEAIVKTIAEYGNSSAATIPLSLSLAHRAAPFRPGEKVLLAAAGAGLSGGALVVGI"],
85
+ ["MSLPDMRLPIQNAIFYPEMVNYTFNRLDLTSISCLTFEKPKRDLFRAIDVCEWVASMGNPYVSVLLGADDKAVELFLEGKIGFLDIPVLIESVLSSVNFHIEENLEDILRAV"],
86
+ ["VSYISSQYPHHPDVFSVVRQACVRSLSCEVCPGREGPIFFGDEHRSHVFSHTFFLKDSQARGFQRWYSIVMVMMDKVFLLNSWPFLVKQIRNFIDQLQAKANKVYFSEQTDCPQRALRLKSSFTMTPANFRRQRSNISVRGLYELTNDKQVFYTAHVWFTWILKAC"]
87
+ ],
88
+ allow_flagging="never"
89
+ )
90
+
91
+ # Launch the interface!
92
+ print("Launching app...")
93
+ iface.launch()