Tarive commited on
Commit
0dee19d
·
verified ·
1 Parent(s): 254a962

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -40
app.py CHANGED
@@ -1,64 +1,70 @@
1
- # app.py (Final Corrected Version)
2
 
3
  import gradio as gr
4
- from transformers import pipeline
 
5
  import pickle
6
- from huggingface_hub import hf_hub_download # Import the download function
7
 
8
  # =============================================================================
9
- # 1. LOAD YOUR MODEL AND THE SAVED LABEL ENCODER
10
  # =============================================================================
11
  # Define the path to your model repository
12
  model_path = "Tarive/esm2_t12_35M_UR50D-5k-families-balanced-augmented-weighted_optimized"
13
 
14
- # --- FIX FOR LFS ---
15
- # Explicitly download the label_encoder.pkl file from the repo.
16
- # This ensures the app can find the file even if it's stored with Git LFS.
17
- print("Downloading label encoder...")
18
- encoder_path = hf_hub_download(repo_id=model_path, filename="label_encoder_5k-2.pkl")
19
- print("Download complete.")
20
- # --- END FIX ---
21
 
22
- # Load the classification pipeline
23
- print("Loading classification pipeline...")
24
- classifier = pipeline("text-classification", model=model_path)
25
- print("Pipeline loaded.")
 
 
26
 
27
- # Load the label encoder from the path where it was downloaded
28
- print("Loading label encoder...")
 
29
  with open(encoder_path, "rb") as f:
30
  label_encoder = pickle.load(f)
31
  print("Label encoder loaded.")
32
 
33
 
34
  # =============================================================================
35
- # 2. DEFINE THE PREDICTION FUNCTION WITH LABEL DECODING
36
  # =============================================================================
37
- # This function now decodes the labels before displaying them.
38
  def predict_family(sequence):
39
- # Get the top 5 predictions from the model
40
- predictions = classifier(sequence, top_k=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- # The model outputs labels like "LABEL_455". We need to extract the number.
43
- results = {}
44
- for p in predictions:
45
- try:
46
- # Extract the number from the label string (e.g., "LABEL_455" -> 455)
47
- label_index = int(p['label'].split('_')[1])
48
-
49
- # Use the label_encoder to find the original family name
50
- original_label = label_encoder.inverse_transform([label_index])[0]
51
-
52
- # Store the real name and score
53
- results[original_label] = p['score']
54
- except (ValueError, IndexError):
55
- # Handle cases where the label format is unexpected
56
- results[p['label']] = p['score']
57
 
 
 
 
 
 
 
 
58
  return results
59
 
60
  # =============================================================================
61
- # 3. CREATE THE GRADIO INTERFACE
62
  # =============================================================================
63
  print("Creating Gradio interface...")
64
  iface = gr.Interface(
@@ -77,11 +83,10 @@ iface = gr.Interface(
77
  examples=[
78
  ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"],
79
  ["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"],
80
- ["MNGTEGPNFYVPFSNKTGVVRSPFEAPQYYLAEPWQFSMLAAYMFLLIMLGFPINFLTLYVTVQHKKLRTPLNYILLNLAVADLFMVFGGFTTTLYTSLHGYFVFGPTGCNLEGFFATLGGEIALWSLVVLAIERYVVVCKPMSNFRFGENHAIMGVAFTWVMALACAAPPLVGWSRYIPEGMQCSCGIDYYTPHEETNNESFVIYMFVVHFIIPLIVIFFCYGQLVFTVKEAAAQQQESATTQKAEKEVTRMVIIMVIAFLICWLPYAGVAFYIFTHQGSDFGPIFMTIPAFFAKTSAVYNPVIYIMMNKQFRNCMVTTLCCGKNPLGDDEASTTVSKTETSQVAPA"]
81
  ],
82
- allow_flagging="never" # Disables the "Flag" button for a cleaner interface
83
  )
84
- print("Interface created.")
85
 
86
  # Launch the interface!
87
  print("Launching app...")
 
1
+ # app.py (Final, Robust Version)
2
 
3
  import gradio as gr
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
  import pickle
7
+ from huggingface_hub import hf_hub_download
8
 
9
  # =============================================================================
10
+ # 1. LOAD MODEL, TOKENIZER, AND LABEL ENCODER
11
  # =============================================================================
12
  # Define the path to your model repository
13
  model_path = "Tarive/esm2_t12_35M_UR50D-5k-families-balanced-augmented-weighted_optimized"
14
 
15
+ print("Loading tokenizer...")
16
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
 
 
 
 
 
17
 
18
+ print("Loading model...")
19
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
20
+ # Move model to GPU if available for faster inference
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ model.to(device)
23
+ print(f"Model loaded on device: {device}")
24
 
25
+ # Download and load the label encoder
26
+ print("Downloading and loading label encoder...")
27
+ encoder_path = hf_hub_download(repo_id=model_path, filename="label_encoder_5k-2.pkl")
28
  with open(encoder_path, "rb") as f:
29
  label_encoder = pickle.load(f)
30
  print("Label encoder loaded.")
31
 
32
 
33
  # =============================================================================
34
+ # 2. DEFINE THE LOW-LEVEL PREDICTION FUNCTION
35
  # =============================================================================
36
+ # This function manually replicates the training data processing steps.
37
  def predict_family(sequence):
38
+ # 1. Tokenize the input sequence with the exact same settings as training
39
+ inputs = tokenizer(
40
+ sequence,
41
+ return_tensors="pt", # Return PyTorch tensors
42
+ truncation=True,
43
+ padding=True,
44
+ max_length=256 # Ensure this matches your training max_length
45
+ ).to(device) # Move tokenized inputs to the same device as the model
46
+
47
+ # 2. Get model predictions (logits)
48
+ with torch.no_grad(): # Disable gradient calculation for efficiency
49
+ logits = model(**inputs).logits
50
+
51
+ # 3. Get the top 5 predictions
52
+ top_k_indices = torch.topk(logits, 5, dim=-1).indices.squeeze().tolist()
53
 
54
+ # 4. Convert logits to probabilities (softmax)
55
+ probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # 5. Decode the numerical labels back to family names
58
+ results = {}
59
+ for index in top_k_indices:
60
+ family_name = label_encoder.inverse_transform([index])[0]
61
+ confidence_score = probabilities[index]
62
+ results[family_name] = confidence_score
63
+
64
  return results
65
 
66
  # =============================================================================
67
+ # 3. CREATE THE GRADIO INTERFACE (No changes here)
68
  # =============================================================================
69
  print("Creating Gradio interface...")
70
  iface = gr.Interface(
 
83
  examples=[
84
  ["MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"],
85
  ["MTEYKLVVVGAGDVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVEVDCQQCMILDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHQYREQIKRVKDSDDVPMVLVGNKCDLAARTVESRQAQDLARSYGIPYIETSAKTRQGVEDAFYTLVREIRQHKLRKLNPPDESGGCMS"],
86
+ ["MSIKKILVSDKITTLEKFPASVTLDGADFTVHSSWYDTEKVREDIKEKYSHLISESENGFLFKEKDSKRFWRYFNEKDGVSYATGYQINPYFPANKKYEFGYTGAEWYYSYEPKNVARYGNFDETDAAHPCTYTVANYYLRDKSYFDDKYFNVPLYNMFFNDYNYYDFEYQTKNKFYFTNYKENPKYPFETNFENVPSKDTDDYIIKPYPGVKKFGEFDWDEFEGNTFDPGYYKDSYMYYQKKYDDSYKYKEYGVDPDDFSYKDKYDNNPKFNLYYKYVPDKKNN"]
87
  ],
88
+ allow_flagging="never"
89
  )
 
90
 
91
  # Launch the interface!
92
  print("Launching app...")