|
import gradio as gr |
|
import json |
|
import base64 |
|
import tempfile |
|
import os |
|
from typing import Dict, List, Optional, Literal |
|
from datetime import datetime |
|
from PIL import Image, ImageDraw, ImageFont |
|
import io |
|
import spaces |
|
import shutil |
|
from pathlib import Path |
|
from htrflow.volume.volume import Collection |
|
from htrflow.pipeline.pipeline import Pipeline |
|
|
|
PIPELINE_CONFIGS = { |
|
"letter_english": { |
|
"steps": [ |
|
{ |
|
"step": "Segmentation", |
|
"settings": { |
|
"model": "yolo", |
|
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, |
|
"generation_settings": {"batch_size": 8}, |
|
}, |
|
}, |
|
{ |
|
"step": "TextRecognition", |
|
"settings": { |
|
"model": "TrOCR", |
|
"model_settings": {"model": "microsoft/trocr-base-handwritten"}, |
|
"generation_settings": {"batch_size": 16}, |
|
}, |
|
}, |
|
{"step": "OrderLines"}, |
|
] |
|
}, |
|
"letter_swedish": { |
|
"steps": [ |
|
{ |
|
"step": "Segmentation", |
|
"settings": { |
|
"model": "yolo", |
|
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, |
|
"generation_settings": {"batch_size": 8}, |
|
}, |
|
}, |
|
{ |
|
"step": "TextRecognition", |
|
"settings": { |
|
"model": "TrOCR", |
|
"model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"}, |
|
"generation_settings": {"batch_size": 16}, |
|
}, |
|
}, |
|
{"step": "OrderLines"}, |
|
] |
|
}, |
|
"spread_english": { |
|
"steps": [ |
|
{ |
|
"step": "Segmentation", |
|
"settings": { |
|
"model": "yolo", |
|
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, |
|
"generation_settings": {"batch_size": 4}, |
|
}, |
|
}, |
|
{ |
|
"step": "Segmentation", |
|
"settings": { |
|
"model": "yolo", |
|
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, |
|
"generation_settings": {"batch_size": 8}, |
|
}, |
|
}, |
|
{ |
|
"step": "TextRecognition", |
|
"settings": { |
|
"model": "TrOCR", |
|
"model_settings": {"model": "microsoft/trocr-base-handwritten"}, |
|
"generation_settings": {"batch_size": 16}, |
|
}, |
|
}, |
|
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, |
|
] |
|
}, |
|
"spread_swedish": { |
|
"steps": [ |
|
{ |
|
"step": "Segmentation", |
|
"settings": { |
|
"model": "yolo", |
|
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"}, |
|
"generation_settings": {"batch_size": 4}, |
|
}, |
|
}, |
|
{ |
|
"step": "Segmentation", |
|
"settings": { |
|
"model": "yolo", |
|
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"}, |
|
"generation_settings": {"batch_size": 8}, |
|
}, |
|
}, |
|
{ |
|
"step": "TextRecognition", |
|
"settings": { |
|
"model": "TrOCR", |
|
"model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"}, |
|
"generation_settings": {"batch_size": 16}, |
|
}, |
|
}, |
|
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}}, |
|
] |
|
}, |
|
} |
|
|
|
@spaces.GPU |
|
def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "spread_swedish", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict: |
|
"""Process handwritten text recognition on uploaded images using HTRflow pipelines.""" |
|
try: |
|
if image is None: |
|
return {"success": False, "error": "No image provided", "results": None} |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: |
|
image.save(temp_file.name, "PNG") |
|
temp_image_path = temp_file.name |
|
|
|
try: |
|
if custom_settings: |
|
try: |
|
config = json.loads(custom_settings) |
|
except json.JSONDecodeError: |
|
return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None} |
|
else: |
|
config = PIPELINE_CONFIGS[document_type] |
|
|
|
collection = Collection([temp_image_path]) |
|
pipeline = Pipeline.from_config(config) |
|
processed_collection = pipeline.run(collection) |
|
|
|
img_buffer = io.BytesIO() |
|
image.save(img_buffer, format="PNG") |
|
image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") |
|
|
|
results = extract_text_results(processed_collection, confidence_threshold) |
|
|
|
processing_state = { |
|
"collection_data": serialize_collection_data(processed_collection), |
|
"image_base64": image_base64, |
|
"image_size": image.size, |
|
"document_type": document_type, |
|
"confidence_threshold": confidence_threshold, |
|
"timestamp": datetime.now().isoformat(), |
|
} |
|
|
|
return { |
|
"success": True, |
|
"results": results, |
|
"processing_state": json.dumps(processing_state), |
|
"metadata": { |
|
"total_lines": len(results.get("text_lines", [])), |
|
"average_confidence": results.get("average_confidence", 0), |
|
"document_type": document_type, |
|
"image_dimensions": image.size, |
|
}, |
|
} |
|
finally: |
|
if os.path.exists(temp_image_path): |
|
os.unlink(temp_image_path) |
|
except Exception as e: |
|
return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None} |
|
|
|
def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict: |
|
"""Generate interactive visualizations of HTR processing results.""" |
|
try: |
|
state = json.loads(processing_state) |
|
collection_data = state["collection_data"] |
|
|
|
if image is not None: |
|
original_image = image |
|
else: |
|
image_data = base64.b64decode(state["image_base64"]) |
|
original_image = Image.open(io.BytesIO(image_data)) |
|
|
|
viz_image = create_visualization(original_image, collection_data, visualization_type, show_confidence, highlight_low_confidence) |
|
|
|
img_buffer = io.BytesIO() |
|
viz_image.save(img_buffer, format="PNG") |
|
img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") |
|
|
|
return { |
|
"success": True, |
|
"visualization": { |
|
"image_base64": img_base64, |
|
"image_format": "PNG", |
|
"visualization_type": visualization_type, |
|
"dimensions": viz_image.size, |
|
}, |
|
"metadata": {"total_elements": len(collection_data.get("text_elements", []))}, |
|
} |
|
except Exception as e: |
|
return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None} |
|
|
|
def export_results(processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict: |
|
"""Export HTR results to multiple formats using HTRflow's native export functionality.""" |
|
try: |
|
state = json.loads(processing_state) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: |
|
image_data = base64.b64decode(state["image_base64"]) |
|
image = Image.open(io.BytesIO(image_data)) |
|
image.save(temp_file.name, "PNG") |
|
temp_image_path = temp_file.name |
|
|
|
try: |
|
collection = Collection([temp_image_path]) |
|
pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]]) |
|
processed_collection = pipeline.run(collection) |
|
|
|
temp_dir = Path(tempfile.mkdtemp()) |
|
exports = {} |
|
|
|
for fmt in output_formats: |
|
export_dir = temp_dir / fmt |
|
processed_collection.save(directory=str(export_dir), serializer=fmt) |
|
|
|
export_files = [] |
|
for root, _, files in os.walk(export_dir): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
export_files.append({"filename": file, "content": content}) |
|
|
|
exports[fmt] = export_files |
|
|
|
shutil.rmtree(temp_dir) |
|
|
|
return { |
|
"success": True, |
|
"exports": exports, |
|
"export_metadata": { |
|
"formats_generated": output_formats, |
|
"confidence_filter": confidence_filter, |
|
"timestamp": datetime.now().isoformat(), |
|
}, |
|
} |
|
finally: |
|
if os.path.exists(temp_image_path): |
|
os.unlink(temp_image_path) |
|
|
|
except Exception as e: |
|
return {"success": False, "error": f"Export generation failed: {str(e)}", "exports": None} |
|
|
|
def extract_text_results(collection: Collection, confidence_threshold: float) -> Dict: |
|
results = {"extracted_text": "", "text_lines": [], "confidence_scores": []} |
|
for page in collection.pages: |
|
for node in page.traverse(): |
|
if hasattr(node, "text") and node.text and hasattr(node, "confidence") and node.confidence >= confidence_threshold: |
|
results["text_lines"].append({ |
|
"text": node.text, |
|
"confidence": node.confidence, |
|
"bbox": getattr(node, "bbox", None), |
|
}) |
|
results["extracted_text"] += node.text + "\n" |
|
results["confidence_scores"].append(node.confidence) |
|
|
|
results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0 |
|
return results |
|
|
|
def serialize_collection_data(collection: Collection) -> Dict: |
|
text_elements = [] |
|
for page in collection.pages: |
|
for node in page.traverse(): |
|
if hasattr(node, "text") and node.text: |
|
text_elements.append({ |
|
"text": node.text, |
|
"confidence": getattr(node, "confidence", 1.0), |
|
"bbox": getattr(node, "bbox", None), |
|
}) |
|
return {"text_elements": text_elements} |
|
|
|
def create_visualization(image, collection_data, visualization_type, show_confidence, highlight_low_confidence): |
|
viz_image = image.copy() |
|
draw = ImageDraw.Draw(viz_image) |
|
|
|
try: |
|
font = ImageFont.truetype("arial.ttf", 12) |
|
except: |
|
font = ImageFont.load_default() |
|
|
|
for element in collection_data.get("text_elements", []): |
|
if element.get("bbox"): |
|
bbox = element["bbox"] |
|
confidence = element.get("confidence", 1.0) |
|
|
|
if visualization_type == "overlay": |
|
color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0) |
|
draw.rectangle(bbox, outline=color, width=2) |
|
if show_confidence: |
|
draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font) |
|
|
|
elif visualization_type == "confidence_heatmap": |
|
if confidence < 0.5: |
|
color = (255, 0, 0, 100) |
|
elif confidence < 0.8: |
|
color = (255, 255, 0, 100) |
|
else: |
|
color = (0, 255, 0, 100) |
|
overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0)) |
|
overlay_draw = ImageDraw.Draw(overlay) |
|
overlay_draw.rectangle(bbox, fill=color) |
|
viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay) |
|
|
|
elif visualization_type == "text_regions": |
|
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] |
|
color = colors[hash(str(bbox)) % len(colors)] |
|
draw.rectangle(bbox, outline=color, width=3) |
|
|
|
return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image |
|
|
|
def create_htrflow_mcp_server(): |
|
demo = gr.TabbedInterface( |
|
[ |
|
gr.Interface( |
|
fn=process_htr, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload Image"), |
|
gr.Dropdown(choices=["letter_english", "letter_swedish", "spread_english", "spread_swedish"], value="letter_english", label="Document Type"), |
|
gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"), |
|
gr.Textbox(label="Custom Settings (JSON)", placeholder="Optional custom pipeline settings"), |
|
], |
|
outputs=gr.JSON(label="Processing Results"), |
|
title="HTR Processing Tool", |
|
description="Process handwritten text using configurable HTRflow pipelines", |
|
api_name="process_htr", |
|
), |
|
gr.Interface( |
|
fn=visualize_results, |
|
inputs=[ |
|
gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"), |
|
gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"), |
|
gr.Checkbox(value=True, label="Show Confidence Scores"), |
|
gr.Checkbox(value=True, label="Highlight Low Confidence"), |
|
gr.Image(type="pil", label="Image (optional)"), |
|
], |
|
outputs=gr.JSON(label="Visualization Results"), |
|
title="Results Visualization Tool", |
|
description="Generate interactive visualizations of HTR results", |
|
api_name="visualize_results", |
|
), |
|
gr.Interface( |
|
fn=export_results, |
|
inputs=[ |
|
gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"), |
|
gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"), |
|
gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"), |
|
], |
|
outputs=gr.JSON(label="Export Results"), |
|
title="Export Tool", |
|
description="Export HTR results to multiple formats", |
|
api_name="export_results", |
|
), |
|
], |
|
["HTR Processing", "Results Visualization", "Export Results"], |
|
title="HTRflow MCP Server", |
|
) |
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_htrflow_mcp_server() |
|
demo.launch(mcp_server=True) |