Spaces:
Runtime error
Runtime error
| import logging | |
| from flask import Flask, request, render_template, send_file | |
| import pandas as pd | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
| import torch | |
| import os | |
| from datetime import datetime | |
| from datasets import load_dataset | |
| from huggingface_hub import login | |
| # Load Hugging Face token from environment variable | |
| HUGGING_FACE_TOKEN = os.getenv("HUGGING_FACE_TOKEN") | |
| # Authenticate with Hugging Face | |
| if HUGGING_FACE_TOKEN: | |
| login(token=HUGGING_FACE_TOKEN) | |
| else: | |
| raise ValueError("Hugging Face token not found. Please set the HUGGING_FACE_TOKEN environment variable.") | |
| # Initialize the Flask application | |
| app = Flask(__name__) | |
| # Set up the device (CUDA or CPU) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Optional: Set up logging for debugging | |
| logging.basicConfig(level=logging.DEBUG) | |
| # Define a function to classify user persona based on the selected model | |
| def classify_persona(text, model, tokenizer): | |
| inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512).to(device) | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Convert logits to probabilities | |
| probabilities = torch.nn.functional.softmax(logits, dim=1) | |
| # Print logits and probabilities for debugging | |
| print(f"Logits: {logits}") | |
| print(f"Probabilities: {probabilities}") | |
| # Get the predicted classes | |
| predictions = torch.argmax(probabilities, dim=1) | |
| persona_mapping = {0: 'Persona A', 1: 'Persona B', 2: 'Persona C'} | |
| # If there are multiple predictions, return the first one (or handle them as needed) | |
| predicted_personas = [persona_mapping.get(pred.item(), 'Unknown') for pred in predictions] | |
| # For now, let's assume you want the first prediction | |
| return predicted_personas[0] | |
| # Define the function to determine if a message is polarized | |
| def is_polarized(message): | |
| # If message is a list, join it into a single string | |
| if isinstance(message, list): | |
| message = ' '.join(message) | |
| polarized_keywords = ["always", "never", "everyone", "nobody", "worst", "best"] | |
| return any(keyword in message.lower() for keyword in polarized_keywords) | |
| # Define the function to generate AI-based nudges using the selected transformer model | |
| def generate_nudge(message, persona, topic, model, tokenizer, model_type, max_length=50, min_length=30, temperature=0.7, top_p=0.9, repetition_penalty=1.1): | |
| # Ensure min_length is less than or equal to max_length | |
| min_length = min(min_length, max_length) | |
| if model_type == "seq2seq": | |
| prompt = f"As an AI assistant, provide a nudge for this {persona} message in a {topic} discussion: {message}" | |
| inputs = tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).to(device) | |
| generated_ids = model.generate( | |
| inputs['input_ids'], | |
| max_length=max_length, | |
| min_length=min_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| elif model_type == "causal": | |
| prompt = f"{message} [AI Nudge]:" | |
| inputs = tokenizer(prompt, return_tensors='pt').to(device) | |
| generated_ids = model.generate( | |
| inputs['input_ids'], | |
| max_length=max_length, | |
| min_length=min_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| ) | |
| nudge = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| else: | |
| nudge = "This model is not suitable for generating text." | |
| return nudge | |
| def home(): | |
| logging.debug("Home route accessed.") | |
| if request.method == 'POST': | |
| logging.debug("POST request received.") | |
| try: | |
| # Get the model names from the form | |
| persona_model_name = request.form.get('persona_model_name', 'roberta-base') | |
| nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn') | |
| logging.debug(f"Selected persona model: {persona_model_name}") | |
| logging.debug(f"Selected nudge model: {nudge_model_name}") | |
| # Load persona classification model | |
| persona_model = AutoModelForSequenceClassification.from_pretrained(persona_model_name, num_labels=3).to(device) | |
| persona_tokenizer = AutoTokenizer.from_pretrained(persona_model_name) | |
| # Load nudge generation model | |
| if "bart" in nudge_model_name or "t5" in nudge_model_name: | |
| model_type = "seq2seq" | |
| nudge_model = AutoModelForSeq2SeqLM.from_pretrained(nudge_model_name).to(device) | |
| elif "gpt2" in nudge_model_name: | |
| model_type = "causal" | |
| nudge_model = AutoModelForCausalLM.from_pretrained(nudge_model_name).to(device) | |
| else: | |
| logging.error("Unsupported model selected.") | |
| return "Selected model is not supported for text generation tasks.", 400 | |
| nudge_tokenizer = AutoTokenizer.from_pretrained(nudge_model_name) | |
| logging.debug("Models and tokenizers loaded.") | |
| use_online_dataset = request.form.get('use_online_dataset') == 'yes' | |
| if use_online_dataset: | |
| # Attempt to load the specified online dataset | |
| dataset_name = request.form.get('dataset_name') | |
| logging.debug(f"Selected online dataset: {dataset_name}") | |
| if dataset_name == 'personachat': | |
| # Use AlekseyKorshuk/persona-chat if 'personachat' is selected | |
| dataset_name = 'AlekseyKorshuk/persona-chat' | |
| dataset = load_dataset(dataset_name) | |
| df = pd.DataFrame(dataset['train']) # Use the training split for processing | |
| df = df.rename(columns=lambda x: x.strip().lower()) | |
| df = df[['utterances', 'personality']] # Modify this according to the dataset structure | |
| df.columns = ['topic', 'post_reply'] # Standardize column names for processing | |
| else: | |
| uploaded_file = request.files['file'] | |
| if uploaded_file.filename != '': | |
| logging.debug(f"File uploaded: {uploaded_file.filename}") | |
| df = pd.read_csv(uploaded_file) | |
| df.columns = df.columns.str.strip().str.lower() | |
| if 'post_reply' not in df.columns: | |
| logging.error("Required column 'post_reply' is missing in the CSV.") | |
| return "The uploaded CSV file must contain 'post_reply' column.", 400 | |
| augmented_rows = [] | |
| for _, row in df.iterrows(): | |
| if 'user_persona' not in row or pd.isna(row['user_persona']): | |
| # Classify user persona if not provided | |
| row['user_persona'] = classify_persona(row['post_reply'], persona_model, persona_tokenizer) | |
| augmented_rows.append(row.to_dict()) | |
| if is_polarized(row['post_reply']): | |
| nudge = generate_nudge(row['post_reply'], row['user_persona'], row['topic'], nudge_model, nudge_tokenizer, model_type) | |
| augmented_rows.append({ | |
| 'topic': row['topic'], | |
| 'user_persona': 'AI Nudge', | |
| 'post_reply': nudge | |
| }) | |
| augmented_df = pd.DataFrame(augmented_rows) | |
| logging.debug("Processing completed.") | |
| # Generate the output filename | |
| persona_model_name = request.form.get('persona_model_name', 'roberta-base').split('/')[-1].replace('-', '_') | |
| nudge_model_name = request.form.get('nudge_model_name', 'facebook/bart-large-cnn').split('/')[-1].replace('-', '_') | |
| current_time = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_filename = f"DepolNudge_{persona_model_name}_{nudge_model_name}_{current_time}.csv" | |
| # Instead of saving to a directory, create the CSV in memory | |
| csv_buffer = io.BytesIO() | |
| augmented_df.to_csv(csv_buffer, index=False) | |
| csv_buffer.seek(0) # Reset buffer position to the start | |
| # Directly send the file for download without saving to a specific folder | |
| return send_file( | |
| csv_buffer, | |
| as_attachment=True, | |
| download_name=output_filename, | |
| mimetype='text/csv' | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error processing the request: {e}", exc_info=True) | |
| return "There was an error processing your request.", 500 | |
| logging.debug("Rendering index.html") | |
| return render_template('index.html') | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |