imperiusrex commited on
Commit
a26cb4b
·
verified ·
1 Parent(s): 1c679de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -165
app.py CHANGED
@@ -6,16 +6,8 @@ import torch
6
  import spaces
7
  from ultralytics import YOLO
8
  from tqdm import tqdm
9
- from PIL import Image
10
- import logging
11
- import time
12
- from transformers import MobileViTFeatureExtractor, MobileViTForImageClassification
13
- from sentence_transformers import util
14
- import gc
15
-
16
- # Set up logging
17
- logging.basicConfig(level=logging.INFO)
18
- logger = logging.getLogger(__name__)
19
 
20
  # Fix for Ultralytics config write error in Hugging Face environment
21
  os.environ["YOLO_CONFIG_DIR"] = "/tmp"
@@ -23,37 +15,19 @@ os.environ["YOLO_CONFIG_DIR"] = "/tmp"
23
  # Use GPU if available
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
- # Static knowledge base with prompts and explanations
27
- KNOWLEDGE_BASE = [
28
- {
29
- "prompt": "A physics equation on a whiteboard",
30
- "explanation": "The board likely contains a physics equation, such as E = mc², which is Einstein's mass-energy equivalence formula. It states that energy (E) equals mass (m) times the speed of light (c) squared, a key concept in relativity."
31
- },
32
- {
33
- "prompt": "A geometry diagram on a whiteboard",
34
- "explanation": "The board shows a geometry diagram, possibly related to the Pythagorean theorem (a² + b² = c²), which applies to right-angled triangles to calculate side lengths."
35
- },
36
- {
37
- "prompt": "A chemistry formula on a whiteboard",
38
- "explanation": "The board displays a chemistry formula, such as a chemical equation or molecular structure, used to describe reactions or compounds."
39
- },
40
- {
41
- "prompt": "A biology diagram on a whiteboard",
42
- "explanation": "The board shows a biology diagram, such as a cell structure or photosynthesis process, illustrating biological concepts."
43
- }
44
- ]
45
 
46
  @spaces.GPU
47
  def process_video(video_path):
48
- # Load YOLO models
49
- try:
50
- extract_model = YOLO("best.pt").to(device)
51
- detect_model = YOLO("yolov8n.pt").to(device)
52
- except Exception as e:
53
- logger.error(f"Failed to load YOLO models: {str(e)}")
54
- raise RuntimeError(f"Failed to load YOLO models: {str(e)}")
55
-
56
- os.makedirs("frames", exist_ok=True)
57
 
58
  # Step 1: Extract board-only frames
59
  cap = cv2.VideoCapture(video_path)
@@ -66,7 +40,7 @@ def process_video(video_path):
66
  labels = [extract_model.names[int(c)] for c in results[0].boxes.cls.cpu().numpy()]
67
  if "board" in labels and "person" not in labels:
68
  frames.append(frame)
69
- cv2.imwrite(f"frames/frame_{idx:04d}.jpg", frame)
70
  idx += 1
71
  cap.release()
72
  if not frames:
@@ -100,7 +74,7 @@ def process_video(video_path):
100
  # Step 3: Median-fuse
101
  stack = np.stack(aligned, axis=0).astype(np.float32)
102
  median_board = np.median(stack, axis=0).astype(np.uint8)
103
- cv2.imwrite("clean_board.jpg", median_board)
104
 
105
  # Step 4: Mask persons & selective fuse
106
  sum_img = np.zeros_like(aligned[0], dtype=np.float32)
@@ -119,132 +93,33 @@ def process_video(video_path):
119
 
120
  count[count == 0] = 1
121
  selective = (sum_img / count[:, :, None]).astype(np.uint8)
122
- cv2.imwrite("fused_board_selective.jpg", selective)
123
 
124
  # Step 5: Sharpen
125
  blur = cv2.GaussianBlur(selective, (5, 5), 0)
126
  sharp = cv2.addWeighted(selective, 1.5, blur, -0.5, 0)
127
- cv2.imwrite("sharpened_board_color.jpg", sharp)
128
-
129
- # Free YOLO models to save memory
130
- extract_model = None
131
- detect_model = None
132
- gc.collect()
133
- if device == "cuda":
134
- torch.cuda.empty_cache()
135
-
136
- return sharp
137
-
138
- def generate_related_content(image, retries=3):
139
- # Load MobileViT model
140
- model = None
141
- feature_extractor = None
142
- try:
143
- model = MobileViTForImageClassification.from_pretrained(
144
- "apple/mobilevit-xxs",
145
- torch_dtype=torch.bfloat16,
146
- low_cpu_mem_usage=True
147
- ).to(device)
148
- feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-xxs")
149
- logger.info("Successfully loaded MobileViT model and feature extractor")
150
- except Exception as e:
151
- logger.error(f"Failed to load MobileViT model: {str(e)}")
152
- return (
153
- "Error: Failed to load MobileViT model due to insufficient memory. "
154
- "Consider upgrading to a paid Space with GPU.\n\n"
155
- "For further reading:\n"
156
- "- Khan Academy: https://www.khanacademy.org\n"
157
- "- Wikipedia: https://en.wikipedia.org/wiki/Education\n"
158
- "- MIT OpenCourseWare: https://ocw.mit.edu"
159
- )
160
-
161
- # Convert OpenCV image to PIL
162
- try:
163
- image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
164
- except Exception as e:
165
- logger.error(f"Image conversion failed: {str(e)}")
166
- return f"Error converting image: {str(e)}"
167
-
168
- # Load sentence transformer for prompt embeddings
169
- try:
170
- from sentence_transformers import SentenceTransformer
171
- text_encoder = SentenceTransformer("all-MiniLM-L6-v2")
172
- logger.info("Successfully loaded sentence transformer")
173
- except Exception as e:
174
- logger.error(f"Failed to load sentence transformer: {str(e)}")
175
- return (
176
- "Error: Failed to load text encoder for prompts.\n\n"
177
- "For further reading:\n"
178
- "- Khan Academy: https://www.khanacademy.org\n"
179
- "- Wikipedia: https://en.wikipedia.org/wiki/Education\n"
180
- "- MIT OpenCourseWare: https://ocw.mit.edu"
181
- )
182
-
183
- # Process image and prompts
184
- for attempt in range(retries):
185
- try:
186
- # Prepare image inputs
187
- inputs = feature_extractor(images=image_pil, return_tensors="pt").to(device)
188
-
189
- # Get image features
190
- with torch.no_grad():
191
- outputs = model(**inputs, output_hidden_states=True)
192
- # Use the last hidden state as features (approximating CLIP-like embeddings)
193
- image_features = outputs.hidden_states[-1].mean(dim=1) # Average pooling
194
-
195
- # Encode prompts
196
- prompts = [entry["prompt"] for entry in KNOWLEDGE_BASE]
197
- text_features = text_encoder.encode(prompts, convert_to_tensor=True, device=device)
198
-
199
- # Compute cosine similarities
200
- similarities = util.cos_sim(image_features, text_features)[0]
201
- best_match_idx = similarities.argmax()
202
- best_score = similarities[best_match_idx].item()
203
-
204
- # Threshold for confidence
205
- if best_score < 0.2:
206
- logger.warning("No confident match found for image content")
207
- explanation = "The board content could not be confidently identified."
208
- matched_prompt = "Unknown content"
209
- else:
210
- matched_prompt = prompts[best_match_idx]
211
- explanation = next(entry["explanation"] for entry in KNOWLEDGE_BASE if entry["prompt"] == matched_prompt)
212
- logger.info(f"Matched prompt: {matched_prompt} (score: {best_score:.2f})")
213
-
214
- references = (
215
- "For further reading:\n"
216
- "- Khan Academy: https://www.khanacademy.org\n"
217
- "- Wikipedia: https://en.wikipedia.org/wiki/Education\n"
218
- "- MIT OpenCourseWare: https://ocw.mit.edu"
219
- )
220
- return f"Content: {matched_prompt}\n\nExplanation: {explanation}\n\n{references}"
221
- except Exception as e:
222
- error_msg = f"MobileViT processing attempt {attempt + 1} failed: {str(e)}"
223
- logger.error(error_msg)
224
- if attempt == retries - 1:
225
- return f"Error generating content with MobileViT: {error_msg}\n\n{references}"
226
- time.sleep(2 ** attempt)
227
- finally:
228
- # Free model to save memory
229
- model = None
230
- feature_extractor = None
231
- gc.collect()
232
- if device == "cuda":
233
- torch.cuda.empty_cache()
234
 
235
- def process_and_generate(video_path):
236
- try:
237
- # Process video to get sharpened image
238
- sharpened_image = process_video(video_path)
239
- # Generate related content
240
- generated_content = generate_related_content(sharpened_image)
241
- return sharpened_image, generated_content
242
- except Exception as e:
243
- logger.error(f"Processing failed: {str(e)}")
244
- return None, f"Error processing video: {str(e)}"
245
 
 
246
  demo = gr.Interface(
247
- fn=process_and_generate,
248
  inputs=[
249
  gr.File(
250
  label="Upload Classroom Video (.mp4)",
@@ -255,19 +130,18 @@ demo = gr.Interface(
255
  ],
256
  outputs=[
257
  gr.Image(label="Sharpened Final Board"),
258
- gr.Textbox(label="Generated Content and Explanation")
259
  ],
260
- title="📹 Classroom Board Cleaner & Content Generator",
261
  description=(
262
  "Upload your classroom video (.mp4). \n"
263
- "Automatic extraction, alignment, masking, fusion & sharpening. \n"
264
- "Generates a summary and detailed explanation of the board content using MobileViT-XXS."
265
  )
266
  )
267
 
268
  if __name__ == "__main__":
269
  if device == "cuda":
270
- logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
271
  else:
272
- logger.info("Using CPU (GPU not available or not assigned)")
273
  demo.launch()
 
6
  import spaces
7
  from ultralytics import YOLO
8
  from tqdm import tqdm
9
+ import easyocr
10
+ from transformers import pipeline
 
 
 
 
 
 
 
 
11
 
12
  # Fix for Ultralytics config write error in Hugging Face environment
13
  os.environ["YOLO_CONFIG_DIR"] = "/tmp"
 
15
  # Use GPU if available
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ # Load models onto the appropriate device
19
+ extract_model = YOLO("best.pt").to(device)
20
+ detect_model = YOLO("yolov8n.pt").to(device)
21
+
22
+ # Initialize EasyOCR reader (English language, GPU if available)
23
+ reader = easyocr.Reader(['en'], gpu=(device == "cuda"))
24
+
25
+ # Initialize text generation model (distilgpt2 for lightweight performance)
26
+ generator = pipeline("text-generation", model="distilgpt2", device=0 if device == "cuda" else -1)
 
 
 
 
 
 
 
 
 
 
27
 
28
  @spaces.GPU
29
  def process_video(video_path):
30
+ os.makedirs("/tmp/frames", exist_ok=True)
 
 
 
 
 
 
 
 
31
 
32
  # Step 1: Extract board-only frames
33
  cap = cv2.VideoCapture(video_path)
 
40
  labels = [extract_model.names[int(c)] for c in results[0].boxes.cls.cpu().numpy()]
41
  if "board" in labels and "person" not in labels:
42
  frames.append(frame)
43
+ cv2.imwrite(f"/tmp/frames/frame_{idx:04d}.jpg", frame)
44
  idx += 1
45
  cap.release()
46
  if not frames:
 
74
  # Step 3: Median-fuse
75
  stack = np.stack(aligned, axis=0).astype(np.float32)
76
  median_board = np.median(stack, axis=0).astype(np.uint8)
77
+ cv2.imwrite("/tmp/clean_board.jpg", median_board)
78
 
79
  # Step 4: Mask persons & selective fuse
80
  sum_img = np.zeros_like(aligned[0], dtype=np.float32)
 
93
 
94
  count[count == 0] = 1
95
  selective = (sum_img / count[:, :, None]).astype(np.uint8)
96
+ cv2.imwrite("/tmp/fused_board_selective.jpg", selective)
97
 
98
  # Step 5: Sharpen
99
  blur = cv2.GaussianBlur(selective, (5, 5), 0)
100
  sharp = cv2.addWeighted(selective, 1.5, blur, -0.5, 0)
101
+ output_image = "/tmp/sharpened_board_color.jpg"
102
+ cv2.imwrite(output_image, sharp)
103
+
104
+ # Step 6: Detect text using EasyOCR (not displayed)
105
+ results = reader.readtext(output_image)
106
+ detected_text = " ".join([result[1] for result in results]).strip()
107
+ if not detected_text:
108
+ return output_image, "No text detected on the board."
109
+
110
+ # Step 7: Generate explanation using distilgpt2
111
+ prompt = (
112
+ f"You are an expert teacher. The following content was detected on a classroom board: '{detected_text}'. "
113
+ "Provide a detailed explanation of the content, including definitions, examples, or step-by-step solutions if applicable. "
114
+ "If the content is an equation, solve it or explain its significance. If it's a concept, provide context and examples."
115
+ )
116
+ explanation = generator(prompt, max_length=200, num_return_sequences=1, truncation=True)[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ return output_image, explanation
 
 
 
 
 
 
 
 
 
119
 
120
+ # Update Gradio interface
121
  demo = gr.Interface(
122
+ fn=process_video,
123
  inputs=[
124
  gr.File(
125
  label="Upload Classroom Video (.mp4)",
 
130
  ],
131
  outputs=[
132
  gr.Image(label="Sharpened Final Board"),
133
+ gr.Textbox(label="Explanation of Board Content")
134
  ],
135
+ title="📹 Classroom Board Cleaner & Content Explainer",
136
  description=(
137
  "Upload your classroom video (.mp4). \n"
138
+ "Automatic board extraction, sharpening, and explanation of detected content."
 
139
  )
140
  )
141
 
142
  if __name__ == "__main__":
143
  if device == "cuda":
144
+ print(f"[INFO] ✅ Using GPU: {torch.cuda.get_device_name(0)}")
145
  else:
146
+ print("[INFO] ⚠️ Using CPU (GPU not available or not assigned)")
147
  demo.launch()