nonzeroexit commited on
Commit
f776418
·
verified ·
1 Parent(s): e9f8ebb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -50
app.py CHANGED
@@ -6,55 +6,32 @@ from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
6
  from sklearn.preprocessing import MinMaxScaler
7
  import torch
8
  from transformers import BertTokenizer, BertModel
 
9
  from math import expm1
10
 
11
  # Load AMP Classifier
12
  model = joblib.load("RF.joblib")
13
  scaler = joblib.load("norm (4).joblib")
14
 
15
- # Load ProtBert Globally
16
  tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
17
  protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  protbert_model = protbert_model.to(device).eval()
20
 
21
  # Selected Features
22
- selected_features = [
23
- "_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
24
- "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001",
25
- "_PolarizabilityD2001", "_PolarizabilityD3001", "_SolventAccessibilityD1001",
26
- "_SolventAccessibilityD2001", "_SolventAccessibilityD3001", "_SecondaryStrD1001",
27
- "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
28
- "_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001",
29
- "_PolarityD1050", "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001",
30
- "_NormalizedVDWVD2001", "_NormalizedVDWVD2025", "_NormalizedVDWVD2050", "_NormalizedVDWVD3001",
31
- "_HydrophobicityD1001", "_HydrophobicityD2001", "_HydrophobicityD3001", "_HydrophobicityD3025",
32
- "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
33
- "AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL",
34
- "HC", "IA", "IL", "IV", "LA", "LC", "LE", "LI", "LT", "LV", "KC", "MA",
35
- "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
36
- "MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
37
- "GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
38
- "GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
39
- "GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
40
- "GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29",
41
- "GearyAuto_AvFlexibility30", "GearyAuto_Polarizability22", "GearyAuto_Polarizability24",
42
- "GearyAuto_Polarizability25", "GearyAuto_Polarizability27", "GearyAuto_Polarizability28",
43
- "GearyAuto_Polarizability29", "GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24",
44
- "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30", "GearyAuto_ResidueASA21",
45
- "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
46
- "GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24",
47
- "GearyAuto_ResidueVol25", "GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28",
48
- "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30", "GearyAuto_Steric18",
49
- "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
50
- "GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
51
- "GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28",
52
- "GearyAuto_Mutability29", "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5",
53
- "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13", "APAAC15", "APAAC18", "APAAC19",
54
- "APAAC24"
55
- ]
56
 
57
- # AMP Feature Extractor
 
 
 
 
 
 
 
 
 
58
  def extract_features(sequence):
59
  all_features_dict = {}
60
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
@@ -79,7 +56,7 @@ def extract_features(sequence):
79
  def predictmic(sequence):
80
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
81
  if len(sequence) < 10:
82
- return {"Error": "Sequence too short or invalid. Must contain at least 10 valid amino acids."}
83
  seq_spaced = ' '.join(list(sequence))
84
  tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
85
  tokens = {k: v.to(device) for k, v in tokens.items()}
@@ -106,37 +83,46 @@ def predictmic(sequence):
106
  mic_results[bacterium] = f"Error: {str(e)}"
107
  return mic_results
108
 
109
- # Combined Output as Single String
110
  def full_prediction(sequence):
111
  features = extract_features(sequence)
112
- if isinstance(features, str): # error message returned
113
  return features
 
114
  prediction = model.predict(features)[0]
115
  probabilities = model.predict_proba(features)[0]
116
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
117
  confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
118
 
119
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
120
-
121
- if prediction == 0: # only predict MIC if AMP
122
  mic_values = predictmic(sequence)
123
  result += "\nPredicted MIC Values (µM):\n"
124
- for organism, mic in mic_values.items():
125
- result += f"- {organism}: {mic}\n"
126
  else:
127
- result += "\nMIC prediction is not available because sequence is Non-AMP."
128
 
129
- return result
 
 
 
 
 
 
 
 
130
 
 
131
 
132
- # Gradio Interface (Single Label Output)
133
  iface = gr.Interface(
134
  fn=full_prediction,
135
  inputs=gr.Textbox(label="Enter Protein Sequence"),
136
- outputs=gr.Textbox(label="AMP & MIC Prediction Summary"),
137
- title="AMP & MIC Predictor",
138
- description="Enter an amino acid sequence (≥10 valid letters) to predict AMP class and MIC values."
139
  )
140
 
141
  iface.launch(share=True)
142
-
 
6
  from sklearn.preprocessing import MinMaxScaler
7
  import torch
8
  from transformers import BertTokenizer, BertModel
9
+ from lime.lime_tabular import LimeTabularExplainer
10
  from math import expm1
11
 
12
  # Load AMP Classifier
13
  model = joblib.load("RF.joblib")
14
  scaler = joblib.load("norm (4).joblib")
15
 
16
+ # Load ProtBert
17
  tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
18
  protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  protbert_model = protbert_model.to(device).eval()
21
 
22
  # Selected Features
23
+ selected_features = [ ... ] # keep your full selected_features list here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # LIME Explainer Setup
26
+ sample_data = np.random.rand(100, len(selected_features))
27
+ explainer = LimeTabularExplainer(
28
+ training_data=sample_data,
29
+ feature_names=selected_features,
30
+ class_names=["AMP", "Non-AMP"],
31
+ mode="classification"
32
+ )
33
+
34
+ # Feature Extractor
35
  def extract_features(sequence):
36
  all_features_dict = {}
37
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
 
56
  def predictmic(sequence):
57
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
58
  if len(sequence) < 10:
59
+ return {"Error": "Sequence too short or invalid."}
60
  seq_spaced = ' '.join(list(sequence))
61
  tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
62
  tokens = {k: v.to(device) for k, v in tokens.items()}
 
83
  mic_results[bacterium] = f"Error: {str(e)}"
84
  return mic_results
85
 
86
+ # Full Prediction with LIME Explanation
87
  def full_prediction(sequence):
88
  features = extract_features(sequence)
89
+ if isinstance(features, str): # error
90
  return features
91
+
92
  prediction = model.predict(features)[0]
93
  probabilities = model.predict_proba(features)[0]
94
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
95
  confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
96
 
97
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
98
+
99
+ if prediction == 0:
100
  mic_values = predictmic(sequence)
101
  result += "\nPredicted MIC Values (µM):\n"
102
+ for org, mic in mic_values.items():
103
+ result += f"- {org}: {mic}\n"
104
  else:
105
+ result += "\nMIC prediction skipped for Non-AMP sequences.\n"
106
 
107
+ # LIME explanation
108
+ explanation = explainer.explain_instance(
109
+ data_row=features[0],
110
+ predict_fn=model.predict_proba,
111
+ num_features=10
112
+ )
113
+ result += "\nTop Features Influencing AMP Prediction:\n"
114
+ for feat, weight in explanation.as_list():
115
+ result += f"- {feat}: {round(weight, 4)}\n"
116
 
117
+ return result
118
 
119
+ # Gradio UI
120
  iface = gr.Interface(
121
  fn=full_prediction,
122
  inputs=gr.Textbox(label="Enter Protein Sequence"),
123
+ outputs=gr.Textbox(label="Prediction + MIC + LIME"),
124
+ title="AMP & MIC Predictor + LIME Explanation",
125
+ description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
126
  )
127
 
128
  iface.launch(share=True)