kawaiipeace's picture
update
83783c7
import os
import re
from PIL import Image
from dotenv import load_dotenv
from fastapi import FastAPI, UploadFile, File, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from pdf2image import convert_from_bytes
import gradio as gr
from transformers import pipeline
# Load .env
load_dotenv()
API_KEY = os.getenv("API_KEY")
MODEL_ID = "scb10x/typhoon-ocr-7b"
ocr_pipeline = pipeline("image-to-text", model="scb10x/typhoon-ocr-7b")
# FastAPI app init
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# --- UTILS ---
def pdf_to_image(file_bytes: bytes) -> Image.Image:
images = convert_from_bytes(file_bytes)
return images[0] # Only first page for now
def run_ocr(image: Image.Image) -> str:
result = ocr_pipeline(image)
return result[0]["generated_text"]
def preprocess_text(text: str) -> str:
text = re.sub(r"</?(figure|table|tr|td|th|b|i|u|p|div|span)[^>]*>", "\n", text)
text = re.sub(r"<.*?>", "", text)
text = re.sub(r"\n+", "\n", text)
text = re.sub(r"\s{2,}", " ", text)
return text.strip()
def extract_fields_regex(text: str) -> dict:
patterns = {
"tax_id": r"(?:TAX\s*ID|เลขที่ผู้เสียภาษี)[\s:\-\.]*([0-9]{10,13})",
"tax_invoice": r"(?:TAX\s*INV\.?|เลขที่ใบกำกับภาษี|ใบกำกับ)[\s:\-\.]*([0-9A-Z\-\/]{6,20})",
"tax_date": r"(?:DATE|วันที่|ออกใบกำกับวันที่)?[\s:\-\.]*([0-9]{2,4}/[0-9]{1,2}/[0-9]{1,2})",
"amount": r"(?:จำนวนเงิน(?:\s*บาทต่อลิตร)?|AMOUNT\s*THB|รวมเงิน)[\s:\-\.]*([0-9,]+\.[0-9]{2})",
"baht_per_litre": r"(?:บาทต่อลิตร|ราคาต่อลิตร|Baht/Litr|Bath/Ltr)[\s:\-\.]*([0-9,]+\.[0-9]{2})",
"litre": r"(?:ลิตร|Ltr\.?|Ltrs?\.?)[\s:\-\.]*([0-9,]+\.[0-9]{2,3})",
"vat": r"(?:VAT|ภาษีมูลค่าเพิ่ม)[\s:\-\.]*([0-9,]+\.[0-9]{2})",
"total": r"(?:TOTAL\s*THB|ยอดรวม|รวมทั้งสิ้น|รวมเงินทั้งสิ้น|ยอดเงินสุทธิ)[\s:\-\.]*([0-9,]+\.[0-9]{2})",
}
results = {}
for field, pattern in patterns.items():
match = re.search(pattern, text, re.IGNORECASE)
results[field] = match.group(1).strip() if match else None
return results
# --- API Endpoint ---
@app.post("/api/ocr_receipt")
async def ocr_receipt(
file: UploadFile = File(...),
x_api_key: str | None = Header(None),
):
if API_KEY and x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
content = await file.read()
try:
if file.filename.lower().endswith(".pdf"):
image = pdf_to_image(content)
else:
image = Image.open(file.file).convert("RGB")
text = run_ocr(image)
text_cleaned = preprocess_text(text)
extracted = extract_fields_regex(text_cleaned)
return {
"raw_ocr": text,
"preprocessed_text": text_cleaned,
"extracted_fields": extracted,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --- Gradio UI ---
def gradio_interface(image_path: str | Image.Image):
if isinstance(image_path, str) and image_path.lower().endswith(".pdf"):
with open(image_path, "rb") as f:
image = pdf_to_image(f.read())
elif isinstance(image_path, str):
image = Image.open(image_path).convert("RGB")
else:
image = image_path.convert("RGB")
text = run_ocr(image)
text_cleaned = preprocess_text(text)
extracted = extract_fields_regex(text_cleaned)
return text_cleaned, extracted
with gr.Blocks() as demo:
gr.Markdown("## 🧾 Thai Receipt OCR (Typhoon 7B)")
with gr.Row():
img = gr.Image(type="filepath", label="📤 Upload receipt (Image or PDF)")
out_text = gr.Textbox(label="📝 OCR Text", lines=12)
out_fields = gr.JSON(label="🧠 Extracted Fields")
gr.Button("🔍 Run OCR").click(fn=gradio_interface, inputs=img, outputs=[out_text, out_fields])
demo.launch()