import os import re from PIL import Image from dotenv import load_dotenv from fastapi import FastAPI, UploadFile, File, HTTPException, Header from fastapi.middleware.cors import CORSMiddleware from pdf2image import convert_from_bytes import gradio as gr from transformers import pipeline # Load .env load_dotenv() API_KEY = os.getenv("API_KEY") MODEL_ID = "scb10x/typhoon-ocr-7b" ocr_pipeline = pipeline("image-to-text", model="scb10x/typhoon-ocr-7b") # FastAPI app init app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # --- UTILS --- def pdf_to_image(file_bytes: bytes) -> Image.Image: images = convert_from_bytes(file_bytes) return images[0] # Only first page for now def run_ocr(image: Image.Image) -> str: result = ocr_pipeline(image) return result[0]["generated_text"] def preprocess_text(text: str) -> str: text = re.sub(r"]*>", "\n", text) text = re.sub(r"<.*?>", "", text) text = re.sub(r"\n+", "\n", text) text = re.sub(r"\s{2,}", " ", text) return text.strip() def extract_fields_regex(text: str) -> dict: patterns = { "tax_id": r"(?:TAX\s*ID|เลขที่ผู้เสียภาษี)[\s:\-\.]*([0-9]{10,13})", "tax_invoice": r"(?:TAX\s*INV\.?|เลขที่ใบกำกับภาษี|ใบกำกับ)[\s:\-\.]*([0-9A-Z\-\/]{6,20})", "tax_date": r"(?:DATE|วันที่|ออกใบกำกับวันที่)?[\s:\-\.]*([0-9]{2,4}/[0-9]{1,2}/[0-9]{1,2})", "amount": r"(?:จำนวนเงิน(?:\s*บาทต่อลิตร)?|AMOUNT\s*THB|รวมเงิน)[\s:\-\.]*([0-9,]+\.[0-9]{2})", "baht_per_litre": r"(?:บาทต่อลิตร|ราคาต่อลิตร|Baht/Litr|Bath/Ltr)[\s:\-\.]*([0-9,]+\.[0-9]{2})", "litre": r"(?:ลิตร|Ltr\.?|Ltrs?\.?)[\s:\-\.]*([0-9,]+\.[0-9]{2,3})", "vat": r"(?:VAT|ภาษีมูลค่าเพิ่ม)[\s:\-\.]*([0-9,]+\.[0-9]{2})", "total": r"(?:TOTAL\s*THB|ยอดรวม|รวมทั้งสิ้น|รวมเงินทั้งสิ้น|ยอดเงินสุทธิ)[\s:\-\.]*([0-9,]+\.[0-9]{2})", } results = {} for field, pattern in patterns.items(): match = re.search(pattern, text, re.IGNORECASE) results[field] = match.group(1).strip() if match else None return results # --- API Endpoint --- @app.post("/api/ocr_receipt") async def ocr_receipt( file: UploadFile = File(...), x_api_key: str | None = Header(None), ): if API_KEY and x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") content = await file.read() try: if file.filename.lower().endswith(".pdf"): image = pdf_to_image(content) else: image = Image.open(file.file).convert("RGB") text = run_ocr(image) text_cleaned = preprocess_text(text) extracted = extract_fields_regex(text_cleaned) return { "raw_ocr": text, "preprocessed_text": text_cleaned, "extracted_fields": extracted, } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # --- Gradio UI --- def gradio_interface(image_path: str | Image.Image): if isinstance(image_path, str) and image_path.lower().endswith(".pdf"): with open(image_path, "rb") as f: image = pdf_to_image(f.read()) elif isinstance(image_path, str): image = Image.open(image_path).convert("RGB") else: image = image_path.convert("RGB") text = run_ocr(image) text_cleaned = preprocess_text(text) extracted = extract_fields_regex(text_cleaned) return text_cleaned, extracted with gr.Blocks() as demo: gr.Markdown("## 🧾 Thai Receipt OCR (Typhoon 7B)") with gr.Row(): img = gr.Image(type="filepath", label="📤 Upload receipt (Image or PDF)") out_text = gr.Textbox(label="📝 OCR Text", lines=12) out_fields = gr.JSON(label="🧠 Extracted Fields") gr.Button("🔍 Run OCR").click(fn=gradio_interface, inputs=img, outputs=[out_text, out_fields]) demo.launch()