Gabriel commited on
Commit
d6e55c9
·
verified ·
1 Parent(s): f094617

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -102
app.py CHANGED
@@ -117,66 +117,68 @@ PIPELINE_CONFIGS = {
117
  }
118
 
119
  @spaces.GPU
120
- def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "spread_swedish", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict:
121
  """Process handwritten text recognition on uploaded images using HTRflow pipelines."""
122
- try:
123
- if image is None:
124
- return {"success": False, "error": "No image provided", "results": None}
125
 
126
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
127
- image.save(temp_file.name, "PNG")
128
- temp_image_path = temp_file.name
129
 
130
- try:
131
- if custom_settings:
132
- try:
133
- config = json.loads(custom_settings)
134
- except json.JSONDecodeError:
135
- return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None}
136
- else:
137
- config = PIPELINE_CONFIGS[document_type]
138
 
139
- collection = Collection([temp_image_path])
140
- pipeline = Pipeline.from_config(config)
 
 
141
  processed_collection = pipeline.run(collection)
 
 
142
 
143
- img_buffer = io.BytesIO()
144
- image.save(img_buffer, format="PNG")
145
- image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
146
 
147
- results = extract_text_results(processed_collection, confidence_threshold)
148
-
149
- processing_state = {
150
- "collection_data": serialize_collection_data(processed_collection),
151
- "image_base64": image_base64,
152
- "image_size": image.size,
153
- "document_type": document_type,
154
- "confidence_threshold": confidence_threshold,
155
- "timestamp": datetime.now().isoformat(),
156
- }
157
 
158
- return {
159
- "success": True,
160
- "results": results,
161
- "processing_state": json.dumps(processing_state),
162
- "metadata": {
163
- "total_lines": len(results.get("text_lines", [])),
164
- "average_confidence": results.get("average_confidence", 0),
165
- "document_type": document_type,
166
- "image_dimensions": image.size,
167
- },
168
- }
169
- finally:
170
- if os.path.exists(temp_image_path):
171
- os.unlink(temp_image_path)
172
  except Exception as e:
173
  return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None}
 
 
 
174
 
175
  def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict:
176
  """Generate interactive visualizations of HTR processing results."""
177
  try:
178
  state = json.loads(processing_state)
179
- collection_data = state["collection_data"]
180
 
181
  if image is not None:
182
  original_image = image
@@ -184,22 +186,36 @@ def visualize_results(processing_state: str, visualization_type: Literal["overla
184
  image_data = base64.b64decode(state["image_base64"])
185
  original_image = Image.open(io.BytesIO(image_data))
186
 
187
- viz_image = create_visualization(original_image, collection_data, visualization_type, show_confidence, highlight_low_confidence)
 
 
 
188
 
189
- img_buffer = io.BytesIO()
190
- viz_image.save(img_buffer, format="PNG")
191
- img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
 
 
 
192
 
193
- return {
194
- "success": True,
195
- "visualization": {
196
- "image_base64": img_base64,
197
- "image_format": "PNG",
198
- "visualization_type": visualization_type,
199
- "dimensions": viz_image.size,
200
- },
201
- "metadata": {"total_elements": len(collection_data.get("text_elements", []))},
202
- }
 
 
 
 
 
 
 
 
203
  except Exception as e:
204
  return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None}
205
 
@@ -230,9 +246,14 @@ def export_results(processing_state: str, output_formats: List[Literal["txt", "j
230
  for root, _, files in os.walk(export_dir):
231
  for file in files:
232
  file_path = os.path.join(root, file)
233
- with open(file_path, 'r', encoding='utf-8') as f:
234
- content = f.read()
235
- export_files.append({"filename": file, "content": content})
 
 
 
 
 
236
 
237
  exports[fmt] = export_files
238
 
@@ -270,19 +291,7 @@ def extract_text_results(collection: Collection, confidence_threshold: float) ->
270
  results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0
271
  return results
272
 
273
- def serialize_collection_data(collection: Collection) -> Dict:
274
- text_elements = []
275
- for page in collection.pages:
276
- for node in page.traverse():
277
- if hasattr(node, "text") and node.text:
278
- text_elements.append({
279
- "text": node.text,
280
- "confidence": getattr(node, "confidence", 1.0),
281
- "bbox": getattr(node, "bbox", None),
282
- })
283
- return {"text_elements": text_elements}
284
-
285
- def create_visualization(image, collection_data, visualization_type, show_confidence, highlight_low_confidence):
286
  viz_image = image.copy()
287
  draw = ImageDraw.Draw(viz_image)
288
 
@@ -291,33 +300,34 @@ def create_visualization(image, collection_data, visualization_type, show_confid
291
  except:
292
  font = ImageFont.load_default()
293
 
294
- for element in collection_data.get("text_elements", []):
295
- if element.get("bbox"):
296
- bbox = element["bbox"]
297
- confidence = element.get("confidence", 1.0)
298
-
299
- if visualization_type == "overlay":
300
- color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0)
301
- draw.rectangle(bbox, outline=color, width=2)
302
- if show_confidence:
303
- draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font)
304
-
305
- elif visualization_type == "confidence_heatmap":
306
- if confidence < 0.5:
307
- color = (255, 0, 0, 100)
308
- elif confidence < 0.8:
309
- color = (255, 255, 0, 100)
310
- else:
311
- color = (0, 255, 0, 100)
312
- overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
313
- overlay_draw = ImageDraw.Draw(overlay)
314
- overlay_draw.rectangle(bbox, fill=color)
315
- viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
316
-
317
- elif visualization_type == "text_regions":
318
- colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
319
- color = colors[hash(str(bbox)) % len(colors)]
320
- draw.rectangle(bbox, outline=color, width=3)
 
321
 
322
  return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image
323
 
 
117
  }
118
 
119
  @spaces.GPU
120
+ def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "letter_english", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict:
121
  """Process handwritten text recognition on uploaded images using HTRflow pipelines."""
122
+ if image is None:
123
+ return {"success": False, "error": "No image provided", "results": None}
 
124
 
125
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
126
+ image.save(temp_file.name, "PNG")
127
+ temp_image_path = temp_file.name
128
 
129
+ try:
130
+ if custom_settings:
131
+ try:
132
+ config = json.loads(custom_settings)
133
+ except json.JSONDecodeError:
134
+ return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None}
135
+ else:
136
+ config = PIPELINE_CONFIGS[document_type]
137
 
138
+ collection = Collection([temp_image_path])
139
+ pipeline = Pipeline.from_config(config)
140
+
141
+ try:
142
  processed_collection = pipeline.run(collection)
143
+ except Exception as pipeline_error:
144
+ return {"success": False, "error": f"Pipeline execution failed: {str(pipeline_error)}", "results": None}
145
 
146
+ img_buffer = io.BytesIO()
147
+ image.save(img_buffer, format="PNG")
148
+ image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
149
 
150
+ results = extract_text_results(processed_collection, confidence_threshold)
151
+
152
+ processing_state = {
153
+ "processed_collection": processed_collection,
154
+ "image_base64": image_base64,
155
+ "image_size": image.size,
156
+ "document_type": document_type,
157
+ "confidence_threshold": confidence_threshold,
158
+ "timestamp": datetime.now().isoformat(),
159
+ }
160
 
161
+ return {
162
+ "success": True,
163
+ "results": results,
164
+ "processing_state": json.dumps(processing_state, default=str),
165
+ "metadata": {
166
+ "total_lines": len(results.get("text_lines", [])),
167
+ "average_confidence": results.get("average_confidence", 0),
168
+ "document_type": document_type,
169
+ "image_dimensions": image.size,
170
+ },
171
+ }
 
 
 
172
  except Exception as e:
173
  return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None}
174
+ finally:
175
+ if os.path.exists(temp_image_path):
176
+ os.unlink(temp_image_path)
177
 
178
  def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict:
179
  """Generate interactive visualizations of HTR processing results."""
180
  try:
181
  state = json.loads(processing_state)
 
182
 
183
  if image is not None:
184
  original_image = image
 
186
  image_data = base64.b64decode(state["image_base64"])
187
  original_image = Image.open(io.BytesIO(image_data))
188
 
189
+ # Recreate the collection from the stored image
190
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
191
+ original_image.save(temp_file.name, "PNG")
192
+ temp_image_path = temp_file.name
193
 
194
+ try:
195
+ collection = Collection([temp_image_path])
196
+ pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]])
197
+ processed_collection = pipeline.run(collection)
198
+
199
+ viz_image = create_visualization(original_image, processed_collection, visualization_type, show_confidence, highlight_low_confidence)
200
 
201
+ img_buffer = io.BytesIO()
202
+ viz_image.save(img_buffer, format="PNG")
203
+ img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
204
+
205
+ return {
206
+ "success": True,
207
+ "visualization": {
208
+ "image_base64": img_base64,
209
+ "image_format": "PNG",
210
+ "visualization_type": visualization_type,
211
+ "dimensions": viz_image.size,
212
+ },
213
+ "metadata": {"visualization_type": visualization_type},
214
+ }
215
+ finally:
216
+ if os.path.exists(temp_image_path):
217
+ os.unlink(temp_image_path)
218
+
219
  except Exception as e:
220
  return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None}
221
 
 
246
  for root, _, files in os.walk(export_dir):
247
  for file in files:
248
  file_path = os.path.join(root, file)
249
+ try:
250
+ with open(file_path, 'r', encoding='utf-8') as f:
251
+ content = f.read()
252
+ export_files.append({"filename": file, "content": content})
253
+ except UnicodeDecodeError:
254
+ with open(file_path, 'rb') as f:
255
+ content = base64.b64encode(f.read()).decode('utf-8')
256
+ export_files.append({"filename": file, "content": content, "encoding": "base64"})
257
 
258
  exports[fmt] = export_files
259
 
 
291
  results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0
292
  return results
293
 
294
+ def create_visualization(image, collection, visualization_type, show_confidence, highlight_low_confidence):
 
 
 
 
 
 
 
 
 
 
 
 
295
  viz_image = image.copy()
296
  draw = ImageDraw.Draw(viz_image)
297
 
 
300
  except:
301
  font = ImageFont.load_default()
302
 
303
+ for page in collection.pages:
304
+ for node in page.traverse():
305
+ if hasattr(node, "bbox") and hasattr(node, "text") and node.bbox and node.text:
306
+ bbox = node.bbox
307
+ confidence = getattr(node, "confidence", 1.0)
308
+
309
+ if visualization_type == "overlay":
310
+ color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0)
311
+ draw.rectangle(bbox, outline=color, width=2)
312
+ if show_confidence:
313
+ draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font)
314
+
315
+ elif visualization_type == "confidence_heatmap":
316
+ if confidence < 0.5:
317
+ color = (255, 0, 0, 100)
318
+ elif confidence < 0.8:
319
+ color = (255, 255, 0, 100)
320
+ else:
321
+ color = (0, 255, 0, 100)
322
+ overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
323
+ overlay_draw = ImageDraw.Draw(overlay)
324
+ overlay_draw.rectangle(bbox, fill=color)
325
+ viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
326
+
327
+ elif visualization_type == "text_regions":
328
+ colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
329
+ color = colors[hash(str(bbox)) % len(colors)]
330
+ draw.rectangle(bbox, outline=color, width=3)
331
 
332
  return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image
333