htrflow_mcp / app.py
Gabriel's picture
Update app.py
f094617 verified
raw
history blame
15.8 kB
import gradio as gr
import json
import base64
import tempfile
import os
from typing import Dict, List, Optional, Literal
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
import io
import spaces
import shutil
from pathlib import Path
from htrflow.volume.volume import Collection
from htrflow.pipeline.pipeline import Pipeline
PIPELINE_CONFIGS = {
"letter_english": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "microsoft/trocr-base-handwritten"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "OrderLines"},
]
},
"letter_swedish": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "OrderLines"},
]
},
"spread_english": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"},
"generation_settings": {"batch_size": 4},
},
},
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "microsoft/trocr-base-handwritten"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}},
]
},
"spread_swedish": {
"steps": [
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-regions-1"},
"generation_settings": {"batch_size": 4},
},
},
{
"step": "Segmentation",
"settings": {
"model": "yolo",
"model_settings": {"model": "Riksarkivet/yolov9-lines-within-regions-1"},
"generation_settings": {"batch_size": 8},
},
},
{
"step": "TextRecognition",
"settings": {
"model": "TrOCR",
"model_settings": {"model": "Riksarkivet/trocr-base-handwritten-hist-swe-2"},
"generation_settings": {"batch_size": 16},
},
},
{"step": "ReadingOrderMarginalia", "settings": {"two_page": True}},
]
},
}
@spaces.GPU
def process_htr(image: Image.Image, document_type: Literal["letter_english", "letter_swedish", "spread_english", "spread_swedish"] = "spread_swedish", confidence_threshold: float = 0.8, custom_settings: Optional[str] = None) -> Dict:
"""Process handwritten text recognition on uploaded images using HTRflow pipelines."""
try:
if image is None:
return {"success": False, "error": "No image provided", "results": None}
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image.save(temp_file.name, "PNG")
temp_image_path = temp_file.name
try:
if custom_settings:
try:
config = json.loads(custom_settings)
except json.JSONDecodeError:
return {"success": False, "error": "Invalid JSON in custom_settings parameter", "results": None}
else:
config = PIPELINE_CONFIGS[document_type]
collection = Collection([temp_image_path])
pipeline = Pipeline.from_config(config)
processed_collection = pipeline.run(collection)
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
results = extract_text_results(processed_collection, confidence_threshold)
processing_state = {
"collection_data": serialize_collection_data(processed_collection),
"image_base64": image_base64,
"image_size": image.size,
"document_type": document_type,
"confidence_threshold": confidence_threshold,
"timestamp": datetime.now().isoformat(),
}
return {
"success": True,
"results": results,
"processing_state": json.dumps(processing_state),
"metadata": {
"total_lines": len(results.get("text_lines", [])),
"average_confidence": results.get("average_confidence", 0),
"document_type": document_type,
"image_dimensions": image.size,
},
}
finally:
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except Exception as e:
return {"success": False, "error": f"HTR processing failed: {str(e)}", "results": None}
def visualize_results(processing_state: str, visualization_type: Literal["overlay", "confidence_heatmap", "text_regions"] = "overlay", show_confidence: bool = True, highlight_low_confidence: bool = True, image: Optional[Image.Image] = None) -> Dict:
"""Generate interactive visualizations of HTR processing results."""
try:
state = json.loads(processing_state)
collection_data = state["collection_data"]
if image is not None:
original_image = image
else:
image_data = base64.b64decode(state["image_base64"])
original_image = Image.open(io.BytesIO(image_data))
viz_image = create_visualization(original_image, collection_data, visualization_type, show_confidence, highlight_low_confidence)
img_buffer = io.BytesIO()
viz_image.save(img_buffer, format="PNG")
img_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
return {
"success": True,
"visualization": {
"image_base64": img_base64,
"image_format": "PNG",
"visualization_type": visualization_type,
"dimensions": viz_image.size,
},
"metadata": {"total_elements": len(collection_data.get("text_elements", []))},
}
except Exception as e:
return {"success": False, "error": f"Visualization generation failed: {str(e)}", "visualization": None}
def export_results(processing_state: str, output_formats: List[Literal["txt", "json", "alto", "page"]] = ["txt"], confidence_filter: float = 0.0) -> Dict:
"""Export HTR results to multiple formats using HTRflow's native export functionality."""
try:
state = json.loads(processing_state)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image_data = base64.b64decode(state["image_base64"])
image = Image.open(io.BytesIO(image_data))
image.save(temp_file.name, "PNG")
temp_image_path = temp_file.name
try:
collection = Collection([temp_image_path])
pipeline = Pipeline.from_config(PIPELINE_CONFIGS[state["document_type"]])
processed_collection = pipeline.run(collection)
temp_dir = Path(tempfile.mkdtemp())
exports = {}
for fmt in output_formats:
export_dir = temp_dir / fmt
processed_collection.save(directory=str(export_dir), serializer=fmt)
export_files = []
for root, _, files in os.walk(export_dir):
for file in files:
file_path = os.path.join(root, file)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
export_files.append({"filename": file, "content": content})
exports[fmt] = export_files
shutil.rmtree(temp_dir)
return {
"success": True,
"exports": exports,
"export_metadata": {
"formats_generated": output_formats,
"confidence_filter": confidence_filter,
"timestamp": datetime.now().isoformat(),
},
}
finally:
if os.path.exists(temp_image_path):
os.unlink(temp_image_path)
except Exception as e:
return {"success": False, "error": f"Export generation failed: {str(e)}", "exports": None}
def extract_text_results(collection: Collection, confidence_threshold: float) -> Dict:
results = {"extracted_text": "", "text_lines": [], "confidence_scores": []}
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text and hasattr(node, "confidence") and node.confidence >= confidence_threshold:
results["text_lines"].append({
"text": node.text,
"confidence": node.confidence,
"bbox": getattr(node, "bbox", None),
})
results["extracted_text"] += node.text + "\n"
results["confidence_scores"].append(node.confidence)
results["average_confidence"] = sum(results["confidence_scores"]) / len(results["confidence_scores"]) if results["confidence_scores"] else 0
return results
def serialize_collection_data(collection: Collection) -> Dict:
text_elements = []
for page in collection.pages:
for node in page.traverse():
if hasattr(node, "text") and node.text:
text_elements.append({
"text": node.text,
"confidence": getattr(node, "confidence", 1.0),
"bbox": getattr(node, "bbox", None),
})
return {"text_elements": text_elements}
def create_visualization(image, collection_data, visualization_type, show_confidence, highlight_low_confidence):
viz_image = image.copy()
draw = ImageDraw.Draw(viz_image)
try:
font = ImageFont.truetype("arial.ttf", 12)
except:
font = ImageFont.load_default()
for element in collection_data.get("text_elements", []):
if element.get("bbox"):
bbox = element["bbox"]
confidence = element.get("confidence", 1.0)
if visualization_type == "overlay":
color = (255, 165, 0) if highlight_low_confidence and confidence < 0.7 else (0, 255, 0)
draw.rectangle(bbox, outline=color, width=2)
if show_confidence:
draw.text((bbox[0], bbox[1] - 15), f"{confidence:.2f}", fill=color, font=font)
elif visualization_type == "confidence_heatmap":
if confidence < 0.5:
color = (255, 0, 0, 100)
elif confidence < 0.8:
color = (255, 255, 0, 100)
else:
color = (0, 255, 0, 100)
overlay = Image.new("RGBA", viz_image.size, (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
overlay_draw.rectangle(bbox, fill=color)
viz_image = Image.alpha_composite(viz_image.convert("RGBA"), overlay)
elif visualization_type == "text_regions":
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
color = colors[hash(str(bbox)) % len(colors)]
draw.rectangle(bbox, outline=color, width=3)
return viz_image.convert("RGB") if visualization_type == "confidence_heatmap" else viz_image
def create_htrflow_mcp_server():
demo = gr.TabbedInterface(
[
gr.Interface(
fn=process_htr,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Dropdown(choices=["letter_english", "letter_swedish", "spread_english", "spread_swedish"], value="letter_english", label="Document Type"),
gr.Slider(0.0, 1.0, value=0.8, label="Confidence Threshold"),
gr.Textbox(label="Custom Settings (JSON)", placeholder="Optional custom pipeline settings"),
],
outputs=gr.JSON(label="Processing Results"),
title="HTR Processing Tool",
description="Process handwritten text using configurable HTRflow pipelines",
api_name="process_htr",
),
gr.Interface(
fn=visualize_results,
inputs=[
gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
gr.Dropdown(choices=["overlay", "confidence_heatmap", "text_regions"], value="overlay", label="Visualization Type"),
gr.Checkbox(value=True, label="Show Confidence Scores"),
gr.Checkbox(value=True, label="Highlight Low Confidence"),
gr.Image(type="pil", label="Image (optional)"),
],
outputs=gr.JSON(label="Visualization Results"),
title="Results Visualization Tool",
description="Generate interactive visualizations of HTR results",
api_name="visualize_results",
),
gr.Interface(
fn=export_results,
inputs=[
gr.Textbox(label="Processing State (JSON)", placeholder="Paste processing results from HTR tool"),
gr.CheckboxGroup(choices=["txt", "json", "alto", "page"], value=["txt"], label="Output Formats"),
gr.Slider(0.0, 1.0, value=0.0, label="Confidence Filter"),
],
outputs=gr.JSON(label="Export Results"),
title="Export Tool",
description="Export HTR results to multiple formats",
api_name="export_results",
),
],
["HTR Processing", "Results Visualization", "Export Results"],
title="HTRflow MCP Server",
)
return demo
if __name__ == "__main__":
demo = create_htrflow_mcp_server()
demo.launch(mcp_server=True)