|
import os |
|
import re |
|
from PIL import Image |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Header |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pdf2image import convert_from_bytes |
|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
load_dotenv() |
|
API_KEY = os.getenv("API_KEY") |
|
MODEL_ID = "scb10x/typhoon-ocr-7b" |
|
|
|
ocr_pipeline = pipeline("image-to-text", model="scb10x/typhoon-ocr-7b") |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
def pdf_to_image(file_bytes: bytes) -> Image.Image: |
|
images = convert_from_bytes(file_bytes) |
|
return images[0] |
|
|
|
def run_ocr(image: Image.Image) -> str: |
|
result = ocr_pipeline(image) |
|
return result[0]["generated_text"] |
|
|
|
def preprocess_text(text: str) -> str: |
|
text = re.sub(r"</?(figure|table|tr|td|th|b|i|u|p|div|span)[^>]*>", "\n", text) |
|
text = re.sub(r"<.*?>", "", text) |
|
text = re.sub(r"\n+", "\n", text) |
|
text = re.sub(r"\s{2,}", " ", text) |
|
return text.strip() |
|
|
|
def extract_fields_regex(text: str) -> dict: |
|
patterns = { |
|
"tax_id": r"(?:TAX\s*ID|เลขที่ผู้เสียภาษี)[\s:\-\.]*([0-9]{10,13})", |
|
"tax_invoice": r"(?:TAX\s*INV\.?|เลขที่ใบกำกับภาษี|ใบกำกับ)[\s:\-\.]*([0-9A-Z\-\/]{6,20})", |
|
"tax_date": r"(?:DATE|วันที่|ออกใบกำกับวันที่)?[\s:\-\.]*([0-9]{2,4}/[0-9]{1,2}/[0-9]{1,2})", |
|
"amount": r"(?:จำนวนเงิน(?:\s*บาทต่อลิตร)?|AMOUNT\s*THB|รวมเงิน)[\s:\-\.]*([0-9,]+\.[0-9]{2})", |
|
"baht_per_litre": r"(?:บาทต่อลิตร|ราคาต่อลิตร|Baht/Litr|Bath/Ltr)[\s:\-\.]*([0-9,]+\.[0-9]{2})", |
|
"litre": r"(?:ลิตร|Ltr\.?|Ltrs?\.?)[\s:\-\.]*([0-9,]+\.[0-9]{2,3})", |
|
"vat": r"(?:VAT|ภาษีมูลค่าเพิ่ม)[\s:\-\.]*([0-9,]+\.[0-9]{2})", |
|
"total": r"(?:TOTAL\s*THB|ยอดรวม|รวมทั้งสิ้น|รวมเงินทั้งสิ้น|ยอดเงินสุทธิ)[\s:\-\.]*([0-9,]+\.[0-9]{2})", |
|
} |
|
|
|
results = {} |
|
for field, pattern in patterns.items(): |
|
match = re.search(pattern, text, re.IGNORECASE) |
|
results[field] = match.group(1).strip() if match else None |
|
return results |
|
|
|
|
|
@app.post("/api/ocr_receipt") |
|
async def ocr_receipt( |
|
file: UploadFile = File(...), |
|
x_api_key: str | None = Header(None), |
|
): |
|
if API_KEY and x_api_key != API_KEY: |
|
raise HTTPException(status_code=401, detail="Invalid API key") |
|
|
|
content = await file.read() |
|
|
|
try: |
|
if file.filename.lower().endswith(".pdf"): |
|
image = pdf_to_image(content) |
|
else: |
|
image = Image.open(file.file).convert("RGB") |
|
|
|
text = run_ocr(image) |
|
text_cleaned = preprocess_text(text) |
|
extracted = extract_fields_regex(text_cleaned) |
|
|
|
return { |
|
"raw_ocr": text, |
|
"preprocessed_text": text_cleaned, |
|
"extracted_fields": extracted, |
|
} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
def gradio_interface(image_path: str | Image.Image): |
|
if isinstance(image_path, str) and image_path.lower().endswith(".pdf"): |
|
with open(image_path, "rb") as f: |
|
image = pdf_to_image(f.read()) |
|
elif isinstance(image_path, str): |
|
image = Image.open(image_path).convert("RGB") |
|
else: |
|
image = image_path.convert("RGB") |
|
|
|
text = run_ocr(image) |
|
text_cleaned = preprocess_text(text) |
|
extracted = extract_fields_regex(text_cleaned) |
|
return text_cleaned, extracted |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🧾 Thai Receipt OCR (Typhoon 7B)") |
|
with gr.Row(): |
|
img = gr.Image(type="filepath", label="📤 Upload receipt (Image or PDF)") |
|
out_text = gr.Textbox(label="📝 OCR Text", lines=12) |
|
out_fields = gr.JSON(label="🧠 Extracted Fields") |
|
gr.Button("🔍 Run OCR").click(fn=gradio_interface, inputs=img, outputs=[out_text, out_fields]) |
|
|
|
demo.launch() |
|
|