JainilP30's picture
Update app.py
b5dfda3 verified
raw
history blame
3.58 kB
import gradio as gr
import pickle
import re
import string
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
import nltk
import os
import zipfile
# Unzip local nltk_data.zip if not already unzipped
nltk_data_path = os.path.join(os.path.dirname(__file__), 'nltk_data')
if not os.path.exists(nltk_data_path):
with zipfile.ZipFile('nltk_data.zip', 'r') as zip_ref:
zip_ref.extractall(nltk_data_path)
# Tell NLTK to use the local data path
nltk.data.path.append(nltk_data_path)
# ============ Load Models and Tokenizers ============
with open("logreg_model.pkl", "rb") as f:
logreg_model = pickle.load(f)
with open("nb_model.pkl", "rb") as f:
nb_model = pickle.load(f)
with open("tfidf_vectorizer.pkl", "rb") as f:
tfidf_vectorizer = pickle.load(f)
with open("glove_tokenizer.pkl", "rb") as f:
glove_tokenizer = pickle.load(f)
model_glove = tf.keras.models.load_model("glove_model.h5")
# ============ Constants ============
MAX_LENGTH = 300
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
# ============ Preprocessing ============
def clean_text(text):
text = str(text).lower()
text = re.sub(r'\[.*?\]', '', text)
text = re.sub(r'https?://\S+|www\.\S+', '', text)
text = re.sub(r'<.*?>+', '', text)
text = re.sub(f"[{re.escape(string.punctuation)}]", '', text)
text = re.sub(r'\n', ' ', text)
text = re.sub(r'\w*\d\w*', '', text)
text = text.replace('β€œ', '').replace('”', '').replace("’", "'").replace("β€˜", "'")
text = re.sub(r"'s\b", '', text)
tokens = word_tokenize(text)
tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words and len(word) > 2]
return ' '.join(tokens)
# ============ Prediction ============
def predict_ensemble(text):
cleaned = clean_text(text)
# Check if cleaned text is too short
if len(cleaned.strip()) <= 10:
return "Input too short to analyze."
# TF-IDF-based predictions
tfidf_vec = tfidf_vectorizer.transform([cleaned])
prob_nb = nb_model.predict_proba(tfidf_vec)[0][1]
prob_logreg = logreg_model.predict_proba(tfidf_vec)[0][1]
# GloVe prediction
glove_seq = glove_tokenizer.texts_to_sequences([cleaned])
glove_pad = pad_sequences(glove_seq, maxlen=MAX_LENGTH, padding='post', truncating='post')
prob_glove = model_glove.predict(glove_pad)[0][0]
# Weighted ensemble
ensemble_score = 0.50 * prob_nb + 0.1 * prob_logreg + 0.40 * prob_glove
label = "βœ… Real News" if ensemble_score >= 0.47 else "❌ Fake News"
# Optional: Include probabilities
# Naive Bayes:
# Logistic Regression:
# GloVe Model:
explanation = f"""
**Model 1** {prob_nb:.4f}
**Model 2** {prob_logreg:.4f}
**Model 3** {prob_glove:.4f}
**Ensemble Score:** {ensemble_score:.4f}
**Final Prediction:** {label}
"""
return explanation
# ============ Gradio Interface ============
interface = gr.Interface(
fn=predict_ensemble,
inputs=gr.Textbox(lines=8, placeholder="Paste your news article here...", label="News Article"),
outputs=gr.Markdown(label="Prediction"),
title="πŸ“° Fake News Detector",
description="This tool uses 3 models (Naive Bayes, Logistic Regression, GloVe-based Deep Learning) to classify news as real or fake using an ensemble method.",
allow_flagging="never"
)
if __name__ == "__main__":
interface.launch()