import gradio as gr import joblib import numpy as np import pandas as pd from propy import AAComposition, Autocorrelation, CTD, PseudoAAC from sklearn.preprocessing import MinMaxScaler import torch from transformers import BertTokenizer, BertModel from lime.lime_tabular import LimeTabularExplainer from math import expm1 # Load AMP Classifier model = joblib.load("RF.joblib") scaler = joblib.load("norm (4).joblib") # Load ProtBert tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) protbert_model = BertModel.from_pretrained("Rostlab/prot_bert") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") protbert_model = protbert_model.to(device).eval() # Selected Features selected_features = [ ... ] # keep your full selected_features list here # LIME Explainer Setup sample_data = np.random.rand(100, len(selected_features)) explainer = LimeTabularExplainer( training_data=sample_data, feature_names=selected_features, class_names=["AMP", "Non-AMP"], mode="classification" ) # Feature Extractor def extract_features(sequence): all_features_dict = {} sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) if len(sequence) < 10: return "Error: Sequence too short." dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence) filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]} ctd_features = CTD.CalculateCTD(sequence) auto_features = Autocorrelation.CalculateAutoTotal(sequence) pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9) all_features_dict.update(ctd_features) all_features_dict.update(filtered_dipeptide_features) all_features_dict.update(auto_features) all_features_dict.update(pseudo_features) feature_df_all = pd.DataFrame([all_features_dict]) normalized_array = scaler.transform(feature_df_all.values) normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns) selected_df = normalized_df[selected_features].fillna(0) return selected_df.values # MIC Predictor def predictmic(sequence): sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) if len(sequence) < 10: return {"Error": "Sequence too short or invalid."} seq_spaced = ' '.join(list(sequence)) tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512) tokens = {k: v.to(device) for k, v in tokens.items()} with torch.no_grad(): outputs = protbert_model(**tokens) embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1) bacteria_config = { "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None}, "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None}, "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None}, "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"} } mic_results = {} for bacterium, cfg in bacteria_config.items(): try: scaler = joblib.load(cfg["scaler"]) scaled = scaler.transform(embedding) transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled model = joblib.load(cfg["model"]) mic_log = model.predict(transformed)[0] mic = round(expm1(mic_log), 3) mic_results[bacterium] = mic except Exception as e: mic_results[bacterium] = f"Error: {str(e)}" return mic_results # Full Prediction with LIME Explanation def full_prediction(sequence): features = extract_features(sequence) if isinstance(features, str): # error return features prediction = model.predict(features)[0] probabilities = model.predict_proba(features)[0] amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP" confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2) result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n" if prediction == 0: mic_values = predictmic(sequence) result += "\nPredicted MIC Values (µM):\n" for org, mic in mic_values.items(): result += f"- {org}: {mic}\n" else: result += "\nMIC prediction skipped for Non-AMP sequences.\n" # LIME explanation explanation = explainer.explain_instance( data_row=features[0], predict_fn=model.predict_proba, num_features=10 ) result += "\nTop Features Influencing AMP Prediction:\n" for feat, weight in explanation.as_list(): result += f"- {feat}: {round(weight, 4)}\n" return result # Gradio UI iface = gr.Interface( fn=full_prediction, inputs=gr.Textbox(label="Enter Protein Sequence"), outputs=gr.Textbox(label="Prediction + MIC + LIME"), title="AMP & MIC Predictor + LIME Explanation", description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights." ) iface.launch(share=True)