Spaces:
Running
Running
import os | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from datetime import datetime | |
import gradio as gr | |
from typing import Dict, List, Union, Optional | |
import logging | |
import traceback | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class ContentAnalyzer: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = None | |
self.tokenizer = None | |
self.batch_size = 4 | |
self.trigger_categories = { | |
"Violence": { | |
"mapped_name": "Violence", | |
"description": ( | |
"Any act involving physical force or aggression intended to cause harm, injury, or death to a person, animal, or object. " | |
"Includes direct physical confrontations (e.g., fights, beatings, or assaults), implied violence (e.g., very graphical threats or descriptions of injuries), " | |
"or large-scale events like wars, riots, or violent protests." | |
) | |
}, | |
"Death": { | |
"mapped_name": "Death References", | |
"description": ( | |
"Any mention, implication, or depiction of the loss of life, including direct deaths of characters, including mentions of deceased individuals, " | |
"or abstract references to mortality (e.g., 'facing the end' or 'gone forever'). This also covers depictions of funerals, mourning, " | |
"grieving, or any dialogue that centers around death, do not take metaphors into context that don't actually lead to death." | |
) | |
}, | |
"Substance_Use": { | |
"mapped_name": "Substance Use", | |
"description": ( | |
"Any explicit reference to the consumption, misuse, or abuse of drugs, alcohol, or other intoxicating substances. " | |
"This includes scenes of drug use, drinking, smoking, discussions about heavy substance abuse or substance-related paraphernalia." | |
) | |
}, | |
"Gore": { | |
"mapped_name": "Gore", | |
"description": ( | |
"Extremely detailed and graphic depictions of highly severe physical injuries, mutilation, or extreme bodily harm, often accompanied by descriptions of heavy blood, exposed organs, " | |
"or dismemberment. This includes war scenes with severe casualties, horror scenarios involving grotesque creatures, or medical procedures depicted with excessive detail." | |
) | |
}, | |
"Sexual_Content": { | |
"mapped_name": "Sexual Content", | |
"description": ( | |
"Any depiction of sexual activity, intimacy, or sexual behavior, ranging from implied scenes to explicit descriptions. " | |
"This includes physical descriptions of characters in a sexual context, sexual dialogue, or references to sexual themes." | |
) | |
}, | |
"Sexual_Abuse": { | |
"mapped_name": "Sexual Abuse", | |
"description": ( | |
"Any form of non-consensual sexual act, behavior, or interaction, involving coercion, manipulation, or physical force. " | |
"This includes incidents of sexual assault, exploitation, harassment, and any acts where an individual is subjected to sexual acts against their will." | |
) | |
}, | |
"Self_Harm": { | |
"mapped_name": "Self-Harm", | |
"description": ( | |
"Any mention or depiction of behaviors where an individual intentionally causes harm to themselves. This includes cutting, burning, or other forms of physical injury, " | |
"as well as suicidal ideation, suicide attempts, or discussions of self-destructive thoughts and actions." | |
) | |
}, | |
"Mental_Health": { | |
"mapped_name": "Mental Health Issues", | |
"description": ( | |
"Any reference to extreme mental health struggles, disorders, or psychological distress. This includes depictions of depression, anxiety, PTSD, bipolar disorder, " | |
"or other conditions. Also includes toxic traits such as Gaslighting or other psycholgoical horrors" | |
) | |
} | |
} | |
logger.info(f"Initialized analyzer with device: {self.device}") | |
async def load_model(self, progress=None) -> None: | |
"""Load the model and tokenizer with progress updates.""" | |
try: | |
if progress: | |
progress(0.1, "Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
"google/flan-t5-base", | |
use_fast=True | |
) | |
if progress: | |
progress(0.3, "Loading model...") | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
"google/flan-t5-base", | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
device_map="auto" | |
) | |
if self.device == "cuda": | |
self.model.eval() | |
torch.cuda.empty_cache() | |
if progress: | |
progress(0.5, "Model loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise | |
def _chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 30) -> List[str]: | |
"""Split text into overlapping chunks.""" | |
words = text.split() | |
chunks = [] | |
for i in range(0, len(words), chunk_size - overlap): | |
chunk = ' '.join(words[i:i + chunk_size]) | |
chunks.append(chunk) | |
return chunks | |
def _validate_response(self, response: str) -> str: | |
"""Validate and clean model response.""" | |
valid_responses = {"YES", "NO", "MAYBE"} | |
response = response.strip().upper() | |
first_word = response.split()[0] if response else "NO" | |
return first_word if first_word in valid_responses else "NO" | |
async def analyze_chunks_batch( | |
self, | |
chunks: List[str], | |
progress: Optional[gr.Progress] = None, | |
current_progress: float = 0, | |
progress_step: float = 0 | |
) -> Dict[str, float]: | |
"""Analyze multiple chunks in batches.""" | |
all_triggers = {} | |
for category, info in self.trigger_categories.items(): | |
mapped_name = info["mapped_name"] | |
description = info["description"] | |
for i in range(0, len(chunks), self.batch_size): | |
batch_chunks = chunks[i:i + self.batch_size] | |
prompts = [] | |
for chunk in batch_chunks: | |
prompt = f""" | |
Task: Analyze if this text contains {mapped_name}. | |
Context: {description} | |
Text: "{chunk}" | |
Rules for analysis: | |
1. Only answer YES if there is clear, direct evidence | |
2. Answer NO if the content is ambiguous or metaphorical | |
3. Consider the severity and context | |
Answer with ONLY ONE word: YES, NO, or MAYBE | |
""" | |
prompts.append(prompt) | |
try: | |
inputs = self.tokenizer( | |
prompts, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=512 | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=20, | |
temperature=0.2, | |
top_p=0.85, | |
num_beams=3, | |
early_stopping=True, | |
pad_token_id=self.tokenizer.eos_token_id, | |
do_sample=True | |
) | |
responses = [ | |
self.tokenizer.decode(output, skip_special_tokens=True) | |
for output in outputs | |
] | |
for response in responses: | |
validated_response = self._validate_response(response) | |
if validated_response == "YES": | |
all_triggers[mapped_name] = all_triggers.get(mapped_name, 0) + 1 | |
elif validated_response == "MAYBE": | |
all_triggers[mapped_name] = all_triggers.get(mapped_name, 0) + 0.5 | |
except Exception as e: | |
logger.error(f"Error processing batch for {mapped_name}: {str(e)}") | |
continue | |
if progress: | |
current_progress += progress_step | |
progress(min(current_progress, 0.9), f"Analyzing {mapped_name}...") | |
return all_triggers | |
async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]: | |
"""Analyze the entire script.""" | |
if not self.model or not self.tokenizer: | |
await self.load_model(progress) | |
chunks = self._chunk_text(script) | |
identified_triggers = await self.analyze_chunks_batch( | |
chunks, | |
progress, | |
current_progress=0.5, | |
progress_step=0.4 / (len(chunks) * len(self.trigger_categories)) | |
) | |
if progress: | |
progress(0.95, "Finalizing results...") | |
final_triggers = [] | |
chunk_threshold = max(1, len(chunks) * 0.1) | |
for mapped_name, count in identified_triggers.items(): | |
if count >= chunk_threshold: | |
final_triggers.append(mapped_name) | |
return final_triggers if final_triggers else ["None"] | |
async def analyze_content( | |
script: str, | |
progress: Optional[gr.Progress] = None | |
) -> Dict[str, Union[List[str], str]]: | |
"""Main analysis function for the Gradio interface.""" | |
logger.info("Starting content analysis") | |
analyzer = ContentAnalyzer() | |
try: | |
# Fix: Use the analyzer instance's method instead of undefined function | |
triggers = await analyzer.analyze_script(script, progress) | |
if progress: | |
progress(1.0, "Analysis complete!") | |
result = { | |
"detected_triggers": triggers, | |
"confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected", | |
"model": "google/large-t5-base", | |
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
} | |
logger.info(f"Analysis complete: {result}") | |
return result | |
except Exception as e: | |
logger.error(f"Analysis error: {str(e)}") | |
return { | |
"detected_triggers": ["Error occurred during analysis"], | |
"confidence": "Error", | |
"model": "google/flan-t5-base", | |
"analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"error": str(e) | |
} | |
if __name__ == "__main__": | |
iface = gr.Interface( | |
fn=analyze_content, | |
inputs=gr.Textbox(lines=8, label="Input Text"), | |
outputs=gr.JSON(), | |
title="Content Trigger Analysis", | |
description="Analyze text content for sensitive topics and trigger warnings" | |
) | |
iface.launch() |