Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,21 +1,34 @@
|
|
1 |
-
# app.py (
|
2 |
|
3 |
import gradio as gr
|
4 |
from transformers import pipeline
|
5 |
import pickle
|
|
|
6 |
|
7 |
# =============================================================================
|
8 |
# 1. LOAD YOUR MODEL AND THE SAVED LABEL ENCODER
|
9 |
# =============================================================================
|
10 |
# Define the path to your model repository
|
11 |
-
model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Load the classification pipeline
|
|
|
14 |
classifier = pipeline("text-classification", model=model_path)
|
|
|
15 |
|
16 |
-
# Load the label encoder from the
|
17 |
-
|
|
|
18 |
label_encoder = pickle.load(f)
|
|
|
19 |
|
20 |
|
21 |
# =============================================================================
|
@@ -29,20 +42,25 @@ def predict_family(sequence):
|
|
29 |
# The model outputs labels like "LABEL_455". We need to extract the number.
|
30 |
results = {}
|
31 |
for p in predictions:
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
41 |
return results
|
42 |
|
43 |
# =============================================================================
|
44 |
-
# 3. CREATE THE GRADIO INTERFACE
|
45 |
# =============================================================================
|
|
|
46 |
iface = gr.Interface(
|
47 |
fn=predict_family,
|
48 |
inputs=gr.Textbox(
|
@@ -58,10 +76,13 @@ iface = gr.Interface(
|
|
58 |
description="This demo uses a fine-tuned ESM-2 model to predict the protein family from its amino acid sequence. Enter a sequence to see the top 5 predictions and their confidence scores.",
|
59 |
examples=[
|
60 |
["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"],
|
61 |
-
["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"]
|
|
|
62 |
],
|
63 |
-
allow_flagging="never"
|
64 |
)
|
|
|
65 |
|
66 |
# Launch the interface!
|
|
|
67 |
iface.launch()
|
|
|
1 |
+
# app.py (Final Corrected Version)
|
2 |
|
3 |
import gradio as gr
|
4 |
from transformers import pipeline
|
5 |
import pickle
|
6 |
+
from huggingface_hub import hf_hub_download # Import the download function
|
7 |
|
8 |
# =============================================================================
|
9 |
# 1. LOAD YOUR MODEL AND THE SAVED LABEL ENCODER
|
10 |
# =============================================================================
|
11 |
# Define the path to your model repository
|
12 |
+
model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k"
|
13 |
+
|
14 |
+
# --- FIX FOR LFS ---
|
15 |
+
# Explicitly download the label_encoder.pkl file from the repo.
|
16 |
+
# This ensures the app can find the file even if it's stored with Git LFS.
|
17 |
+
print("Downloading label encoder...")
|
18 |
+
encoder_path = hf_hub_download(repo_id=model_path, filename="label_encoder.pkl")
|
19 |
+
print("Download complete.")
|
20 |
+
# --- END FIX ---
|
21 |
|
22 |
# Load the classification pipeline
|
23 |
+
print("Loading classification pipeline...")
|
24 |
classifier = pipeline("text-classification", model=model_path)
|
25 |
+
print("Pipeline loaded.")
|
26 |
|
27 |
+
# Load the label encoder from the path where it was downloaded
|
28 |
+
print("Loading label encoder...")
|
29 |
+
with open(encoder_path, "rb") as f:
|
30 |
label_encoder = pickle.load(f)
|
31 |
+
print("Label encoder loaded.")
|
32 |
|
33 |
|
34 |
# =============================================================================
|
|
|
42 |
# The model outputs labels like "LABEL_455". We need to extract the number.
|
43 |
results = {}
|
44 |
for p in predictions:
|
45 |
+
try:
|
46 |
+
# Extract the number from the label string (e.g., "LABEL_455" -> 455)
|
47 |
+
label_index = int(p['label'].split('_')[1])
|
48 |
+
|
49 |
+
# Use the label_encoder to find the original family name
|
50 |
+
original_label = label_encoder.inverse_transform([label_index])[0]
|
51 |
+
|
52 |
+
# Store the real name and score
|
53 |
+
results[original_label] = p['score']
|
54 |
+
except (ValueError, IndexError):
|
55 |
+
# Handle cases where the label format is unexpected
|
56 |
+
results[p['label']] = p['score']
|
57 |
+
|
58 |
return results
|
59 |
|
60 |
# =============================================================================
|
61 |
+
# 3. CREATE THE GRADIO INTERFACE
|
62 |
# =============================================================================
|
63 |
+
print("Creating Gradio interface...")
|
64 |
iface = gr.Interface(
|
65 |
fn=predict_family,
|
66 |
inputs=gr.Textbox(
|
|
|
76 |
description="This demo uses a fine-tuned ESM-2 model to predict the protein family from its amino acid sequence. Enter a sequence to see the top 5 predictions and their confidence scores.",
|
77 |
examples=[
|
78 |
["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"],
|
79 |
+
["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"],
|
80 |
+
["MNGTEGPNFYVPFSNKTGVVRSPFEAPQYYLAEPWQFSMLAAYMFLLIMLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVFGGFTTTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLVGWSRYIPEGMQCSCGIDYYTPHEETNNESFVIYMFVVHFIIPLIVIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWLPYAGVAFYIFTHQGSDFGPIFMTIPAFFAKTSAVYNPVIYIMMNKQFRNCMVTTLCCGKNPLGDDEASTTVSKTETSQVAPA"]
|
81 |
],
|
82 |
+
allow_flagging="never" # Disables the "Flag" button for a cleaner interface
|
83 |
)
|
84 |
+
print("Interface created.")
|
85 |
|
86 |
# Launch the interface!
|
87 |
+
print("Launching app...")
|
88 |
iface.launch()
|