import gradio as gr import joblib import numpy as np import pandas as pd from propy import AAComposition, Autocorrelation, CTD, PseudoAAC from sklearn.preprocessing import MinMaxScaler import torch from transformers import BertTokenizer, BertModel from lime.lime_tabular import LimeTabularExplainer from math import expm1 # Load AMP Classifier and Scaler model = joblib.load("RF.joblib") scaler = joblib.load("norm (4).joblib") # Load ProtBert tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) protbert_model = BertModel.from_pretrained("Rostlab/prot_bert") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") protbert_model = protbert_model.to(device).eval() # Define selected features (put your complete list here) selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1", "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001", "_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001", "_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001", "_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050", "_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025", "_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001", "_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V", "AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE", "LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV", "MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4", "GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26", "GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29", "GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26", "GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30", "GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25", "GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29", "GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30", "GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24", "GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25", "GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30", "GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28", "GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25", "GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29", "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13", "APAAC15", "APAAC18", "APAAC19", "APAAC24"] # Dummy data for LIME sample_data = np.random.rand(100, len(selected_features)) explainer = LimeTabularExplainer( training_data=sample_data, feature_names=selected_features, class_names=["AMP", "Non-AMP"], mode="classification" ) # Feature extraction function def extract_features(sequence): sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) if len(sequence) < 10: return "Error: Sequence too short." try: dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence) filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]} ctd_features = CTD.CalculateCTD(sequence) auto_features = Autocorrelation.CalculateAutoTotal(sequence) pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9) all_features_dict = {} all_features_dict.update(ctd_features) all_features_dict.update(filtered_dipeptide_features) all_features_dict.update(auto_features) all_features_dict.update(pseudo_features) feature_df_all = pd.DataFrame([all_features_dict]) normalized_array = scaler.transform(feature_df_all.values) normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns) if not set(selected_features).issubset(normalized_df.columns): return "Error: Some selected features are missing." selected_df = normalized_df[selected_features].fillna(0) return selected_df.values except Exception as e: return f"Error in feature extraction: {str(e)}" # MIC prediction function def predictmic(sequence): sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"]) if len(sequence) < 10: return {"Error": "Sequence too short or invalid."} seq_spaced = ' '.join(list(sequence)) tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512) tokens = {k: v.to(device) for k, v in tokens.items()} with torch.no_grad(): outputs = protbert_model(**tokens) embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1) bacteria_config = { "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None}, "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None}, "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None}, "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"} } mic_results = {} for bacterium, cfg in bacteria_config.items(): try: scaler = joblib.load(cfg["scaler"]) scaled = scaler.transform(embedding) transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled model = joblib.load(cfg["model"]) mic_log = model.predict(transformed)[0] mic = round(expm1(mic_log), 3) mic_results[bacterium] = mic except Exception as e: mic_results[bacterium] = f"Error: {str(e)}" return mic_results # Main prediction function def full_prediction(sequence): features = extract_features(sequence) if isinstance(features, str): return features prediction = model.predict(features)[0] probabilities = model.predict_proba(features)[0] try: class_index = list(model.classes_).index(prediction) confidence = round(probabilities[class_index] * 100, 2) except Exception: confidence = "Unknown" amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP" result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n" if prediction == 0: mic_values = predictmic(sequence) result += "\nPredicted MIC Values (μM):\n" for org, mic in mic_values.items(): result += f"- {org}: {mic}\n" else: result += "\nMIC prediction skipped for Non-AMP sequences.\n" explanation = explainer.explain_instance( data_row=features[0], predict_fn=model.predict_proba, num_features=10 ) result += "\nTop Features Influencing Prediction:\n" for feat, weight in explanation.as_list(): result += f"- {feat}: {round(weight, 4)}\n" return result # Gradio UI iface = gr.Interface( fn=full_prediction, inputs=gr.Textbox(label="Enter Protein Sequence"), outputs=gr.Textbox(label="Results"), title="AMP & MIC Predictor + LIME Explanation", description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights." ) iface.launch(share=True)