Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,40 +1,48 @@
|
|
1 |
-
# app.py
|
2 |
|
3 |
-
# Import the necessary libraries
|
4 |
import gradio as gr
|
5 |
from transformers import pipeline
|
|
|
6 |
|
7 |
# =============================================================================
|
8 |
-
# 1. LOAD YOUR MODEL
|
9 |
# =============================================================================
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
-
#
|
14 |
-
# For example: "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k"
|
15 |
-
#
|
16 |
-
model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k"
|
17 |
classifier = pipeline("text-classification", model=model_path)
|
18 |
|
|
|
|
|
|
|
|
|
19 |
|
20 |
# =============================================================================
|
21 |
-
# 2. DEFINE THE PREDICTION FUNCTION
|
22 |
# =============================================================================
|
23 |
-
# This function
|
24 |
def predict_family(sequence):
|
25 |
-
#
|
26 |
-
|
27 |
-
predictions = classifier(sequence, top_k=5) # Get the top 5 predictions
|
28 |
|
29 |
-
#
|
30 |
-
results = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
return results
|
33 |
|
34 |
# =============================================================================
|
35 |
-
# 3. CREATE THE GRADIO INTERFACE
|
36 |
# =============================================================================
|
37 |
-
# This creates the actual web page interface.
|
38 |
iface = gr.Interface(
|
39 |
fn=predict_family,
|
40 |
inputs=gr.Textbox(
|
@@ -52,7 +60,7 @@ iface = gr.Interface(
|
|
52 |
["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"],
|
53 |
["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"]
|
54 |
],
|
55 |
-
allow_flagging="never"
|
56 |
)
|
57 |
|
58 |
# Launch the interface!
|
|
|
1 |
+
# app.py (Updated Version)
|
2 |
|
|
|
3 |
import gradio as gr
|
4 |
from transformers import pipeline
|
5 |
+
import pickle
|
6 |
|
7 |
# =============================================================================
|
8 |
+
# 1. LOAD YOUR MODEL AND THE SAVED LABEL ENCODER
|
9 |
# =============================================================================
|
10 |
+
# Define the path to your model repository
|
11 |
+
model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k" # Make sure this is correct
|
12 |
+
|
13 |
+
# Load the classification pipeline
|
|
|
|
|
|
|
14 |
classifier = pipeline("text-classification", model=model_path)
|
15 |
|
16 |
+
# Load the label encoder from the file you uploaded
|
17 |
+
with open("label_encoder.pkl", "rb") as f:
|
18 |
+
label_encoder = pickle.load(f)
|
19 |
+
|
20 |
|
21 |
# =============================================================================
|
22 |
+
# 2. DEFINE THE PREDICTION FUNCTION WITH LABEL DECODING
|
23 |
# =============================================================================
|
24 |
+
# This function now decodes the labels before displaying them.
|
25 |
def predict_family(sequence):
|
26 |
+
# Get the top 5 predictions from the model
|
27 |
+
predictions = classifier(sequence, top_k=5)
|
|
|
28 |
|
29 |
+
# The model outputs labels like "LABEL_455". We need to extract the number.
|
30 |
+
results = {}
|
31 |
+
for p in predictions:
|
32 |
+
# Extract the number from the label string (e.g., "LABEL_455" -> 455)
|
33 |
+
label_index = int(p['label'].split('_')[1])
|
34 |
+
|
35 |
+
# Use the label_encoder to find the original family name
|
36 |
+
original_label = label_encoder.inverse_transform([label_index])[0]
|
37 |
+
|
38 |
+
# Store the real name and score
|
39 |
+
results[original_label] = p['score']
|
40 |
|
41 |
return results
|
42 |
|
43 |
# =============================================================================
|
44 |
+
# 3. CREATE THE GRADIO INTERFACE (No changes here)
|
45 |
# =============================================================================
|
|
|
46 |
iface = gr.Interface(
|
47 |
fn=predict_family,
|
48 |
inputs=gr.Textbox(
|
|
|
60 |
["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"],
|
61 |
["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"]
|
62 |
],
|
63 |
+
allow_flagging="never"
|
64 |
)
|
65 |
|
66 |
# Launch the interface!
|