Spaces:
Running
Running
Commit
·
83dd2a8
0
Parent(s):
Initial commit without binary files
Browse files- .gitattributes +4 -0
- .github/workflows/ci-cd.yml +117 -0
- .gitignore +89 -0
- Dockerfile +46 -0
- README.md +54 -0
- app.py +181 -0
- chatbot_server.py +68 -0
- insert_dummy_data.py +51 -0
- requirements.txt +33 -0
- src/api/main.py +104 -0
- src/config/config.yaml +66 -0
- src/frontend/app.py +299 -0
- src/frontend/temp_uploaded_image.jpg +0 -0
- src/models/layoutlm.py +144 -0
- src/scripts/train.py +202 -0
- tests/test_model.py +93 -0
- transformers +1 -0
.gitattributes
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/ci-cd.yml
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: FormIQ CI/CD
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
pull_request:
|
7 |
+
branches: [ main ]
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
test:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
strategy:
|
13 |
+
matrix:
|
14 |
+
python-version: [3.8, 3.9]
|
15 |
+
|
16 |
+
steps:
|
17 |
+
- uses: actions/checkout@v2
|
18 |
+
|
19 |
+
- name: Set up Python ${{ matrix.python-version }}
|
20 |
+
uses: actions/setup-python@v2
|
21 |
+
with:
|
22 |
+
python-version: ${{ matrix.python-version }}
|
23 |
+
|
24 |
+
- name: Install dependencies
|
25 |
+
run: |
|
26 |
+
python -m pip install --upgrade pip
|
27 |
+
pip install -r requirements.txt
|
28 |
+
pip install pytest pytest-cov
|
29 |
+
|
30 |
+
- name: Run tests
|
31 |
+
run: |
|
32 |
+
pytest tests/ --cov=src/ --cov-report=xml
|
33 |
+
|
34 |
+
- name: Upload coverage to Codecov
|
35 |
+
uses: codecov/codecov-action@v2
|
36 |
+
with:
|
37 |
+
file: ./coverage.xml
|
38 |
+
fail_ci_if_error: true
|
39 |
+
|
40 |
+
lint:
|
41 |
+
runs-on: ubuntu-latest
|
42 |
+
steps:
|
43 |
+
- uses: actions/checkout@v2
|
44 |
+
|
45 |
+
- name: Set up Python
|
46 |
+
uses: actions/setup-python@v2
|
47 |
+
with:
|
48 |
+
python-version: 3.9
|
49 |
+
|
50 |
+
- name: Install dependencies
|
51 |
+
run: |
|
52 |
+
python -m pip install --upgrade pip
|
53 |
+
pip install flake8 black isort
|
54 |
+
|
55 |
+
- name: Run linters
|
56 |
+
run: |
|
57 |
+
flake8 src/ tests/
|
58 |
+
black --check src/ tests/
|
59 |
+
isort --check-only src/ tests/
|
60 |
+
|
61 |
+
build-and-push:
|
62 |
+
needs: [test, lint]
|
63 |
+
runs-on: ubuntu-latest
|
64 |
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
65 |
+
|
66 |
+
steps:
|
67 |
+
- uses: actions/checkout@v2
|
68 |
+
|
69 |
+
- name: Set up Docker Buildx
|
70 |
+
uses: docker/setup-buildx-action@v1
|
71 |
+
|
72 |
+
- name: Login to DockerHub
|
73 |
+
uses: docker/login-action@v1
|
74 |
+
with:
|
75 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
76 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
77 |
+
|
78 |
+
- name: Build and push
|
79 |
+
uses: docker/build-push-action@v2
|
80 |
+
with:
|
81 |
+
context: .
|
82 |
+
push: true
|
83 |
+
tags: |
|
84 |
+
${{ secrets.DOCKERHUB_USERNAME }}/formiq:latest
|
85 |
+
${{ secrets.DOCKERHUB_USERNAME }}/formiq:${{ github.sha }}
|
86 |
+
cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/formiq:buildcache
|
87 |
+
cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/formiq:buildcache,mode=max
|
88 |
+
|
89 |
+
deploy:
|
90 |
+
needs: build-and-push
|
91 |
+
runs-on: ubuntu-latest
|
92 |
+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
|
93 |
+
|
94 |
+
steps:
|
95 |
+
- uses: actions/checkout@v2
|
96 |
+
|
97 |
+
- name: Configure AWS credentials
|
98 |
+
uses: aws-actions/configure-aws-credentials@v1
|
99 |
+
with:
|
100 |
+
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
101 |
+
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
102 |
+
aws-region: us-east-1
|
103 |
+
|
104 |
+
- name: Deploy to SageMaker
|
105 |
+
run: |
|
106 |
+
# Update SageMaker endpoint with new model
|
107 |
+
aws sagemaker update-endpoint \
|
108 |
+
--endpoint-name formiq-endpoint \
|
109 |
+
--endpoint-config-name formiq-config-${{ github.sha }}
|
110 |
+
|
111 |
+
- name: Deploy to ECS
|
112 |
+
run: |
|
113 |
+
# Update ECS service with new container
|
114 |
+
aws ecs update-service \
|
115 |
+
--cluster formiq-cluster \
|
116 |
+
--service formiq-service \
|
117 |
+
--force-new-deployment
|
.gitignore
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Virtual Environment
|
24 |
+
venv/
|
25 |
+
ENV/
|
26 |
+
env/
|
27 |
+
|
28 |
+
# IDE
|
29 |
+
.idea/
|
30 |
+
.vscode/
|
31 |
+
*.swp
|
32 |
+
*.swo
|
33 |
+
|
34 |
+
# Project specific
|
35 |
+
temp_uploaded_image.jpg
|
36 |
+
.env
|
37 |
+
*.log
|
38 |
+
.DS_Store
|
39 |
+
|
40 |
+
# Model files
|
41 |
+
*.pt
|
42 |
+
*.pth
|
43 |
+
*.onnx
|
44 |
+
*.h5
|
45 |
+
*.model
|
46 |
+
|
47 |
+
# Data
|
48 |
+
*.csv
|
49 |
+
*.json
|
50 |
+
*.xlsx
|
51 |
+
*.db
|
52 |
+
*.sqlite3
|
53 |
+
|
54 |
+
# Jupyter Notebook
|
55 |
+
.ipynb_checkpoints
|
56 |
+
*.ipynb
|
57 |
+
|
58 |
+
# Logs
|
59 |
+
logs/
|
60 |
+
*.log
|
61 |
+
|
62 |
+
# MLflow
|
63 |
+
mlruns/
|
64 |
+
mlflow.db
|
65 |
+
|
66 |
+
# DVC
|
67 |
+
.dvc/
|
68 |
+
.dvc/cache/
|
69 |
+
|
70 |
+
# Testing
|
71 |
+
.coverage
|
72 |
+
coverage.xml
|
73 |
+
htmlcov/
|
74 |
+
.pytest_cache/
|
75 |
+
|
76 |
+
# Docker
|
77 |
+
.docker/
|
78 |
+
|
79 |
+
# AWS
|
80 |
+
.aws/
|
81 |
+
*.pem
|
82 |
+
|
83 |
+
# Environment variables
|
84 |
+
.env.*
|
85 |
+
|
86 |
+
# Distribution
|
87 |
+
dist/
|
88 |
+
build/
|
89 |
+
*.egg-info/
|
Dockerfile
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use NVIDIA CUDA base image for GPU support
|
2 |
+
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04
|
3 |
+
|
4 |
+
# Set environment variables
|
5 |
+
ENV PYTHONUNBUFFERED=1 \
|
6 |
+
DEBIAN_FRONTEND=noninteractive \
|
7 |
+
PYTHON_VERSION=3.9
|
8 |
+
|
9 |
+
# Install system dependencies
|
10 |
+
RUN apt-get update && apt-get install -y \
|
11 |
+
python${PYTHON_VERSION} \
|
12 |
+
python3-pip \
|
13 |
+
python${PYTHON_VERSION}-dev \
|
14 |
+
git \
|
15 |
+
&& rm -rf /var/lib/apt/lists/*
|
16 |
+
|
17 |
+
# Set working directory
|
18 |
+
WORKDIR /app
|
19 |
+
|
20 |
+
# Copy requirements first to leverage Docker cache
|
21 |
+
COPY requirements.txt .
|
22 |
+
|
23 |
+
# Install Python dependencies
|
24 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
25 |
+
|
26 |
+
# Copy application code
|
27 |
+
COPY . .
|
28 |
+
|
29 |
+
# Create necessary directories
|
30 |
+
RUN mkdir -p data/train data/val data/test logs
|
31 |
+
|
32 |
+
# Set environment variables for the application
|
33 |
+
ENV MODEL_SAVE_DIR=/app/models \
|
34 |
+
DATA_DIR=/app/data \
|
35 |
+
LOG_DIR=/app/logs
|
36 |
+
|
37 |
+
# Expose ports
|
38 |
+
EXPOSE 8000 8501
|
39 |
+
|
40 |
+
# Create a non-root user
|
41 |
+
RUN useradd -m -u 1000 appuser
|
42 |
+
RUN chown -R appuser:appuser /app
|
43 |
+
USER appuser
|
44 |
+
|
45 |
+
# Start the application
|
46 |
+
CMD ["sh", "-c", "uvicorn src.api.main:app --host 0.0.0.0 --port 8000 & streamlit run src/frontend/app.py --server.port 8501 --server.address 0.0.0.0"]
|
README.md
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FormIQ - Intelligent Document Parser
|
2 |
+
|
3 |
+
FormIQ is an intelligent document parser that uses advanced AI models to extract and validate information from various types of documents.
|
4 |
+
|
5 |
+
## Features
|
6 |
+
|
7 |
+
- Document image upload and processing
|
8 |
+
- OCR text extraction using Tesseract
|
9 |
+
- Advanced document understanding using LayoutLMv3
|
10 |
+
- Structured information extraction using Perplexity AI
|
11 |
+
- Interactive web interface built with Streamlit
|
12 |
+
|
13 |
+
## Technologies Used
|
14 |
+
|
15 |
+
- **Frontend**: Streamlit
|
16 |
+
- **OCR**: Tesseract
|
17 |
+
- **Document Understanding**: LayoutLMv3
|
18 |
+
- **Text Processing**: Perplexity AI
|
19 |
+
- **Data Processing**: Pandas, NumPy
|
20 |
+
- **Visualization**: Plotly
|
21 |
+
|
22 |
+
## Setup
|
23 |
+
|
24 |
+
1. Clone the repository
|
25 |
+
2. Install dependencies:
|
26 |
+
```bash
|
27 |
+
pip install -r requirements.txt
|
28 |
+
```
|
29 |
+
3. Set up environment variables:
|
30 |
+
```bash
|
31 |
+
PERPLEXITY_API_KEY=your_api_key_here
|
32 |
+
```
|
33 |
+
|
34 |
+
## Usage
|
35 |
+
|
36 |
+
1. Run the Streamlit app:
|
37 |
+
```bash
|
38 |
+
streamlit run app.py
|
39 |
+
```
|
40 |
+
2. Open your browser and navigate to the provided URL
|
41 |
+
3. Upload a document image
|
42 |
+
4. Click "Process Document" to extract information
|
43 |
+
|
44 |
+
## Hugging Face Spaces Deployment
|
45 |
+
|
46 |
+
This project is deployed on Hugging Face Spaces. You can access the live demo at: [Your Spaces URL]
|
47 |
+
|
48 |
+
## Contributing
|
49 |
+
|
50 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
51 |
+
|
52 |
+
## License
|
53 |
+
|
54 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
app.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
+
import json
|
7 |
+
import pandas as pd
|
8 |
+
import plotly.express as px
|
9 |
+
import numpy as np
|
10 |
+
from typing import Dict, Any
|
11 |
+
import logging
|
12 |
+
import pytesseract
|
13 |
+
import re
|
14 |
+
from openai import OpenAI
|
15 |
+
import os
|
16 |
+
from dotenv import load_dotenv
|
17 |
+
|
18 |
+
# Configure logging
|
19 |
+
logging.basicConfig(level=logging.INFO)
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
# Load environment variables
|
23 |
+
load_dotenv()
|
24 |
+
|
25 |
+
# Initialize OpenAI client for Perplexity
|
26 |
+
client = OpenAI(
|
27 |
+
api_key=os.getenv('PERPLEXITY_API_KEY'),
|
28 |
+
base_url="https://api.perplexity.ai"
|
29 |
+
)
|
30 |
+
|
31 |
+
# Initialize LayoutLM model
|
32 |
+
@st.cache_resource
|
33 |
+
def load_model():
|
34 |
+
model_name = "microsoft/layoutlmv3-base"
|
35 |
+
processor = LayoutLMv3Processor.from_pretrained(model_name)
|
36 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
|
37 |
+
return processor, model
|
38 |
+
|
39 |
+
def extract_json_from_llm_output(llm_result):
|
40 |
+
match = re.search(r'\{.*\}', llm_result, re.DOTALL)
|
41 |
+
if match:
|
42 |
+
return match.group(0)
|
43 |
+
return None
|
44 |
+
|
45 |
+
def extract_fields(image_path):
|
46 |
+
# OCR
|
47 |
+
text = pytesseract.image_to_string(Image.open(image_path))
|
48 |
+
|
49 |
+
# Display OCR output for debugging
|
50 |
+
st.subheader("Raw OCR Output")
|
51 |
+
st.code(text)
|
52 |
+
|
53 |
+
# Improved Regex patterns for fields
|
54 |
+
patterns = {
|
55 |
+
"name": r"Mrs\s+\w+\s+\w+",
|
56 |
+
"date": r"Date[:\s]+([\d/]+)",
|
57 |
+
"product": r"\d+\s+\w+.*Style\s+\d+",
|
58 |
+
"amount_paid": r"Total Paid\s+\$?([\d.,]+)",
|
59 |
+
"receipt_no": r"Receipt No\.?\s*:?\s*(\d+)"
|
60 |
+
}
|
61 |
+
|
62 |
+
results = {}
|
63 |
+
for field, pattern in patterns.items():
|
64 |
+
match = re.search(pattern, text, re.IGNORECASE)
|
65 |
+
if match:
|
66 |
+
results[field] = match.group(1) if match.groups() else match.group(0)
|
67 |
+
else:
|
68 |
+
results[field] = None
|
69 |
+
|
70 |
+
return results
|
71 |
+
|
72 |
+
def extract_with_perplexity_llm(ocr_text):
|
73 |
+
prompt = f"""
|
74 |
+
Extract the following fields from this receipt text:
|
75 |
+
- name
|
76 |
+
- date
|
77 |
+
- product
|
78 |
+
- amount_paid
|
79 |
+
- receipt_no
|
80 |
+
|
81 |
+
Text:
|
82 |
+
\"\"\"{ocr_text}\"\"\"
|
83 |
+
|
84 |
+
Return the result as a JSON object with those fields.
|
85 |
+
"""
|
86 |
+
messages = [
|
87 |
+
{
|
88 |
+
"role": "system",
|
89 |
+
"content": "You are an AI assistant that extracts structured information from text."
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"role": "user",
|
93 |
+
"content": prompt
|
94 |
+
}
|
95 |
+
]
|
96 |
+
|
97 |
+
response = client.chat.completions.create(
|
98 |
+
model="sonar-pro",
|
99 |
+
messages=messages
|
100 |
+
)
|
101 |
+
return response.choices[0].message.content
|
102 |
+
|
103 |
+
def main():
|
104 |
+
st.set_page_config(
|
105 |
+
page_title="FormIQ - Intelligent Document Parser",
|
106 |
+
page_icon="📄",
|
107 |
+
layout="wide"
|
108 |
+
)
|
109 |
+
|
110 |
+
st.title("FormIQ: Intelligent Document Parser")
|
111 |
+
st.markdown("""
|
112 |
+
Upload your documents to extract and validate information using advanced AI models.
|
113 |
+
""")
|
114 |
+
|
115 |
+
# Sidebar
|
116 |
+
with st.sidebar:
|
117 |
+
st.header("Settings")
|
118 |
+
document_type = st.selectbox(
|
119 |
+
"Document Type",
|
120 |
+
options=["invoice", "receipt", "form"],
|
121 |
+
index=0
|
122 |
+
)
|
123 |
+
|
124 |
+
confidence_threshold = st.slider(
|
125 |
+
"Confidence Threshold",
|
126 |
+
min_value=0.0,
|
127 |
+
max_value=1.0,
|
128 |
+
value=0.5,
|
129 |
+
step=0.05
|
130 |
+
)
|
131 |
+
|
132 |
+
st.markdown("---")
|
133 |
+
st.markdown("### About")
|
134 |
+
st.markdown("""
|
135 |
+
FormIQ uses LayoutLMv3 and Perplexity AI to extract and validate information from documents.
|
136 |
+
""")
|
137 |
+
|
138 |
+
# Main content
|
139 |
+
uploaded_file = st.file_uploader(
|
140 |
+
"Upload Document",
|
141 |
+
type=["png", "jpg", "jpeg", "pdf"],
|
142 |
+
help="Upload a document image to process"
|
143 |
+
)
|
144 |
+
|
145 |
+
if uploaded_file is not None:
|
146 |
+
# Display uploaded image
|
147 |
+
image = Image.open(uploaded_file)
|
148 |
+
st.image(image, caption="Uploaded Document", width=600)
|
149 |
+
|
150 |
+
# Process button
|
151 |
+
if st.button("Process Document"):
|
152 |
+
with st.spinner("Processing document..."):
|
153 |
+
try:
|
154 |
+
# Save the uploaded file to a temporary location
|
155 |
+
temp_path = "temp_uploaded_image.jpg"
|
156 |
+
image.save(temp_path)
|
157 |
+
|
158 |
+
# Extract fields using OCR + regex
|
159 |
+
fields = extract_fields(temp_path)
|
160 |
+
|
161 |
+
# Extract with Perplexity LLM
|
162 |
+
with st.spinner("Extracting structured data with Perplexity LLM..."):
|
163 |
+
try:
|
164 |
+
llm_result = extract_with_perplexity_llm(pytesseract.image_to_string(Image.open(temp_path)))
|
165 |
+
st.subheader("Structured Data (Perplexity LLM)")
|
166 |
+
st.code(llm_result, language="json")
|
167 |
+
|
168 |
+
# Display extracted fields
|
169 |
+
st.subheader("Extracted Fields")
|
170 |
+
fields_df = pd.DataFrame([fields])
|
171 |
+
st.dataframe(fields_df)
|
172 |
+
|
173 |
+
except Exception as e:
|
174 |
+
st.error(f"LLM extraction failed: {e}")
|
175 |
+
|
176 |
+
except Exception as e:
|
177 |
+
logger.error(f"Error processing document: {str(e)}")
|
178 |
+
st.error(f"Error processing document: {str(e)}")
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
main()
|
chatbot_server.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from fastapi import FastAPI
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from openai import OpenAI
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
import boto3
|
7 |
+
|
8 |
+
# Load environment variables from .env
|
9 |
+
load_dotenv()
|
10 |
+
|
11 |
+
# Initialize OpenAI client for Perplexity
|
12 |
+
client = OpenAI(
|
13 |
+
api_key=os.getenv('PERPLEXITY_API_KEY'),
|
14 |
+
base_url="https://api.perplexity.ai"
|
15 |
+
)
|
16 |
+
|
17 |
+
app = FastAPI()
|
18 |
+
|
19 |
+
class ChatRequest(BaseModel):
|
20 |
+
question: str
|
21 |
+
|
22 |
+
@app.post("/chat")
|
23 |
+
def chat_endpoint(chat_request: ChatRequest):
|
24 |
+
# Connect to DynamoDB
|
25 |
+
dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
|
26 |
+
table = dynamodb.Table('Receipts')
|
27 |
+
|
28 |
+
# Get question and search DynamoDB
|
29 |
+
question = chat_request.question
|
30 |
+
response = table.scan()
|
31 |
+
items = response.get('Items', [])
|
32 |
+
|
33 |
+
# Format items for context with all receipt details
|
34 |
+
context = "\n".join([
|
35 |
+
f"Receipt {item['receipt_no']}:\n"
|
36 |
+
f" Name: {item['name']}\n"
|
37 |
+
f" Date: {item['date']}\n"
|
38 |
+
f" Product: {item['product']}\n"
|
39 |
+
f" Amount Paid: {item['amount_paid']}\n"
|
40 |
+
for item in items
|
41 |
+
])
|
42 |
+
question = f"Based on these receipts:\n{context}\n\nQuestion: {question}\nPlease provide a 2-3 line answer."
|
43 |
+
|
44 |
+
# Prepare messages for the chat
|
45 |
+
messages = [
|
46 |
+
{
|
47 |
+
"role": "system",
|
48 |
+
"content": (
|
49 |
+
"You are an artificial intelligence assistant and you need to "
|
50 |
+
"engage in a helpful, detailed, polite conversation with a user."
|
51 |
+
"Give a 2-3 line answer."
|
52 |
+
)
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"role": "user",
|
56 |
+
"content": question
|
57 |
+
}
|
58 |
+
]
|
59 |
+
|
60 |
+
try:
|
61 |
+
# Get response from Perplexity
|
62 |
+
response = client.chat.completions.create(
|
63 |
+
model="sonar",
|
64 |
+
messages=messages
|
65 |
+
)
|
66 |
+
return {"answer": response.choices[0].message.content}
|
67 |
+
except Exception as e:
|
68 |
+
return {"error": f"Error from LLM: {str(e)}"}
|
insert_dummy_data.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import boto3
|
2 |
+
|
3 |
+
# Initialize DynamoDB resource (ensure AWS credentials and region are set)
|
4 |
+
dynamodb = boto3.resource('dynamodb', region_name='us-east-1') # Change region if needed
|
5 |
+
table = dynamodb.Table('Receipts') # Replace with your table name
|
6 |
+
|
7 |
+
# List of dummy items to insert with meaningful receipt numbers
|
8 |
+
dummy_items = [
|
9 |
+
{
|
10 |
+
'receipt_no': 'RCPT-2024-0001',
|
11 |
+
'amount_paid': '100.00',
|
12 |
+
'date': '2024-01-01',
|
13 |
+
'name': 'John Doe',
|
14 |
+
'product': 'Widget A'
|
15 |
+
},
|
16 |
+
{
|
17 |
+
'receipt_no': 'RCPT-2024-0002',
|
18 |
+
'amount_paid': '250.50',
|
19 |
+
'date': '2024-02-15',
|
20 |
+
'name': 'Jane Smith',
|
21 |
+
'product': 'Gadget B'
|
22 |
+
},
|
23 |
+
{
|
24 |
+
'receipt_no': 'RCPT-2024-0003',
|
25 |
+
'amount_paid': '75.25',
|
26 |
+
'date': '2024-03-10',
|
27 |
+
'name': 'Alice Johnson',
|
28 |
+
'product': 'Thingamajig C'
|
29 |
+
},
|
30 |
+
{
|
31 |
+
'receipt_no': 'RCPT-2024-0004',
|
32 |
+
'amount_paid': '180.00',
|
33 |
+
'date': '2024-04-05',
|
34 |
+
'name': 'Bob Lee',
|
35 |
+
'product': 'Gizmo D'
|
36 |
+
},
|
37 |
+
{
|
38 |
+
'receipt_no': 'RCPT-2024-0005',
|
39 |
+
'amount_paid': '320.75',
|
40 |
+
'date': '2024-05-20',
|
41 |
+
'name': 'Carol King',
|
42 |
+
'product': 'Device E'
|
43 |
+
}
|
44 |
+
]
|
45 |
+
|
46 |
+
# Insert each item
|
47 |
+
for item in dummy_items:
|
48 |
+
table.put_item(Item=item)
|
49 |
+
print(f"Inserted: {item['receipt_no']}")
|
50 |
+
|
51 |
+
print("Dummy data inserted successfully.")
|
requirements.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Core ML dependencies
|
2 |
+
torch==2.2.0
|
3 |
+
transformers==4.37.2
|
4 |
+
datasets>=2.12.0
|
5 |
+
pytesseract==0.3.10
|
6 |
+
Pillow==10.2.0
|
7 |
+
|
8 |
+
# API and Web Framework
|
9 |
+
fastapi==0.109.2
|
10 |
+
uvicorn==0.27.1
|
11 |
+
streamlit==1.32.0
|
12 |
+
python-multipart==0.0.9
|
13 |
+
|
14 |
+
# MLOps and Monitoring
|
15 |
+
wandb>=0.15.0
|
16 |
+
mlflow>=2.4.0
|
17 |
+
dvc>=3.0.0
|
18 |
+
hydra-core>=1.3.2
|
19 |
+
evidently>=0.2.0
|
20 |
+
tensorboard>=2.12.0
|
21 |
+
|
22 |
+
# Cloud and Deployment
|
23 |
+
boto3==1.34.34
|
24 |
+
sagemaker>=2.160.0
|
25 |
+
|
26 |
+
# Utilities
|
27 |
+
numpy==1.26.3
|
28 |
+
pandas==2.2.0
|
29 |
+
python-dotenv==1.0.1
|
30 |
+
pydantic>=2.0.0
|
31 |
+
openai
|
32 |
+
streamlit
|
33 |
+
plotly==5.18.0
|
src/api/main.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
from PIL import Image
|
4 |
+
import io
|
5 |
+
import logging
|
6 |
+
from typing import Dict, Any
|
7 |
+
import json
|
8 |
+
|
9 |
+
from src.models.layoutlm import FormIQModel
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
# Initialize FastAPI app
|
16 |
+
app = FastAPI(
|
17 |
+
title="FormIQ API",
|
18 |
+
description="Intelligent Document Parser API",
|
19 |
+
version="1.0.0"
|
20 |
+
)
|
21 |
+
|
22 |
+
# Add CORS middleware
|
23 |
+
app.add_middleware(
|
24 |
+
CORSMiddleware,
|
25 |
+
allow_origins=["*"],
|
26 |
+
allow_credentials=True,
|
27 |
+
allow_methods=["*"],
|
28 |
+
allow_headers=["*"],
|
29 |
+
)
|
30 |
+
|
31 |
+
# Initialize model
|
32 |
+
model = FormIQModel()
|
33 |
+
|
34 |
+
@app.get("/")
|
35 |
+
async def root():
|
36 |
+
"""Health check endpoint."""
|
37 |
+
return {"status": "healthy", "service": "FormIQ API"}
|
38 |
+
|
39 |
+
@app.post("/extract")
|
40 |
+
async def extract_information(
|
41 |
+
file: UploadFile = File(...),
|
42 |
+
confidence_threshold: float = 0.5,
|
43 |
+
document_type: str = "invoice"
|
44 |
+
) -> Dict[str, Any]:
|
45 |
+
"""Extract information from uploaded document.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
file: Uploaded document image
|
49 |
+
confidence_threshold: Minimum confidence score for predictions
|
50 |
+
document_type: Type of document being processed
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Dictionary containing extracted fields and metadata
|
54 |
+
"""
|
55 |
+
try:
|
56 |
+
# Read and validate image
|
57 |
+
contents = await file.read()
|
58 |
+
image = Image.open(io.BytesIO(contents))
|
59 |
+
if image.mode != "RGB":
|
60 |
+
image = image.convert("RGB")
|
61 |
+
|
62 |
+
# Process image
|
63 |
+
extraction_results = model.predict(
|
64 |
+
image=image,
|
65 |
+
confidence_threshold=confidence_threshold
|
66 |
+
)
|
67 |
+
|
68 |
+
# Validate extraction
|
69 |
+
validation_results = model.validate_extraction(
|
70 |
+
extracted_fields=extraction_results,
|
71 |
+
document_type=document_type
|
72 |
+
)
|
73 |
+
|
74 |
+
# Combine results
|
75 |
+
response = {
|
76 |
+
"extraction": extraction_results,
|
77 |
+
"validation": validation_results,
|
78 |
+
"metadata": {
|
79 |
+
"document_type": document_type,
|
80 |
+
"confidence_threshold": confidence_threshold
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
return response
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(f"Error processing document: {str(e)}")
|
88 |
+
raise HTTPException(
|
89 |
+
status_code=500,
|
90 |
+
detail=f"Error processing document: {str(e)}"
|
91 |
+
)
|
92 |
+
|
93 |
+
@app.get("/model-info")
|
94 |
+
async def get_model_info() -> Dict[str, Any]:
|
95 |
+
"""Get information about the current model."""
|
96 |
+
return {
|
97 |
+
"model_name": model.model.config.model_type,
|
98 |
+
"device": model.device,
|
99 |
+
"version": "1.0.0"
|
100 |
+
}
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
import uvicorn
|
104 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
src/config/config.yaml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Configuration
|
2 |
+
model:
|
3 |
+
name: "microsoft/layoutlmv3-base"
|
4 |
+
device: "cuda" # or "cpu"
|
5 |
+
confidence_threshold: 0.5
|
6 |
+
max_length: 512
|
7 |
+
|
8 |
+
# Training Configuration
|
9 |
+
training:
|
10 |
+
batch_size: 8
|
11 |
+
learning_rate: 2e-5
|
12 |
+
num_epochs: 10
|
13 |
+
warmup_steps: 100
|
14 |
+
weight_decay: 0.01
|
15 |
+
gradient_accumulation_steps: 4
|
16 |
+
|
17 |
+
# Dataset Configuration
|
18 |
+
dataset:
|
19 |
+
train_path: "data/train"
|
20 |
+
val_path: "data/val"
|
21 |
+
test_path: "data/test"
|
22 |
+
max_samples: null # Set to null for all samples
|
23 |
+
augmentation:
|
24 |
+
enabled: true
|
25 |
+
rotation_range: 10
|
26 |
+
width_shift_range: 0.1
|
27 |
+
height_shift_range: 0.1
|
28 |
+
zoom_range: 0.1
|
29 |
+
fill_mode: "nearest"
|
30 |
+
|
31 |
+
# Logging Configuration
|
32 |
+
logging:
|
33 |
+
level: "INFO"
|
34 |
+
wandb:
|
35 |
+
enabled: true
|
36 |
+
project: "formiq"
|
37 |
+
entity: null # Set your W&B username
|
38 |
+
tensorboard:
|
39 |
+
enabled: true
|
40 |
+
log_dir: "logs"
|
41 |
+
|
42 |
+
# API Configuration
|
43 |
+
api:
|
44 |
+
host: "0.0.0.0"
|
45 |
+
port: 8000
|
46 |
+
workers: 4
|
47 |
+
timeout: 60
|
48 |
+
|
49 |
+
# Frontend Configuration
|
50 |
+
frontend:
|
51 |
+
host: "localhost"
|
52 |
+
port: 8501
|
53 |
+
debug: false
|
54 |
+
|
55 |
+
# MLOps Configuration
|
56 |
+
mlops:
|
57 |
+
dvc:
|
58 |
+
remote: "s3://formiq-data"
|
59 |
+
cache_dir: ".dvc/cache"
|
60 |
+
mlflow:
|
61 |
+
tracking_uri: "http://localhost:5000"
|
62 |
+
experiment_name: "formiq"
|
63 |
+
evidently:
|
64 |
+
enabled: true
|
65 |
+
drift_threshold: 0.1
|
66 |
+
window_size: 1000
|
src/frontend/app.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import requests
|
3 |
+
from PIL import Image
|
4 |
+
import io
|
5 |
+
import json
|
6 |
+
import pandas as pd
|
7 |
+
import plotly.express as px
|
8 |
+
import numpy as np
|
9 |
+
from typing import Dict, Any
|
10 |
+
import logging
|
11 |
+
import pytesseract
|
12 |
+
import re
|
13 |
+
from openai import OpenAI, OpenAIError
|
14 |
+
import boto3
|
15 |
+
from botocore.exceptions import ClientError
|
16 |
+
import os
|
17 |
+
from dotenv import load_dotenv
|
18 |
+
load_dotenv()
|
19 |
+
# Configure logging
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
# Constants
|
24 |
+
API_URL = "http://localhost:8000"
|
25 |
+
SUPPORTED_DOCUMENT_TYPES = ["invoice", "receipt", "form"]
|
26 |
+
|
27 |
+
api_key = os.getenv("PERPLEXITY_API_KEY")
|
28 |
+
client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
|
29 |
+
|
30 |
+
REGION = "us-east-1"
|
31 |
+
dynamodb = boto3.resource('dynamodb', region_name=REGION)
|
32 |
+
|
33 |
+
def extract_json_from_llm_output(llm_result):
|
34 |
+
match = re.search(r'\{.*\}', llm_result, re.DOTALL)
|
35 |
+
if match:
|
36 |
+
return match.group(0)
|
37 |
+
return None
|
38 |
+
|
39 |
+
def save_to_dynamodb(data, table_name="Receipts"):
|
40 |
+
dynamodb = boto3.resource("dynamodb")
|
41 |
+
table = dynamodb.Table(table_name)
|
42 |
+
try:
|
43 |
+
table.put_item(Item=data)
|
44 |
+
return True
|
45 |
+
except ClientError as e:
|
46 |
+
st.error(f"Failed to save to DynamoDB: {e}")
|
47 |
+
return False
|
48 |
+
|
49 |
+
def main():
|
50 |
+
st.set_page_config(
|
51 |
+
page_title="FormIQ - Intelligent Document Parser",
|
52 |
+
page_icon="📄",
|
53 |
+
layout="wide"
|
54 |
+
)
|
55 |
+
|
56 |
+
st.title("FormIQ: Intelligent Document Parser")
|
57 |
+
st.markdown("""
|
58 |
+
Upload your documents to extract and validate information using advanced AI models.
|
59 |
+
""")
|
60 |
+
|
61 |
+
# Sidebar
|
62 |
+
with st.sidebar:
|
63 |
+
st.header("Settings")
|
64 |
+
document_type = st.selectbox(
|
65 |
+
"Document Type",
|
66 |
+
options=SUPPORTED_DOCUMENT_TYPES,
|
67 |
+
index=0
|
68 |
+
)
|
69 |
+
|
70 |
+
confidence_threshold = st.slider(
|
71 |
+
"Confidence Threshold",
|
72 |
+
min_value=0.0,
|
73 |
+
max_value=1.0,
|
74 |
+
value=0.5,
|
75 |
+
step=0.05
|
76 |
+
)
|
77 |
+
|
78 |
+
st.markdown("---")
|
79 |
+
st.markdown("### About")
|
80 |
+
st.markdown("""
|
81 |
+
FormIQ uses LayoutLMv3 and GPT-4 to extract and validate information from documents.
|
82 |
+
""")
|
83 |
+
|
84 |
+
# Main content
|
85 |
+
uploaded_file = st.file_uploader(
|
86 |
+
"Upload Document",
|
87 |
+
type=["png", "jpg", "jpeg", "pdf"],
|
88 |
+
help="Upload a document image to process"
|
89 |
+
)
|
90 |
+
|
91 |
+
if uploaded_file is not None:
|
92 |
+
# Display uploaded image
|
93 |
+
image = Image.open(uploaded_file)
|
94 |
+
st.image(image, caption="Uploaded Document", width=600)
|
95 |
+
|
96 |
+
# Process button
|
97 |
+
if st.button("Process Document"):
|
98 |
+
with st.spinner("Processing document..."):
|
99 |
+
try:
|
100 |
+
# Save the uploaded file to a temporary location
|
101 |
+
temp_path = "temp_uploaded_image.jpg"
|
102 |
+
image.save(temp_path)
|
103 |
+
|
104 |
+
# Extract fields using OCR + regex
|
105 |
+
fields = extract_fields(temp_path)
|
106 |
+
|
107 |
+
# Extract with Perplexity LLM using the provided API key
|
108 |
+
with st.spinner("Extracting structured data with Perplexity LLM..."):
|
109 |
+
try:
|
110 |
+
llm_result = extract_with_perplexity_llm(pytesseract.image_to_string(Image.open(temp_path)))
|
111 |
+
st.subheader("Structured Data (Perplexity LLM)")
|
112 |
+
st.code(llm_result, language="json")
|
113 |
+
|
114 |
+
# Extract and save JSON to DynamoDB
|
115 |
+
raw_json = extract_json_from_llm_output(llm_result)
|
116 |
+
if raw_json:
|
117 |
+
try:
|
118 |
+
llm_data = json.loads(raw_json)
|
119 |
+
if save_to_dynamodb(llm_data):
|
120 |
+
st.success("Data saved to DynamoDB!")
|
121 |
+
except Exception as e:
|
122 |
+
st.error(f"Failed to parse/save JSON: {e}")
|
123 |
+
else:
|
124 |
+
st.error("No valid JSON found in LLM output.")
|
125 |
+
except Exception as e:
|
126 |
+
st.error(f"LLM extraction failed: {e}")
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f"Error processing document: {str(e)}")
|
130 |
+
st.error(f"Error processing document: {str(e)}")
|
131 |
+
|
132 |
+
def display_results(results: Dict[str, Any]):
|
133 |
+
"""Display extraction and validation results."""
|
134 |
+
|
135 |
+
# Create tabs for different views
|
136 |
+
tab1, tab2, tab3 = st.tabs(["Extracted Fields", "Validation", "Visualization"])
|
137 |
+
|
138 |
+
with tab1:
|
139 |
+
st.subheader("Extracted Fields")
|
140 |
+
if "fields" in results["extraction"]:
|
141 |
+
fields_df = pd.DataFrame(results["extraction"]["fields"])
|
142 |
+
st.dataframe(fields_df)
|
143 |
+
else:
|
144 |
+
st.info("No fields extracted")
|
145 |
+
|
146 |
+
with tab2:
|
147 |
+
st.subheader("Validation Results")
|
148 |
+
validation = results["validation"]
|
149 |
+
|
150 |
+
# Display validation status
|
151 |
+
status_color = "green" if validation["is_valid"] else "red"
|
152 |
+
st.markdown(f"### Status: :{status_color}[{validation['is_valid']}]")
|
153 |
+
|
154 |
+
# Display validation errors if any
|
155 |
+
if validation["validation_errors"]:
|
156 |
+
st.error("Validation Errors:")
|
157 |
+
for error in validation["validation_errors"]:
|
158 |
+
st.markdown(f"- {error}")
|
159 |
+
|
160 |
+
# Display confidence score
|
161 |
+
st.metric(
|
162 |
+
"Overall Confidence",
|
163 |
+
f"{validation['confidence_score']:.2%}"
|
164 |
+
)
|
165 |
+
|
166 |
+
with tab3:
|
167 |
+
st.subheader("Confidence Visualization")
|
168 |
+
if "confidence_scores" in results["extraction"]["metadata"]:
|
169 |
+
scores = results["extraction"]["metadata"]["confidence_scores"]
|
170 |
+
|
171 |
+
# Create confidence distribution plot
|
172 |
+
fig = px.histogram(
|
173 |
+
x=scores,
|
174 |
+
nbins=20,
|
175 |
+
title="Confidence Score Distribution",
|
176 |
+
labels={"x": "Confidence Score", "y": "Count"}
|
177 |
+
)
|
178 |
+
st.plotly_chart(fig)
|
179 |
+
|
180 |
+
# Display heatmap if available
|
181 |
+
if "bbox" in results["extraction"]["fields"][0]:
|
182 |
+
st.subheader("Field Location Heatmap")
|
183 |
+
# TODO: Implement heatmap visualization
|
184 |
+
st.info("Heatmap visualization coming soon!")
|
185 |
+
|
186 |
+
def group_tokens_by_label(tokens, labels):
|
187 |
+
structured = {}
|
188 |
+
current_label = None
|
189 |
+
current_tokens = []
|
190 |
+
for token, label in zip(tokens, labels):
|
191 |
+
if label != current_label:
|
192 |
+
if current_label is not None:
|
193 |
+
structured.setdefault(current_label, []).append(' '.join(current_tokens))
|
194 |
+
current_label = label
|
195 |
+
current_tokens = [token]
|
196 |
+
else:
|
197 |
+
current_tokens.append(token)
|
198 |
+
if current_label is not None:
|
199 |
+
structured.setdefault(current_label, []).append(' '.join(current_tokens))
|
200 |
+
return structured
|
201 |
+
|
202 |
+
def extract_fields(image_path):
|
203 |
+
# OCR
|
204 |
+
text = pytesseract.image_to_string(Image.open(image_path))
|
205 |
+
|
206 |
+
# Display OCR output for debugging
|
207 |
+
st.subheader("Raw OCR Output (for debugging)")
|
208 |
+
st.code(text)
|
209 |
+
|
210 |
+
# Improved Regex patterns for fields
|
211 |
+
patterns = {
|
212 |
+
"name": r"Mrs\s+\w+\s+\w+",
|
213 |
+
"date": r"Date[:\s]+([\d/]+)",
|
214 |
+
"product": r"\d+\s+\w+.*Style\s+\d+",
|
215 |
+
"amount_paid": r"Total Paid\s+\$?([\d.,]+)",
|
216 |
+
# Improved pattern for receipt number (handles optional dot, colon, spaces)
|
217 |
+
"receipt_no": r"Receipt No\.?\s*:?\s*(\d+)"
|
218 |
+
}
|
219 |
+
|
220 |
+
results = {}
|
221 |
+
for field, pattern in patterns.items():
|
222 |
+
match = re.search(pattern, text, re.IGNORECASE)
|
223 |
+
if match:
|
224 |
+
results[field] = match.group(1) if match.groups() else match.group(0)
|
225 |
+
else:
|
226 |
+
results[field] = None
|
227 |
+
|
228 |
+
return results
|
229 |
+
|
230 |
+
def extract_with_perplexity_llm(ocr_text):
|
231 |
+
prompt = f"""
|
232 |
+
Extract the following fields from this receipt text:
|
233 |
+
- name
|
234 |
+
- date
|
235 |
+
- product
|
236 |
+
- amount_paid
|
237 |
+
- receipt_no
|
238 |
+
|
239 |
+
Text:
|
240 |
+
\"\"\"{ocr_text}\"\"\"
|
241 |
+
|
242 |
+
Return the result as a JSON object with those fields.
|
243 |
+
"""
|
244 |
+
messages = [
|
245 |
+
{
|
246 |
+
"role": "system",
|
247 |
+
"content": (
|
248 |
+
"You are an artificial intelligence assistant. "
|
249 |
+
"Answer user questions as concisely and directly as possible. "
|
250 |
+
"Limit your responses to 2-3 sentences unless the user asks for more detail."
|
251 |
+
)
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"role": "user",
|
255 |
+
"content": prompt
|
256 |
+
}
|
257 |
+
]
|
258 |
+
response = client.chat.completions.create(
|
259 |
+
model="sonar-pro", # Use a valid model name for your account
|
260 |
+
messages=messages,
|
261 |
+
)
|
262 |
+
return response.choices[0].message.content
|
263 |
+
|
264 |
+
def interactive_chatbot_ui():
|
265 |
+
st.header("🤖 Chatbot")
|
266 |
+
if "chat_history" not in st.session_state:
|
267 |
+
st.session_state.chat_history = []
|
268 |
+
|
269 |
+
# Display chat history as chat bubbles
|
270 |
+
for sender, msg in st.session_state.chat_history:
|
271 |
+
if sender == "You":
|
272 |
+
st.markdown(f"<div style='text-align: right; background: #262730; color: #fff; padding: 8px 12px; border-radius: 12px; margin: 4px 0 4px 40px;'><b>You:</b> {msg}</div>", unsafe_allow_html=True)
|
273 |
+
else:
|
274 |
+
st.markdown(f"<div style='text-align: left; background: #31333F; color: #fff; padding: 8px 12px; border-radius: 12px; margin: 4px 40px 4px 0;'><b>Bot:</b> {msg}</div>", unsafe_allow_html=True)
|
275 |
+
|
276 |
+
# Input at the bottom
|
277 |
+
with st.form(key="chat_form", clear_on_submit=True):
|
278 |
+
user_input = st.text_input("Type your message...", key="chat_input_main", placeholder="Ask me anything...")
|
279 |
+
submitted = st.form_submit_button("Send")
|
280 |
+
if submitted and user_input:
|
281 |
+
st.session_state.chat_history.append(("You", user_input))
|
282 |
+
try:
|
283 |
+
response = requests.post(
|
284 |
+
f"{API_URL}/chat",
|
285 |
+
json={"question": user_input}
|
286 |
+
)
|
287 |
+
if response.status_code == 200:
|
288 |
+
bot_reply = response.json()["answer"]
|
289 |
+
else:
|
290 |
+
bot_reply = f"Error: Server returned status code {response.status_code}"
|
291 |
+
except Exception as e:
|
292 |
+
bot_reply = f"Error: {e}"
|
293 |
+
st.session_state.chat_history.append(("Bot", bot_reply))
|
294 |
+
|
295 |
+
if __name__ == "__main__":
|
296 |
+
|
297 |
+
main()
|
298 |
+
st.markdown("---")
|
299 |
+
interactive_chatbot_ui()
|
src/frontend/temp_uploaded_image.jpg
ADDED
![]() |
src/models/layoutlm.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from typing import Dict, List, Tuple, Optional
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class FormIQModel:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
model_name: str = "microsoft/layoutlmv3-base",
|
14 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
):
|
16 |
+
"""Initialize the FormIQ model with LayoutLMv3.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model_name: Name of the pre-trained model to use
|
20 |
+
device: Device to run the model on ('cuda' or 'cpu')
|
21 |
+
"""
|
22 |
+
self.device = device
|
23 |
+
self.processor = LayoutLMv3Processor.from_pretrained(model_name)
|
24 |
+
self.model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
|
25 |
+
self.model.to(device)
|
26 |
+
logger.info(f"Model initialized on {device}")
|
27 |
+
|
28 |
+
def preprocess_image(self, image: Image.Image) -> Dict[str, torch.Tensor]:
|
29 |
+
"""Preprocess the input image for the model.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
image: PIL Image to process
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Dictionary of processed inputs
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
# Process image and text
|
39 |
+
encoding = self.processor(
|
40 |
+
image,
|
41 |
+
return_tensors="pt",
|
42 |
+
truncation=True,
|
43 |
+
max_length=512
|
44 |
+
)
|
45 |
+
|
46 |
+
# Move tensors to device
|
47 |
+
encoding = {k: v.to(self.device) for k, v in encoding.items()}
|
48 |
+
return encoding
|
49 |
+
|
50 |
+
except Exception as e:
|
51 |
+
logger.error(f"Error preprocessing image: {str(e)}")
|
52 |
+
raise
|
53 |
+
|
54 |
+
def predict(
|
55 |
+
self,
|
56 |
+
image: Image.Image,
|
57 |
+
confidence_threshold: float = 0.5
|
58 |
+
) -> Dict[str, List[Dict[str, any]]]:
|
59 |
+
"""Extract information from the document image.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
image: PIL Image of the document
|
63 |
+
confidence_threshold: Minimum confidence score for predictions
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Dictionary containing extracted fields and their metadata
|
67 |
+
"""
|
68 |
+
try:
|
69 |
+
# Preprocess image
|
70 |
+
inputs = self.preprocess_image(image)
|
71 |
+
|
72 |
+
# Get model predictions
|
73 |
+
with torch.no_grad():
|
74 |
+
outputs = self.model(**inputs)
|
75 |
+
predictions = outputs.logits.argmax(-1).squeeze().cpu().numpy()
|
76 |
+
scores = torch.softmax(outputs.logits, dim=-1).max(-1)[0].squeeze().cpu().numpy()
|
77 |
+
|
78 |
+
# Process predictions
|
79 |
+
extracted_fields = self._process_predictions(predictions, scores, confidence_threshold)
|
80 |
+
|
81 |
+
return {
|
82 |
+
"fields": extracted_fields,
|
83 |
+
"metadata": {
|
84 |
+
"confidence_scores": scores.tolist(),
|
85 |
+
"model_version": self.model.config.model_type
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"Error during prediction: {str(e)}")
|
91 |
+
raise
|
92 |
+
|
93 |
+
def _process_predictions(
|
94 |
+
self,
|
95 |
+
predictions: np.ndarray,
|
96 |
+
scores: np.ndarray,
|
97 |
+
confidence_threshold: float
|
98 |
+
) -> List[Dict[str, any]]:
|
99 |
+
"""Process raw model predictions into structured output.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
predictions: Array of predicted class indices
|
103 |
+
scores: Array of confidence scores
|
104 |
+
confidence_threshold: Minimum confidence score
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
List of dictionaries containing field information
|
108 |
+
"""
|
109 |
+
# TODO: Implement field-specific post-processing
|
110 |
+
# This is a placeholder implementation
|
111 |
+
processed_fields = []
|
112 |
+
|
113 |
+
for pred, score in zip(predictions, scores):
|
114 |
+
if score >= confidence_threshold:
|
115 |
+
field_info = {
|
116 |
+
"label": self.model.config.id2label[pred],
|
117 |
+
"confidence": float(score),
|
118 |
+
"bbox": None # TODO: Add bounding box information
|
119 |
+
}
|
120 |
+
processed_fields.append(field_info)
|
121 |
+
|
122 |
+
return processed_fields
|
123 |
+
|
124 |
+
def validate_extraction(
|
125 |
+
self,
|
126 |
+
extracted_fields: Dict[str, List[Dict[str, any]]],
|
127 |
+
document_type: str
|
128 |
+
) -> Dict[str, any]:
|
129 |
+
"""Validate extracted fields based on document type rules.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
extracted_fields: Dictionary of extracted fields
|
133 |
+
document_type: Type of document (e.g., 'invoice', 'receipt')
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
Dictionary containing validation results
|
137 |
+
"""
|
138 |
+
# TODO: Implement field validation logic
|
139 |
+
# This is a placeholder implementation
|
140 |
+
return {
|
141 |
+
"is_valid": True,
|
142 |
+
"validation_errors": [],
|
143 |
+
"confidence_score": 1.0
|
144 |
+
}
|
src/scripts/train.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
from omegaconf import DictConfig, OmegaConf
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
|
6 |
+
from datasets import load_dataset
|
7 |
+
import mlflow
|
8 |
+
import wandb
|
9 |
+
from pathlib import Path
|
10 |
+
import logging
|
11 |
+
from typing import Dict, Any
|
12 |
+
import numpy as np
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
# Configure logging
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
class FormIQTrainer:
|
20 |
+
def __init__(self, config: DictConfig):
|
21 |
+
"""Initialize the trainer with configuration."""
|
22 |
+
self.config = config
|
23 |
+
self.device = torch.device(config.model.device)
|
24 |
+
|
25 |
+
# Initialize model and processor
|
26 |
+
self.processor = LayoutLMv3Processor.from_pretrained(config.model.name)
|
27 |
+
self.model = LayoutLMv3ForTokenClassification.from_pretrained(
|
28 |
+
config.model.name,
|
29 |
+
num_labels=config.model.num_labels
|
30 |
+
)
|
31 |
+
self.model.to(self.device)
|
32 |
+
|
33 |
+
# Initialize optimizer
|
34 |
+
self.optimizer = torch.optim.AdamW(
|
35 |
+
self.model.parameters(),
|
36 |
+
lr=config.training.learning_rate,
|
37 |
+
weight_decay=config.training.weight_decay
|
38 |
+
)
|
39 |
+
|
40 |
+
# Setup logging
|
41 |
+
self.setup_logging()
|
42 |
+
|
43 |
+
def setup_logging(self):
|
44 |
+
"""Setup MLflow and W&B logging."""
|
45 |
+
if self.config.logging.mlflow.enabled:
|
46 |
+
mlflow.set_tracking_uri(self.config.logging.mlflow.tracking_uri)
|
47 |
+
mlflow.set_experiment(self.config.logging.mlflow.experiment_name)
|
48 |
+
|
49 |
+
if self.config.logging.wandb.enabled:
|
50 |
+
wandb.init(
|
51 |
+
project=self.config.logging.wandb.project,
|
52 |
+
entity=self.config.logging.wandb.entity,
|
53 |
+
config=OmegaConf.to_container(self.config, resolve=True)
|
54 |
+
)
|
55 |
+
|
56 |
+
def prepare_dataset(self):
|
57 |
+
"""Prepare the dataset for training."""
|
58 |
+
# TODO: Implement dataset preparation
|
59 |
+
# This is a placeholder implementation
|
60 |
+
return None, None
|
61 |
+
|
62 |
+
def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
|
63 |
+
"""Train for one epoch.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
train_loader: DataLoader for training data
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
Dictionary containing training metrics
|
70 |
+
"""
|
71 |
+
self.model.train()
|
72 |
+
total_loss = 0
|
73 |
+
correct_predictions = 0
|
74 |
+
total_predictions = 0
|
75 |
+
|
76 |
+
progress_bar = tqdm(train_loader, desc="Training")
|
77 |
+
for batch in progress_bar:
|
78 |
+
# Move batch to device
|
79 |
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
80 |
+
|
81 |
+
# Forward pass
|
82 |
+
outputs = self.model(**batch)
|
83 |
+
loss = outputs.loss
|
84 |
+
|
85 |
+
# Backward pass
|
86 |
+
loss.backward()
|
87 |
+
|
88 |
+
# Update weights
|
89 |
+
self.optimizer.step()
|
90 |
+
self.optimizer.zero_grad()
|
91 |
+
|
92 |
+
# Update metrics
|
93 |
+
total_loss += loss.item()
|
94 |
+
predictions = outputs.logits.argmax(-1)
|
95 |
+
correct_predictions += (predictions == batch["labels"]).sum().item()
|
96 |
+
total_predictions += batch["labels"].numel()
|
97 |
+
|
98 |
+
# Update progress bar
|
99 |
+
progress_bar.set_postfix({
|
100 |
+
"loss": loss.item(),
|
101 |
+
"accuracy": correct_predictions / total_predictions
|
102 |
+
})
|
103 |
+
|
104 |
+
# Calculate epoch metrics
|
105 |
+
metrics = {
|
106 |
+
"train_loss": total_loss / len(train_loader),
|
107 |
+
"train_accuracy": correct_predictions / total_predictions
|
108 |
+
}
|
109 |
+
|
110 |
+
return metrics
|
111 |
+
|
112 |
+
def evaluate(self, eval_loader: DataLoader) -> Dict[str, float]:
|
113 |
+
"""Evaluate the model.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
eval_loader: DataLoader for evaluation data
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
Dictionary containing evaluation metrics
|
120 |
+
"""
|
121 |
+
self.model.eval()
|
122 |
+
total_loss = 0
|
123 |
+
correct_predictions = 0
|
124 |
+
total_predictions = 0
|
125 |
+
|
126 |
+
with torch.no_grad():
|
127 |
+
for batch in tqdm(eval_loader, desc="Evaluating"):
|
128 |
+
# Move batch to device
|
129 |
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
130 |
+
|
131 |
+
# Forward pass
|
132 |
+
outputs = self.model(**batch)
|
133 |
+
loss = outputs.loss
|
134 |
+
|
135 |
+
# Update metrics
|
136 |
+
total_loss += loss.item()
|
137 |
+
predictions = outputs.logits.argmax(-1)
|
138 |
+
correct_predictions += (predictions == batch["labels"]).sum().item()
|
139 |
+
total_predictions += batch["labels"].numel()
|
140 |
+
|
141 |
+
# Calculate evaluation metrics
|
142 |
+
metrics = {
|
143 |
+
"eval_loss": total_loss / len(eval_loader),
|
144 |
+
"eval_accuracy": correct_predictions / total_predictions
|
145 |
+
}
|
146 |
+
|
147 |
+
return metrics
|
148 |
+
|
149 |
+
def train(self):
|
150 |
+
"""Train the model."""
|
151 |
+
# Prepare datasets
|
152 |
+
train_loader, eval_loader = self.prepare_dataset()
|
153 |
+
|
154 |
+
# Training loop
|
155 |
+
best_eval_loss = float('inf')
|
156 |
+
for epoch in range(self.config.training.num_epochs):
|
157 |
+
logger.info(f"Epoch {epoch + 1}/{self.config.training.num_epochs}")
|
158 |
+
|
159 |
+
# Train
|
160 |
+
train_metrics = self.train_epoch(train_loader)
|
161 |
+
|
162 |
+
# Evaluate
|
163 |
+
eval_metrics = self.evaluate(eval_loader)
|
164 |
+
|
165 |
+
# Log metrics
|
166 |
+
metrics = {**train_metrics, **eval_metrics}
|
167 |
+
if self.config.logging.mlflow.enabled:
|
168 |
+
mlflow.log_metrics(metrics, step=epoch)
|
169 |
+
if self.config.logging.wandb.enabled:
|
170 |
+
wandb.log(metrics, step=epoch)
|
171 |
+
|
172 |
+
# Save best model
|
173 |
+
if eval_metrics["eval_loss"] < best_eval_loss:
|
174 |
+
best_eval_loss = eval_metrics["eval_loss"]
|
175 |
+
self.save_model("best_model")
|
176 |
+
|
177 |
+
# Save checkpoint
|
178 |
+
self.save_model(f"checkpoint_epoch_{epoch + 1}")
|
179 |
+
|
180 |
+
def save_model(self, name: str):
|
181 |
+
"""Save the model.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
name: Name of the saved model
|
185 |
+
"""
|
186 |
+
save_path = Path(self.config.model.save_dir) / name
|
187 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
188 |
+
|
189 |
+
self.model.save_pretrained(save_path)
|
190 |
+
self.processor.save_pretrained(save_path)
|
191 |
+
|
192 |
+
if self.config.logging.mlflow.enabled:
|
193 |
+
mlflow.log_artifacts(str(save_path), f"models/{name}")
|
194 |
+
|
195 |
+
@hydra.main(config_path="../config", config_name="config")
|
196 |
+
def main(config: DictConfig):
|
197 |
+
"""Main training function."""
|
198 |
+
trainer = FormIQTrainer(config)
|
199 |
+
trainer.train()
|
200 |
+
|
201 |
+
if __name__ == "__main__":
|
202 |
+
main()
|
tests/test_model.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
from src.models.layoutlm import FormIQModel
|
6 |
+
|
7 |
+
@pytest.fixture
|
8 |
+
def model():
|
9 |
+
"""Create a model instance for testing."""
|
10 |
+
return FormIQModel(device="cpu")
|
11 |
+
|
12 |
+
@pytest.fixture
|
13 |
+
def sample_image():
|
14 |
+
"""Create a sample image for testing."""
|
15 |
+
# Create a random image
|
16 |
+
image_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
|
17 |
+
return Image.fromarray(image_array)
|
18 |
+
|
19 |
+
def test_model_initialization(model):
|
20 |
+
"""Test model initialization."""
|
21 |
+
assert model.device == "cpu"
|
22 |
+
assert model.model is not None
|
23 |
+
assert model.processor is not None
|
24 |
+
|
25 |
+
def test_preprocess_image(model, sample_image):
|
26 |
+
"""Test image preprocessing."""
|
27 |
+
processed = model.preprocess_image(sample_image)
|
28 |
+
|
29 |
+
# Check if all required keys are present
|
30 |
+
assert "input_ids" in processed
|
31 |
+
assert "attention_mask" in processed
|
32 |
+
assert "bbox" in processed
|
33 |
+
assert "pixel_values" in processed
|
34 |
+
|
35 |
+
# Check tensor types and shapes
|
36 |
+
assert isinstance(processed["input_ids"], torch.Tensor)
|
37 |
+
assert isinstance(processed["attention_mask"], torch.Tensor)
|
38 |
+
assert isinstance(processed["bbox"], torch.Tensor)
|
39 |
+
assert isinstance(processed["pixel_values"], torch.Tensor)
|
40 |
+
|
41 |
+
def test_predict(model, sample_image):
|
42 |
+
"""Test prediction functionality."""
|
43 |
+
results = model.predict(sample_image, confidence_threshold=0.5)
|
44 |
+
|
45 |
+
# Check result structure
|
46 |
+
assert "fields" in results
|
47 |
+
assert "metadata" in results
|
48 |
+
assert isinstance(results["fields"], list)
|
49 |
+
assert isinstance(results["metadata"], dict)
|
50 |
+
|
51 |
+
# Check metadata
|
52 |
+
assert "confidence_scores" in results["metadata"]
|
53 |
+
assert "model_version" in results["metadata"]
|
54 |
+
|
55 |
+
def test_validate_extraction(model):
|
56 |
+
"""Test field validation."""
|
57 |
+
# Create sample extraction results
|
58 |
+
sample_extraction = {
|
59 |
+
"fields": [
|
60 |
+
{"label": "amount", "confidence": 0.95, "value": "100.00"},
|
61 |
+
{"label": "date", "confidence": 0.85, "value": "2024-03-20"}
|
62 |
+
]
|
63 |
+
}
|
64 |
+
|
65 |
+
# Test validation
|
66 |
+
validation_results = model.validate_extraction(
|
67 |
+
sample_extraction,
|
68 |
+
document_type="invoice"
|
69 |
+
)
|
70 |
+
|
71 |
+
# Check validation results structure
|
72 |
+
assert "is_valid" in validation_results
|
73 |
+
assert "validation_errors" in validation_results
|
74 |
+
assert "confidence_score" in validation_results
|
75 |
+
|
76 |
+
# Check types
|
77 |
+
assert isinstance(validation_results["is_valid"], bool)
|
78 |
+
assert isinstance(validation_results["validation_errors"], list)
|
79 |
+
assert isinstance(validation_results["confidence_score"], float)
|
80 |
+
|
81 |
+
def test_error_handling(model):
|
82 |
+
"""Test error handling."""
|
83 |
+
# Test with invalid image
|
84 |
+
with pytest.raises(Exception):
|
85 |
+
model.predict(Image.new("RGB", (0, 0)))
|
86 |
+
|
87 |
+
# Test with invalid confidence threshold
|
88 |
+
with pytest.raises(Exception):
|
89 |
+
model.predict(Image.new("RGB", (224, 224)), confidence_threshold=2.0)
|
90 |
+
|
91 |
+
# Test with invalid document type
|
92 |
+
with pytest.raises(Exception):
|
93 |
+
model.validate_extraction({}, document_type="invalid_type")
|
transformers
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit b3db4ddb2255bb4c8c4340fa630a53ac1cc53dee
|