AMP-Classifier / app.py
nonzeroexit's picture
Update app.py
f776418 verified
raw
history blame
5.27 kB
import gradio as gr
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
from sklearn.preprocessing import MinMaxScaler
import torch
from transformers import BertTokenizer, BertModel
from lime.lime_tabular import LimeTabularExplainer
from math import expm1
# Load AMP Classifier
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")
# Load ProtBert
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
protbert_model = protbert_model.to(device).eval()
# Selected Features
selected_features = [ ... ] # keep your full selected_features list here
# LIME Explainer Setup
sample_data = np.random.rand(100, len(selected_features))
explainer = LimeTabularExplainer(
training_data=sample_data,
feature_names=selected_features,
class_names=["AMP", "Non-AMP"],
mode="classification"
)
# Feature Extractor
def extract_features(sequence):
all_features_dict = {}
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return "Error: Sequence too short."
dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
ctd_features = CTD.CalculateCTD(sequence)
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
all_features_dict.update(ctd_features)
all_features_dict.update(filtered_dipeptide_features)
all_features_dict.update(auto_features)
all_features_dict.update(pseudo_features)
feature_df_all = pd.DataFrame([all_features_dict])
normalized_array = scaler.transform(feature_df_all.values)
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
selected_df = normalized_df[selected_features].fillna(0)
return selected_df.values
# MIC Predictor
def predictmic(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return {"Error": "Sequence too short or invalid."}
seq_spaced = ' '.join(list(sequence))
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
outputs = protbert_model(**tokens)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
bacteria_config = {
"E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
"S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
"P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
"K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
}
mic_results = {}
for bacterium, cfg in bacteria_config.items():
try:
scaler = joblib.load(cfg["scaler"])
scaled = scaler.transform(embedding)
transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
model = joblib.load(cfg["model"])
mic_log = model.predict(transformed)[0]
mic = round(expm1(mic_log), 3)
mic_results[bacterium] = mic
except Exception as e:
mic_results[bacterium] = f"Error: {str(e)}"
return mic_results
# Full Prediction with LIME Explanation
def full_prediction(sequence):
features = extract_features(sequence)
if isinstance(features, str): # error
return features
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
if prediction == 0:
mic_values = predictmic(sequence)
result += "\nPredicted MIC Values (µM):\n"
for org, mic in mic_values.items():
result += f"- {org}: {mic}\n"
else:
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
# LIME explanation
explanation = explainer.explain_instance(
data_row=features[0],
predict_fn=model.predict_proba,
num_features=10
)
result += "\nTop Features Influencing AMP Prediction:\n"
for feat, weight in explanation.as_list():
result += f"- {feat}: {round(weight, 4)}\n"
return result
# Gradio UI
iface = gr.Interface(
fn=full_prediction,
inputs=gr.Textbox(label="Enter Protein Sequence"),
outputs=gr.Textbox(label="Prediction + MIC + LIME"),
title="AMP & MIC Predictor + LIME Explanation",
description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
)
iface.launch(share=True)