Tarive commited on
Commit
2132cad
·
verified ·
1 Parent(s): 7707dbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -20
app.py CHANGED
@@ -1,40 +1,48 @@
1
- # app.py
2
 
3
- # Import the necessary libraries
4
  import gradio as gr
5
  from transformers import pipeline
 
6
 
7
  # =============================================================================
8
- # 1. LOAD YOUR MODEL
9
  # =============================================================================
10
- # Use a pipeline for easy text classification. This will automatically
11
- # load your fine-tuned model and tokenizer from the repository.
12
- #
13
- # IMPORTANT: Replace "your-username/your-model-repo-name" with your actual model path.
14
- # For example: "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k"
15
- #
16
- model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k"
17
  classifier = pipeline("text-classification", model=model_path)
18
 
 
 
 
 
19
 
20
  # =============================================================================
21
- # 2. DEFINE THE PREDICTION FUNCTION
22
  # =============================================================================
23
- # This function takes a sequence as input and returns a formatted result.
24
  def predict_family(sequence):
25
- # The pipeline returns a list of dictionaries.
26
- # Example: [{'label': 'PF00042', 'score': 0.979}]
27
- predictions = classifier(sequence, top_k=5) # Get the top 5 predictions
28
 
29
- # Format the results into a more readable dictionary for display.
30
- results = {p['label']: p['score'] for p in predictions}
 
 
 
 
 
 
 
 
 
31
 
32
  return results
33
 
34
  # =============================================================================
35
- # 3. CREATE THE GRADIO INTERFACE
36
  # =============================================================================
37
- # This creates the actual web page interface.
38
  iface = gr.Interface(
39
  fn=predict_family,
40
  inputs=gr.Textbox(
@@ -52,7 +60,7 @@ iface = gr.Interface(
52
  ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"],
53
  ["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"]
54
  ],
55
- allow_flagging="never" # Disables the "Flag" button for a cleaner interface
56
  )
57
 
58
  # Launch the interface!
 
1
+ # app.py (Updated Version)
2
 
 
3
  import gradio as gr
4
  from transformers import pipeline
5
+ import pickle
6
 
7
  # =============================================================================
8
+ # 1. LOAD YOUR MODEL AND THE SAVED LABEL ENCODER
9
  # =============================================================================
10
+ # Define the path to your model repository
11
+ model_path = "Tarive/esm2_t12_35M_UR50D-finetuned-pfam-1k" # Make sure this is correct
12
+
13
+ # Load the classification pipeline
 
 
 
14
  classifier = pipeline("text-classification", model=model_path)
15
 
16
+ # Load the label encoder from the file you uploaded
17
+ with open("label_encoder.pkl", "rb") as f:
18
+ label_encoder = pickle.load(f)
19
+
20
 
21
  # =============================================================================
22
+ # 2. DEFINE THE PREDICTION FUNCTION WITH LABEL DECODING
23
  # =============================================================================
24
+ # This function now decodes the labels before displaying them.
25
  def predict_family(sequence):
26
+ # Get the top 5 predictions from the model
27
+ predictions = classifier(sequence, top_k=5)
 
28
 
29
+ # The model outputs labels like "LABEL_455". We need to extract the number.
30
+ results = {}
31
+ for p in predictions:
32
+ # Extract the number from the label string (e.g., "LABEL_455" -> 455)
33
+ label_index = int(p['label'].split('_')[1])
34
+
35
+ # Use the label_encoder to find the original family name
36
+ original_label = label_encoder.inverse_transform([label_index])[0]
37
+
38
+ # Store the real name and score
39
+ results[original_label] = p['score']
40
 
41
  return results
42
 
43
  # =============================================================================
44
+ # 3. CREATE THE GRADIO INTERFACE (No changes here)
45
  # =============================================================================
 
46
  iface = gr.Interface(
47
  fn=predict_family,
48
  inputs=gr.Textbox(
 
60
  ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"],
61
  ["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"]
62
  ],
63
+ allow_flagging="never"
64
  )
65
 
66
  # Launch the interface!