Kuberwastaken commited on
Commit
04b7192
·
1 Parent(s): 6a425f0

Flan T5 Initial Commit

Browse files
.gitignore CHANGED
@@ -1 +1,4 @@
1
- treat-env
 
 
 
 
1
+ treat-classic
2
+ .gitignore
3
+ .gitattributes
4
+ __pycache__
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: TREAT
3
- emoji: 🍫
4
- colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: "5.11.0" # Replace with the correct version if different
8
  app_file: gradio_app.py
9
  pinned: true
10
  ---
 
1
  ---
2
+ title: TREAT-Classic
3
+ emoji: 🍨
4
+ colorFrom: pink
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: "5.11.0"
8
  app_file: gradio_app.py
9
  pinned: true
10
  ---
model/__pycache__/analyzer.cpython-310.pyc CHANGED
Binary files a/model/__pycache__/analyzer.cpython-310.pyc and b/model/__pycache__/analyzer.cpython-310.pyc differ
 
model/analyzer.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from datetime import datetime
5
  import gradio as gr
@@ -13,263 +13,231 @@ logger = logging.getLogger(__name__)
13
 
14
  class ContentAnalyzer:
15
  def __init__(self):
16
- self.hf_token = os.getenv("HF_TOKEN")
17
- if not self.hf_token:
18
- raise ValueError("HF_TOKEN environment variable is not set!")
19
-
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
  self.model = None
22
  self.tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  logger.info(f"Initialized analyzer with device: {self.device}")
24
 
25
  async def load_model(self, progress=None) -> None:
26
- """Load the model and tokenizer with progress updates and detailed logging."""
27
  try:
28
- print("\n=== Starting Model Loading ===")
29
- print(f"Time: {datetime.now()}")
30
-
31
  if progress:
32
  progress(0.1, "Loading tokenizer...")
33
 
34
- print("Loading tokenizer...")
35
  self.tokenizer = AutoTokenizer.from_pretrained(
36
- "meta-llama/Llama-3.2-3B",
37
  use_fast=True
38
  )
39
-
40
  if progress:
41
  progress(0.3, "Loading model...")
42
 
43
- print(f"Loading model on {self.device}...")
44
- self.model = AutoModelForCausalLM.from_pretrained(
45
- "meta-llama/Llama-3.2-3B",
46
- token=self.hf_token,
47
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
48
  device_map="auto"
49
  )
50
-
 
 
 
 
51
  if progress:
52
  progress(0.5, "Model loaded successfully")
53
-
54
- print("Model and tokenizer loaded successfully")
55
- logger.info(f"Model loaded successfully on {self.device}")
56
  except Exception as e:
57
  logger.error(f"Error loading model: {str(e)}")
58
- print(f"\nERROR DURING MODEL LOADING: {str(e)}")
59
- print("Stack trace:")
60
- traceback.print_exc()
61
  raise
62
 
63
- def _chunk_text(self, text: str, chunk_size: int = 256, overlap: int = 15) -> List[str]:
64
- """Split text into overlapping chunks for processing."""
 
65
  chunks = []
66
- for i in range(0, len(text), chunk_size - overlap):
67
- chunk = text[i:i + chunk_size]
68
  chunks.append(chunk)
69
- print(f"Split text into {len(chunks)} chunks with {overlap} token overlap")
70
  return chunks
71
 
72
- async def analyze_chunk(
 
 
 
 
 
 
 
73
  self,
74
- chunk: str,
75
- trigger_categories: Dict,
76
  progress: Optional[gr.Progress] = None,
77
  current_progress: float = 0,
78
  progress_step: float = 0
79
  ) -> Dict[str, float]:
80
- """Analyze a single chunk of text for triggers with detailed logging."""
81
- chunk_triggers = {}
82
- print(f"\n--- Processing Chunk ---")
83
- print(f"Chunk text (preview): {chunk[:50]}...")
84
 
85
- for category, info in trigger_categories.items():
86
  mapped_name = info["mapped_name"]
87
  description = info["description"]
88
-
89
- print(f"\nAnalyzing for {mapped_name}...")
90
- prompt = f"""
91
- Check this text for any clear indication of {mapped_name} ({description}).
92
- only say yes if you are confident, make sure the text is not metaphorical.
93
- Respond concisely and only with: YES, NO, or MAYBE.
94
- Text: {chunk}
95
- Answer:
96
- """
97
-
98
- try:
99
- print("Sending prompt to model...")
100
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
101
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
102
-
103
- with torch.no_grad():
104
- print("Generating response...")
105
- outputs = self.model.generate(
106
- **inputs,
107
- max_new_tokens=2,
108
- do_sample=True,
109
- temperature=0.3,
110
- top_p=0.9,
111
- pad_token_id=self.tokenizer.eos_token_id
112
- )
113
-
114
- response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip().upper()
115
- first_word = response_text.split("\n")[-1].split()[0] if response_text else "NO"
116
- print(f"Model response for {mapped_name}: {first_word}")
117
-
118
- if first_word == "YES":
119
- print(f"Detected {mapped_name} in this chunk!")
120
- chunk_triggers[mapped_name] = chunk_triggers.get(mapped_name, 0) + 1
121
- elif first_word == "MAYBE":
122
- print(f"Possible {mapped_name} detected, marking for further review.")
123
- chunk_triggers[mapped_name] = chunk_triggers.get(mapped_name, 0) + 0.5
124
- else:
125
- print(f"No {mapped_name} detected in this chunk.")
126
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  if progress:
128
  current_progress += progress_step
129
  progress(min(current_progress, 0.9), f"Analyzing {mapped_name}...")
130
-
131
- except Exception as e:
132
- logger.error(f"Error analyzing chunk for {mapped_name}: {str(e)}")
133
- print(f"Error during analysis of {mapped_name}: {str(e)}")
134
- traceback.print_exc()
135
-
136
- return chunk_triggers
137
 
138
  async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]:
139
- """Analyze the entire script for triggers with progress updates and detailed logging."""
140
- print("\n=== Starting Script Analysis ===")
141
- print(f"Time: {datetime.now()}")
142
-
143
  if not self.model or not self.tokenizer:
144
  await self.load_model(progress)
145
-
146
- # Initialize trigger categories (kept from your working script)
147
- trigger_categories = {
148
-
149
- "Violence": {
150
- "mapped_name": "Violence",
151
- "description": (
152
- "Any act of physical force meant to cause harm, injury, or death, including fights, threats, and large-scale violence like wars or riots."
153
- )
154
- },
155
-
156
- "Death": {
157
- "mapped_name": "Death References",
158
- "description": (
159
- "Mentions or depictions of death, such as characters dying, references to deceased people, funerals, or mourning."
160
- )
161
- },
162
-
163
- "Substance Use": {
164
- "mapped_name": "Substance Use",
165
- "description": (
166
- "Any reference to using or abusing drugs, alcohol, or other substances, including scenes of drinking, smoking, or drug use."
167
- )
168
- },
169
-
170
- "Gore": {
171
- "mapped_name": "Gore",
172
- "description": (
173
- "Graphic depictions of severe injuries or mutilation, often with detailed blood, exposed organs, or dismemberment."
174
- )
175
- },
176
-
177
- "Vomit": {
178
- "mapped_name": "Vomit",
179
- "description": (
180
- "Any explicit reference to vomiting or related actions. This includes only very specific mentions of nausea or the act of vomiting, with more focus on the direct description, only flag this if you absolutely believe it's present."
181
- )
182
- },
183
-
184
- "Sexual Content": {
185
- "mapped_name": "Sexual Content",
186
- "description": (
187
- "Depictions or mentions of sexual activity, intimacy, or behavior, including sexual themes like harassment or innuendo."
188
- )
189
- },
190
-
191
- "Sexual Abuse": {
192
- "mapped_name": "Sexual Abuse",
193
- "description": (
194
- "Explicit non-consensual sexual acts, including assault, molestation, or harassment, and the emotional or legal consequences of such abuse. A stronger focus on detailed depictions or direct references to coercion or violence."
195
- )
196
- },
197
-
198
- "Self-Harm": {
199
- "mapped_name": "Self-Harm",
200
- "description": (
201
- "Depictions or mentions of intentional self-injury, including acts like cutting, burning, or other self-destructive behavior. Emphasis on more graphic or repeated actions, not implied or casual references."
202
- )
203
- },
204
-
205
- "Gun Use": {
206
- "mapped_name": "Gun Use",
207
- "description": (
208
- "Explicit mentions of firearms in use, including threatening actions or accidents involving guns. Only triggers when the gun use is shown in a clear, violent context."
209
- )
210
- },
211
-
212
- "Animal Cruelty": {
213
- "mapped_name": "Animal Cruelty",
214
- "description": (
215
- "Direct or explicit harm, abuse, or neglect of animals, including physical abuse or suffering, and actions performed for human entertainment or experimentation. Triggers only in clear, violent depictions."
216
- )
217
- },
218
-
219
- "Mental Health Issues": {
220
- "mapped_name": "Mental Health Issues",
221
- "description": (
222
- "References to psychological struggles, such as depression, anxiety, or PTSD, including therapy or coping mechanisms."
223
- )
224
- }
225
- }
226
-
227
  chunks = self._chunk_text(script)
228
- identified_triggers = {}
229
- progress_step = 0.4 / (len(chunks) * len(trigger_categories))
230
- current_progress = 0.5 # Starting after model loading
231
-
232
- for chunk_idx, chunk in enumerate(chunks, 1):
233
- chunk_triggers = await self.analyze_chunk(
234
- chunk,
235
- trigger_categories,
236
- progress,
237
- current_progress,
238
- progress_step
239
- )
240
-
241
- for trigger, count in chunk_triggers.items():
242
- identified_triggers[trigger] = identified_triggers.get(trigger, 0) + count
243
-
244
  if progress:
245
  progress(0.95, "Finalizing results...")
246
 
247
- print("\n=== Analysis Complete ===")
248
- print("Final Results:")
249
  final_triggers = []
250
-
 
251
  for mapped_name, count in identified_triggers.items():
252
- if count > 0.5:
253
  final_triggers.append(mapped_name)
254
- print(f"- {mapped_name}: found in {count} chunks")
255
-
256
- if not final_triggers:
257
- print("No triggers detected")
258
- final_triggers = ["None"]
259
 
260
- return final_triggers
261
 
262
  async def analyze_content(
263
  script: str,
264
  progress: Optional[gr.Progress] = None
265
  ) -> Dict[str, Union[List[str], str]]:
266
- """Main analysis function for the Gradio interface with detailed logging."""
267
- print("\n=== Starting Content Analysis ===")
268
- print(f"Time: {datetime.now()}")
269
 
270
  analyzer = ContentAnalyzer()
271
 
272
  try:
 
273
  triggers = await analyzer.analyze_script(script, progress)
274
 
275
  if progress:
@@ -278,33 +246,29 @@ async def analyze_content(
278
  result = {
279
  "detected_triggers": triggers,
280
  "confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected",
281
- "model": "Llama-3.2-3B",
282
  "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
283
  }
284
 
285
- print("\nFinal Result Dictionary:", result)
286
  return result
287
 
288
  except Exception as e:
289
  logger.error(f"Analysis error: {str(e)}")
290
- print(f"\nERROR OCCURRED: {str(e)}")
291
- print("Stack trace:")
292
- traceback.print_exc()
293
  return {
294
  "detected_triggers": ["Error occurred during analysis"],
295
  "confidence": "Error",
296
- "model": "Llama-3.2-3B",
297
  "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
298
  "error": str(e)
299
  }
300
 
301
  if __name__ == "__main__":
302
- # Gradio interface
303
  iface = gr.Interface(
304
  fn=analyze_content,
305
  inputs=gr.Textbox(lines=8, label="Input Text"),
306
  outputs=gr.JSON(),
307
- title="Content Analysis",
308
- description="Analyze text content for sensitive topics"
309
  )
310
  iface.launch()
 
1
  import os
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
  from datetime import datetime
5
  import gradio as gr
 
13
 
14
  class ContentAnalyzer:
15
  def __init__(self):
 
 
 
 
16
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
  self.model = None
18
  self.tokenizer = None
19
+ self.batch_size = 4
20
+ self.trigger_categories = {
21
+ "Violence": {
22
+ "mapped_name": "Violence",
23
+ "description": (
24
+ "Any act involving physical force or aggression intended to cause harm, injury, or death to a person, animal, or object. "
25
+ "Includes direct physical confrontations (e.g., fights, beatings, or assaults), implied violence (e.g., very graphical threats or descriptions of injuries), "
26
+ "or large-scale events like wars, riots, or violent protests."
27
+ )
28
+ },
29
+ "Death": {
30
+ "mapped_name": "Death References",
31
+ "description": (
32
+ "Any mention, implication, or depiction of the loss of life, including direct deaths of characters, including mentions of deceased individuals, "
33
+ "or abstract references to mortality (e.g., 'facing the end' or 'gone forever'). This also covers depictions of funerals, mourning, "
34
+ "grieving, or any dialogue that centers around death, do not take metaphors into context that don't actually lead to death."
35
+ )
36
+ },
37
+ "Substance_Use": {
38
+ "mapped_name": "Substance Use",
39
+ "description": (
40
+ "Any explicit reference to the consumption, misuse, or abuse of drugs, alcohol, or other intoxicating substances. "
41
+ "This includes scenes of drug use, drinking, smoking, discussions about heavy substance abuse or substance-related paraphernalia."
42
+ )
43
+ },
44
+ "Gore": {
45
+ "mapped_name": "Gore",
46
+ "description": (
47
+ "Extremely detailed and graphic depictions of highly severe physical injuries, mutilation, or extreme bodily harm, often accompanied by descriptions of heavy blood, exposed organs, "
48
+ "or dismemberment. This includes war scenes with severe casualties, horror scenarios involving grotesque creatures, or medical procedures depicted with excessive detail."
49
+ )
50
+ },
51
+ "Sexual_Content": {
52
+ "mapped_name": "Sexual Content",
53
+ "description": (
54
+ "Any depiction of sexual activity, intimacy, or sexual behavior, ranging from implied scenes to explicit descriptions. "
55
+ "This includes physical descriptions of characters in a sexual context, sexual dialogue, or references to sexual themes."
56
+ )
57
+ },
58
+ "Sexual_Abuse": {
59
+ "mapped_name": "Sexual Abuse",
60
+ "description": (
61
+ "Any form of non-consensual sexual act, behavior, or interaction, involving coercion, manipulation, or physical force. "
62
+ "This includes incidents of sexual assault, exploitation, harassment, and any acts where an individual is subjected to sexual acts against their will."
63
+ )
64
+ },
65
+ "Self_Harm": {
66
+ "mapped_name": "Self-Harm",
67
+ "description": (
68
+ "Any mention or depiction of behaviors where an individual intentionally causes harm to themselves. This includes cutting, burning, or other forms of physical injury, "
69
+ "as well as suicidal ideation, suicide attempts, or discussions of self-destructive thoughts and actions."
70
+ )
71
+ },
72
+ "Mental_Health": {
73
+ "mapped_name": "Mental Health Issues",
74
+ "description": (
75
+ "Any reference to extreme mental health struggles, disorders, or psychological distress. This includes depictions of depression, anxiety, PTSD, bipolar disorder, "
76
+ "or other conditions. Also includes toxic traits such as Gaslighting or other psycholgoical horrors"
77
+ )
78
+ }
79
+ }
80
  logger.info(f"Initialized analyzer with device: {self.device}")
81
 
82
  async def load_model(self, progress=None) -> None:
83
+ """Load the model and tokenizer with progress updates."""
84
  try:
 
 
 
85
  if progress:
86
  progress(0.1, "Loading tokenizer...")
87
 
 
88
  self.tokenizer = AutoTokenizer.from_pretrained(
89
+ "google/flan-t5-base",
90
  use_fast=True
91
  )
92
+
93
  if progress:
94
  progress(0.3, "Loading model...")
95
 
96
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
97
+ "google/flan-t5-base",
 
 
98
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
99
  device_map="auto"
100
  )
101
+
102
+ if self.device == "cuda":
103
+ self.model.eval()
104
+ torch.cuda.empty_cache()
105
+
106
  if progress:
107
  progress(0.5, "Model loaded successfully")
108
+
 
 
109
  except Exception as e:
110
  logger.error(f"Error loading model: {str(e)}")
 
 
 
111
  raise
112
 
113
+ def _chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 30) -> List[str]:
114
+ """Split text into overlapping chunks."""
115
+ words = text.split()
116
  chunks = []
117
+ for i in range(0, len(words), chunk_size - overlap):
118
+ chunk = ' '.join(words[i:i + chunk_size])
119
  chunks.append(chunk)
 
120
  return chunks
121
 
122
+ def _validate_response(self, response: str) -> str:
123
+ """Validate and clean model response."""
124
+ valid_responses = {"YES", "NO", "MAYBE"}
125
+ response = response.strip().upper()
126
+ first_word = response.split()[0] if response else "NO"
127
+ return first_word if first_word in valid_responses else "NO"
128
+
129
+ async def analyze_chunks_batch(
130
  self,
131
+ chunks: List[str],
 
132
  progress: Optional[gr.Progress] = None,
133
  current_progress: float = 0,
134
  progress_step: float = 0
135
  ) -> Dict[str, float]:
136
+ """Analyze multiple chunks in batches."""
137
+ all_triggers = {}
 
 
138
 
139
+ for category, info in self.trigger_categories.items():
140
  mapped_name = info["mapped_name"]
141
  description = info["description"]
142
+
143
+ for i in range(0, len(chunks), self.batch_size):
144
+ batch_chunks = chunks[i:i + self.batch_size]
145
+ prompts = []
146
+
147
+ for chunk in batch_chunks:
148
+ prompt = f"""
149
+ Task: Analyze if this text contains {mapped_name}.
150
+ Context: {description}
151
+ Text: "{chunk}"
152
+
153
+ Rules for analysis:
154
+ 1. Only answer YES if there is clear, direct evidence
155
+ 2. Answer NO if the content is ambiguous or metaphorical
156
+ 3. Consider the severity and context
157
+
158
+ Answer with ONLY ONE word: YES, NO, or MAYBE
159
+ """
160
+ prompts.append(prompt)
161
+
162
+ try:
163
+ inputs = self.tokenizer(
164
+ prompts,
165
+ return_tensors="pt",
166
+ padding=True,
167
+ truncation=True,
168
+ max_length=512
169
+ ).to(self.device)
170
+
171
+ with torch.no_grad():
172
+ outputs = self.model.generate(
173
+ **inputs,
174
+ max_new_tokens=20,
175
+ temperature=0.2,
176
+ top_p=0.85,
177
+ num_beams=3,
178
+ early_stopping=True,
179
+ pad_token_id=self.tokenizer.eos_token_id,
180
+ do_sample=True
181
+ )
182
+
183
+ responses = [
184
+ self.tokenizer.decode(output, skip_special_tokens=True)
185
+ for output in outputs
186
+ ]
187
+
188
+ for response in responses:
189
+ validated_response = self._validate_response(response)
190
+ if validated_response == "YES":
191
+ all_triggers[mapped_name] = all_triggers.get(mapped_name, 0) + 1
192
+ elif validated_response == "MAYBE":
193
+ all_triggers[mapped_name] = all_triggers.get(mapped_name, 0) + 0.5
194
+
195
+ except Exception as e:
196
+ logger.error(f"Error processing batch for {mapped_name}: {str(e)}")
197
+ continue
198
+
199
  if progress:
200
  current_progress += progress_step
201
  progress(min(current_progress, 0.9), f"Analyzing {mapped_name}...")
202
+
203
+ return all_triggers
 
 
 
 
 
204
 
205
  async def analyze_script(self, script: str, progress: Optional[gr.Progress] = None) -> List[str]:
206
+ """Analyze the entire script."""
 
 
 
207
  if not self.model or not self.tokenizer:
208
  await self.load_model(progress)
209
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  chunks = self._chunk_text(script)
211
+ identified_triggers = await self.analyze_chunks_batch(
212
+ chunks,
213
+ progress,
214
+ current_progress=0.5,
215
+ progress_step=0.4 / (len(chunks) * len(self.trigger_categories))
216
+ )
217
+
 
 
 
 
 
 
 
 
 
218
  if progress:
219
  progress(0.95, "Finalizing results...")
220
 
 
 
221
  final_triggers = []
222
+ chunk_threshold = max(1, len(chunks) * 0.1)
223
+
224
  for mapped_name, count in identified_triggers.items():
225
+ if count >= chunk_threshold:
226
  final_triggers.append(mapped_name)
 
 
 
 
 
227
 
228
+ return final_triggers if final_triggers else ["None"]
229
 
230
  async def analyze_content(
231
  script: str,
232
  progress: Optional[gr.Progress] = None
233
  ) -> Dict[str, Union[List[str], str]]:
234
+ """Main analysis function for the Gradio interface."""
235
+ logger.info("Starting content analysis")
 
236
 
237
  analyzer = ContentAnalyzer()
238
 
239
  try:
240
+ # Fix: Use the analyzer instance's method instead of undefined function
241
  triggers = await analyzer.analyze_script(script, progress)
242
 
243
  if progress:
 
246
  result = {
247
  "detected_triggers": triggers,
248
  "confidence": "High - Content detected" if triggers != ["None"] else "High - No concerning content detected",
249
+ "model": "google/flan-t5-base",
250
  "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
251
  }
252
 
253
+ logger.info(f"Analysis complete: {result}")
254
  return result
255
 
256
  except Exception as e:
257
  logger.error(f"Analysis error: {str(e)}")
 
 
 
258
  return {
259
  "detected_triggers": ["Error occurred during analysis"],
260
  "confidence": "Error",
261
+ "model": "google/flan-t5-base",
262
  "analysis_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
263
  "error": str(e)
264
  }
265
 
266
  if __name__ == "__main__":
 
267
  iface = gr.Interface(
268
  fn=analyze_content,
269
  inputs=gr.Textbox(lines=8, label="Input Text"),
270
  outputs=gr.JSON(),
271
+ title="Content Trigger Analysis",
272
+ description="Analyze text content for sensitive topics and trigger warnings"
273
  )
274
  iface.launch()