nonzeroexit commited on
Commit
b206439
·
verified ·
1 Parent(s): 63d3a19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -35
app.py CHANGED
@@ -19,35 +19,10 @@ protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  protbert_model = protbert_model.to(device).eval()
21
 
22
- # Define selected features
23
- selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
24
- "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
25
- "_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
26
- "_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
27
- "_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
28
- "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
29
- "_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
30
- "_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
31
- "AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
32
- "LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
33
- "MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
34
- "GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
35
- "GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
36
- "GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
37
- "GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
38
- "GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
39
- "GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
40
- "GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
41
- "GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
42
- "GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
43
- "GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
44
- "GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
45
- "GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
46
- "GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
47
- "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
48
- "APAAC15", "APAAC18", "APAAC19", "APAAC24"]
49
-
50
- # Create dummy data for LIME initialization
51
  sample_data = np.random.rand(100, len(selected_features))
52
  explainer = LimeTabularExplainer(
53
  training_data=sample_data,
@@ -56,7 +31,7 @@ explainer = LimeTabularExplainer(
56
  mode="classification"
57
  )
58
 
59
- # Feature extraction from sequence
60
  def extract_features(sequence):
61
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
62
  if len(sequence) < 10:
@@ -87,7 +62,7 @@ def extract_features(sequence):
87
  except Exception as e:
88
  return f"Error in feature extraction: {str(e)}"
89
 
90
- # MIC prediction for AMP sequences
91
  def predictmic(sequence):
92
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
93
  if len(sequence) < 10:
@@ -123,7 +98,7 @@ def predictmic(sequence):
123
 
124
  return mic_results
125
 
126
- # Full prediction pipeline
127
  def full_prediction(sequence):
128
  features = extract_features(sequence)
129
  if isinstance(features, str):
@@ -131,9 +106,14 @@ def full_prediction(sequence):
131
 
132
  prediction = model.predict(features)[0]
133
  probabilities = model.predict_proba(features)[0]
134
- amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
135
- confidence = round(probabilities[prediction] * 100, 2)
136
 
 
 
 
 
 
 
 
137
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
138
 
139
  if prediction == 0:
@@ -156,7 +136,7 @@ def full_prediction(sequence):
156
 
157
  return result
158
 
159
- # Gradio interface
160
  iface = gr.Interface(
161
  fn=full_prediction,
162
  inputs=gr.Textbox(label="Enter Protein Sequence"),
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  protbert_model = protbert_model.to(device).eval()
21
 
22
+ # Define selected features (put your complete list here)
23
+ selected_features = [ ... ] # Replace with your full list
24
+
25
+ # Dummy data for LIME
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  sample_data = np.random.rand(100, len(selected_features))
27
  explainer = LimeTabularExplainer(
28
  training_data=sample_data,
 
31
  mode="classification"
32
  )
33
 
34
+ # Feature extraction function
35
  def extract_features(sequence):
36
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
37
  if len(sequence) < 10:
 
62
  except Exception as e:
63
  return f"Error in feature extraction: {str(e)}"
64
 
65
+ # MIC prediction function
66
  def predictmic(sequence):
67
  sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
68
  if len(sequence) < 10:
 
98
 
99
  return mic_results
100
 
101
+ # Main prediction function
102
  def full_prediction(sequence):
103
  features = extract_features(sequence)
104
  if isinstance(features, str):
 
106
 
107
  prediction = model.predict(features)[0]
108
  probabilities = model.predict_proba(features)[0]
 
 
109
 
110
+ try:
111
+ class_index = list(model.classes_).index(prediction)
112
+ confidence = round(probabilities[class_index] * 100, 2)
113
+ except Exception:
114
+ confidence = "Unknown"
115
+
116
+ amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
117
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
118
 
119
  if prediction == 0:
 
136
 
137
  return result
138
 
139
+ # Gradio UI
140
  iface = gr.Interface(
141
  fn=full_prediction,
142
  inputs=gr.Textbox(label="Enter Protein Sequence"),