chandini2595 commited on
Commit
83dd2a8
·
0 Parent(s):

Initial commit without binary files

Browse files
.gitattributes ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pdf filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
4
+ *.png filter=lfs diff=lfs merge=lfs -text
.github/workflows/ci-cd.yml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: FormIQ CI/CD
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ branches: [ main ]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: [3.8, 3.9]
15
+
16
+ steps:
17
+ - uses: actions/checkout@v2
18
+
19
+ - name: Set up Python ${{ matrix.python-version }}
20
+ uses: actions/setup-python@v2
21
+ with:
22
+ python-version: ${{ matrix.python-version }}
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ python -m pip install --upgrade pip
27
+ pip install -r requirements.txt
28
+ pip install pytest pytest-cov
29
+
30
+ - name: Run tests
31
+ run: |
32
+ pytest tests/ --cov=src/ --cov-report=xml
33
+
34
+ - name: Upload coverage to Codecov
35
+ uses: codecov/codecov-action@v2
36
+ with:
37
+ file: ./coverage.xml
38
+ fail_ci_if_error: true
39
+
40
+ lint:
41
+ runs-on: ubuntu-latest
42
+ steps:
43
+ - uses: actions/checkout@v2
44
+
45
+ - name: Set up Python
46
+ uses: actions/setup-python@v2
47
+ with:
48
+ python-version: 3.9
49
+
50
+ - name: Install dependencies
51
+ run: |
52
+ python -m pip install --upgrade pip
53
+ pip install flake8 black isort
54
+
55
+ - name: Run linters
56
+ run: |
57
+ flake8 src/ tests/
58
+ black --check src/ tests/
59
+ isort --check-only src/ tests/
60
+
61
+ build-and-push:
62
+ needs: [test, lint]
63
+ runs-on: ubuntu-latest
64
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main'
65
+
66
+ steps:
67
+ - uses: actions/checkout@v2
68
+
69
+ - name: Set up Docker Buildx
70
+ uses: docker/setup-buildx-action@v1
71
+
72
+ - name: Login to DockerHub
73
+ uses: docker/login-action@v1
74
+ with:
75
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
76
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
77
+
78
+ - name: Build and push
79
+ uses: docker/build-push-action@v2
80
+ with:
81
+ context: .
82
+ push: true
83
+ tags: |
84
+ ${{ secrets.DOCKERHUB_USERNAME }}/formiq:latest
85
+ ${{ secrets.DOCKERHUB_USERNAME }}/formiq:${{ github.sha }}
86
+ cache-from: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/formiq:buildcache
87
+ cache-to: type=registry,ref=${{ secrets.DOCKERHUB_USERNAME }}/formiq:buildcache,mode=max
88
+
89
+ deploy:
90
+ needs: build-and-push
91
+ runs-on: ubuntu-latest
92
+ if: github.event_name == 'push' && github.ref == 'refs/heads/main'
93
+
94
+ steps:
95
+ - uses: actions/checkout@v2
96
+
97
+ - name: Configure AWS credentials
98
+ uses: aws-actions/configure-aws-credentials@v1
99
+ with:
100
+ aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
101
+ aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
102
+ aws-region: us-east-1
103
+
104
+ - name: Deploy to SageMaker
105
+ run: |
106
+ # Update SageMaker endpoint with new model
107
+ aws sagemaker update-endpoint \
108
+ --endpoint-name formiq-endpoint \
109
+ --endpoint-config-name formiq-config-${{ github.sha }}
110
+
111
+ - name: Deploy to ECS
112
+ run: |
113
+ # Update ECS service with new container
114
+ aws ecs update-service \
115
+ --cluster formiq-cluster \
116
+ --service formiq-service \
117
+ --force-new-deployment
.gitignore ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ ENV/
26
+ env/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # Project specific
35
+ temp_uploaded_image.jpg
36
+ .env
37
+ *.log
38
+ .DS_Store
39
+
40
+ # Model files
41
+ *.pt
42
+ *.pth
43
+ *.onnx
44
+ *.h5
45
+ *.model
46
+
47
+ # Data
48
+ *.csv
49
+ *.json
50
+ *.xlsx
51
+ *.db
52
+ *.sqlite3
53
+
54
+ # Jupyter Notebook
55
+ .ipynb_checkpoints
56
+ *.ipynb
57
+
58
+ # Logs
59
+ logs/
60
+ *.log
61
+
62
+ # MLflow
63
+ mlruns/
64
+ mlflow.db
65
+
66
+ # DVC
67
+ .dvc/
68
+ .dvc/cache/
69
+
70
+ # Testing
71
+ .coverage
72
+ coverage.xml
73
+ htmlcov/
74
+ .pytest_cache/
75
+
76
+ # Docker
77
+ .docker/
78
+
79
+ # AWS
80
+ .aws/
81
+ *.pem
82
+
83
+ # Environment variables
84
+ .env.*
85
+
86
+ # Distribution
87
+ dist/
88
+ build/
89
+ *.egg-info/
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use NVIDIA CUDA base image for GPU support
2
+ FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ DEBIAN_FRONTEND=noninteractive \
7
+ PYTHON_VERSION=3.9
8
+
9
+ # Install system dependencies
10
+ RUN apt-get update && apt-get install -y \
11
+ python${PYTHON_VERSION} \
12
+ python3-pip \
13
+ python${PYTHON_VERSION}-dev \
14
+ git \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ # Set working directory
18
+ WORKDIR /app
19
+
20
+ # Copy requirements first to leverage Docker cache
21
+ COPY requirements.txt .
22
+
23
+ # Install Python dependencies
24
+ RUN pip3 install --no-cache-dir -r requirements.txt
25
+
26
+ # Copy application code
27
+ COPY . .
28
+
29
+ # Create necessary directories
30
+ RUN mkdir -p data/train data/val data/test logs
31
+
32
+ # Set environment variables for the application
33
+ ENV MODEL_SAVE_DIR=/app/models \
34
+ DATA_DIR=/app/data \
35
+ LOG_DIR=/app/logs
36
+
37
+ # Expose ports
38
+ EXPOSE 8000 8501
39
+
40
+ # Create a non-root user
41
+ RUN useradd -m -u 1000 appuser
42
+ RUN chown -R appuser:appuser /app
43
+ USER appuser
44
+
45
+ # Start the application
46
+ CMD ["sh", "-c", "uvicorn src.api.main:app --host 0.0.0.0 --port 8000 & streamlit run src/frontend/app.py --server.port 8501 --server.address 0.0.0.0"]
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FormIQ - Intelligent Document Parser
2
+
3
+ FormIQ is an intelligent document parser that uses advanced AI models to extract and validate information from various types of documents.
4
+
5
+ ## Features
6
+
7
+ - Document image upload and processing
8
+ - OCR text extraction using Tesseract
9
+ - Advanced document understanding using LayoutLMv3
10
+ - Structured information extraction using Perplexity AI
11
+ - Interactive web interface built with Streamlit
12
+
13
+ ## Technologies Used
14
+
15
+ - **Frontend**: Streamlit
16
+ - **OCR**: Tesseract
17
+ - **Document Understanding**: LayoutLMv3
18
+ - **Text Processing**: Perplexity AI
19
+ - **Data Processing**: Pandas, NumPy
20
+ - **Visualization**: Plotly
21
+
22
+ ## Setup
23
+
24
+ 1. Clone the repository
25
+ 2. Install dependencies:
26
+ ```bash
27
+ pip install -r requirements.txt
28
+ ```
29
+ 3. Set up environment variables:
30
+ ```bash
31
+ PERPLEXITY_API_KEY=your_api_key_here
32
+ ```
33
+
34
+ ## Usage
35
+
36
+ 1. Run the Streamlit app:
37
+ ```bash
38
+ streamlit run app.py
39
+ ```
40
+ 2. Open your browser and navigate to the provided URL
41
+ 3. Upload a document image
42
+ 4. Click "Process Document" to extract information
43
+
44
+ ## Hugging Face Spaces Deployment
45
+
46
+ This project is deployed on Hugging Face Spaces. You can access the live demo at: [Your Spaces URL]
47
+
48
+ ## Contributing
49
+
50
+ Contributions are welcome! Please feel free to submit a Pull Request.
51
+
52
+ ## License
53
+
54
+ This project is licensed under the MIT License - see the LICENSE file for details.
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
4
+ from PIL import Image
5
+ import io
6
+ import json
7
+ import pandas as pd
8
+ import plotly.express as px
9
+ import numpy as np
10
+ from typing import Dict, Any
11
+ import logging
12
+ import pytesseract
13
+ import re
14
+ from openai import OpenAI
15
+ import os
16
+ from dotenv import load_dotenv
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Load environment variables
23
+ load_dotenv()
24
+
25
+ # Initialize OpenAI client for Perplexity
26
+ client = OpenAI(
27
+ api_key=os.getenv('PERPLEXITY_API_KEY'),
28
+ base_url="https://api.perplexity.ai"
29
+ )
30
+
31
+ # Initialize LayoutLM model
32
+ @st.cache_resource
33
+ def load_model():
34
+ model_name = "microsoft/layoutlmv3-base"
35
+ processor = LayoutLMv3Processor.from_pretrained(model_name)
36
+ model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
37
+ return processor, model
38
+
39
+ def extract_json_from_llm_output(llm_result):
40
+ match = re.search(r'\{.*\}', llm_result, re.DOTALL)
41
+ if match:
42
+ return match.group(0)
43
+ return None
44
+
45
+ def extract_fields(image_path):
46
+ # OCR
47
+ text = pytesseract.image_to_string(Image.open(image_path))
48
+
49
+ # Display OCR output for debugging
50
+ st.subheader("Raw OCR Output")
51
+ st.code(text)
52
+
53
+ # Improved Regex patterns for fields
54
+ patterns = {
55
+ "name": r"Mrs\s+\w+\s+\w+",
56
+ "date": r"Date[:\s]+([\d/]+)",
57
+ "product": r"\d+\s+\w+.*Style\s+\d+",
58
+ "amount_paid": r"Total Paid\s+\$?([\d.,]+)",
59
+ "receipt_no": r"Receipt No\.?\s*:?\s*(\d+)"
60
+ }
61
+
62
+ results = {}
63
+ for field, pattern in patterns.items():
64
+ match = re.search(pattern, text, re.IGNORECASE)
65
+ if match:
66
+ results[field] = match.group(1) if match.groups() else match.group(0)
67
+ else:
68
+ results[field] = None
69
+
70
+ return results
71
+
72
+ def extract_with_perplexity_llm(ocr_text):
73
+ prompt = f"""
74
+ Extract the following fields from this receipt text:
75
+ - name
76
+ - date
77
+ - product
78
+ - amount_paid
79
+ - receipt_no
80
+
81
+ Text:
82
+ \"\"\"{ocr_text}\"\"\"
83
+
84
+ Return the result as a JSON object with those fields.
85
+ """
86
+ messages = [
87
+ {
88
+ "role": "system",
89
+ "content": "You are an AI assistant that extracts structured information from text."
90
+ },
91
+ {
92
+ "role": "user",
93
+ "content": prompt
94
+ }
95
+ ]
96
+
97
+ response = client.chat.completions.create(
98
+ model="sonar-pro",
99
+ messages=messages
100
+ )
101
+ return response.choices[0].message.content
102
+
103
+ def main():
104
+ st.set_page_config(
105
+ page_title="FormIQ - Intelligent Document Parser",
106
+ page_icon="📄",
107
+ layout="wide"
108
+ )
109
+
110
+ st.title("FormIQ: Intelligent Document Parser")
111
+ st.markdown("""
112
+ Upload your documents to extract and validate information using advanced AI models.
113
+ """)
114
+
115
+ # Sidebar
116
+ with st.sidebar:
117
+ st.header("Settings")
118
+ document_type = st.selectbox(
119
+ "Document Type",
120
+ options=["invoice", "receipt", "form"],
121
+ index=0
122
+ )
123
+
124
+ confidence_threshold = st.slider(
125
+ "Confidence Threshold",
126
+ min_value=0.0,
127
+ max_value=1.0,
128
+ value=0.5,
129
+ step=0.05
130
+ )
131
+
132
+ st.markdown("---")
133
+ st.markdown("### About")
134
+ st.markdown("""
135
+ FormIQ uses LayoutLMv3 and Perplexity AI to extract and validate information from documents.
136
+ """)
137
+
138
+ # Main content
139
+ uploaded_file = st.file_uploader(
140
+ "Upload Document",
141
+ type=["png", "jpg", "jpeg", "pdf"],
142
+ help="Upload a document image to process"
143
+ )
144
+
145
+ if uploaded_file is not None:
146
+ # Display uploaded image
147
+ image = Image.open(uploaded_file)
148
+ st.image(image, caption="Uploaded Document", width=600)
149
+
150
+ # Process button
151
+ if st.button("Process Document"):
152
+ with st.spinner("Processing document..."):
153
+ try:
154
+ # Save the uploaded file to a temporary location
155
+ temp_path = "temp_uploaded_image.jpg"
156
+ image.save(temp_path)
157
+
158
+ # Extract fields using OCR + regex
159
+ fields = extract_fields(temp_path)
160
+
161
+ # Extract with Perplexity LLM
162
+ with st.spinner("Extracting structured data with Perplexity LLM..."):
163
+ try:
164
+ llm_result = extract_with_perplexity_llm(pytesseract.image_to_string(Image.open(temp_path)))
165
+ st.subheader("Structured Data (Perplexity LLM)")
166
+ st.code(llm_result, language="json")
167
+
168
+ # Display extracted fields
169
+ st.subheader("Extracted Fields")
170
+ fields_df = pd.DataFrame([fields])
171
+ st.dataframe(fields_df)
172
+
173
+ except Exception as e:
174
+ st.error(f"LLM extraction failed: {e}")
175
+
176
+ except Exception as e:
177
+ logger.error(f"Error processing document: {str(e)}")
178
+ st.error(f"Error processing document: {str(e)}")
179
+
180
+ if __name__ == "__main__":
181
+ main()
chatbot_server.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from openai import OpenAI
5
+ from dotenv import load_dotenv
6
+ import boto3
7
+
8
+ # Load environment variables from .env
9
+ load_dotenv()
10
+
11
+ # Initialize OpenAI client for Perplexity
12
+ client = OpenAI(
13
+ api_key=os.getenv('PERPLEXITY_API_KEY'),
14
+ base_url="https://api.perplexity.ai"
15
+ )
16
+
17
+ app = FastAPI()
18
+
19
+ class ChatRequest(BaseModel):
20
+ question: str
21
+
22
+ @app.post("/chat")
23
+ def chat_endpoint(chat_request: ChatRequest):
24
+ # Connect to DynamoDB
25
+ dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
26
+ table = dynamodb.Table('Receipts')
27
+
28
+ # Get question and search DynamoDB
29
+ question = chat_request.question
30
+ response = table.scan()
31
+ items = response.get('Items', [])
32
+
33
+ # Format items for context with all receipt details
34
+ context = "\n".join([
35
+ f"Receipt {item['receipt_no']}:\n"
36
+ f" Name: {item['name']}\n"
37
+ f" Date: {item['date']}\n"
38
+ f" Product: {item['product']}\n"
39
+ f" Amount Paid: {item['amount_paid']}\n"
40
+ for item in items
41
+ ])
42
+ question = f"Based on these receipts:\n{context}\n\nQuestion: {question}\nPlease provide a 2-3 line answer."
43
+
44
+ # Prepare messages for the chat
45
+ messages = [
46
+ {
47
+ "role": "system",
48
+ "content": (
49
+ "You are an artificial intelligence assistant and you need to "
50
+ "engage in a helpful, detailed, polite conversation with a user."
51
+ "Give a 2-3 line answer."
52
+ )
53
+ },
54
+ {
55
+ "role": "user",
56
+ "content": question
57
+ }
58
+ ]
59
+
60
+ try:
61
+ # Get response from Perplexity
62
+ response = client.chat.completions.create(
63
+ model="sonar",
64
+ messages=messages
65
+ )
66
+ return {"answer": response.choices[0].message.content}
67
+ except Exception as e:
68
+ return {"error": f"Error from LLM: {str(e)}"}
insert_dummy_data.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import boto3
2
+
3
+ # Initialize DynamoDB resource (ensure AWS credentials and region are set)
4
+ dynamodb = boto3.resource('dynamodb', region_name='us-east-1') # Change region if needed
5
+ table = dynamodb.Table('Receipts') # Replace with your table name
6
+
7
+ # List of dummy items to insert with meaningful receipt numbers
8
+ dummy_items = [
9
+ {
10
+ 'receipt_no': 'RCPT-2024-0001',
11
+ 'amount_paid': '100.00',
12
+ 'date': '2024-01-01',
13
+ 'name': 'John Doe',
14
+ 'product': 'Widget A'
15
+ },
16
+ {
17
+ 'receipt_no': 'RCPT-2024-0002',
18
+ 'amount_paid': '250.50',
19
+ 'date': '2024-02-15',
20
+ 'name': 'Jane Smith',
21
+ 'product': 'Gadget B'
22
+ },
23
+ {
24
+ 'receipt_no': 'RCPT-2024-0003',
25
+ 'amount_paid': '75.25',
26
+ 'date': '2024-03-10',
27
+ 'name': 'Alice Johnson',
28
+ 'product': 'Thingamajig C'
29
+ },
30
+ {
31
+ 'receipt_no': 'RCPT-2024-0004',
32
+ 'amount_paid': '180.00',
33
+ 'date': '2024-04-05',
34
+ 'name': 'Bob Lee',
35
+ 'product': 'Gizmo D'
36
+ },
37
+ {
38
+ 'receipt_no': 'RCPT-2024-0005',
39
+ 'amount_paid': '320.75',
40
+ 'date': '2024-05-20',
41
+ 'name': 'Carol King',
42
+ 'product': 'Device E'
43
+ }
44
+ ]
45
+
46
+ # Insert each item
47
+ for item in dummy_items:
48
+ table.put_item(Item=item)
49
+ print(f"Inserted: {item['receipt_no']}")
50
+
51
+ print("Dummy data inserted successfully.")
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML dependencies
2
+ torch==2.2.0
3
+ transformers==4.37.2
4
+ datasets>=2.12.0
5
+ pytesseract==0.3.10
6
+ Pillow==10.2.0
7
+
8
+ # API and Web Framework
9
+ fastapi==0.109.2
10
+ uvicorn==0.27.1
11
+ streamlit==1.32.0
12
+ python-multipart==0.0.9
13
+
14
+ # MLOps and Monitoring
15
+ wandb>=0.15.0
16
+ mlflow>=2.4.0
17
+ dvc>=3.0.0
18
+ hydra-core>=1.3.2
19
+ evidently>=0.2.0
20
+ tensorboard>=2.12.0
21
+
22
+ # Cloud and Deployment
23
+ boto3==1.34.34
24
+ sagemaker>=2.160.0
25
+
26
+ # Utilities
27
+ numpy==1.26.3
28
+ pandas==2.2.0
29
+ python-dotenv==1.0.1
30
+ pydantic>=2.0.0
31
+ openai
32
+ streamlit
33
+ plotly==5.18.0
src/api/main.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from PIL import Image
4
+ import io
5
+ import logging
6
+ from typing import Dict, Any
7
+ import json
8
+
9
+ from src.models.layoutlm import FormIQModel
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Initialize FastAPI app
16
+ app = FastAPI(
17
+ title="FormIQ API",
18
+ description="Intelligent Document Parser API",
19
+ version="1.0.0"
20
+ )
21
+
22
+ # Add CORS middleware
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # Initialize model
32
+ model = FormIQModel()
33
+
34
+ @app.get("/")
35
+ async def root():
36
+ """Health check endpoint."""
37
+ return {"status": "healthy", "service": "FormIQ API"}
38
+
39
+ @app.post("/extract")
40
+ async def extract_information(
41
+ file: UploadFile = File(...),
42
+ confidence_threshold: float = 0.5,
43
+ document_type: str = "invoice"
44
+ ) -> Dict[str, Any]:
45
+ """Extract information from uploaded document.
46
+
47
+ Args:
48
+ file: Uploaded document image
49
+ confidence_threshold: Minimum confidence score for predictions
50
+ document_type: Type of document being processed
51
+
52
+ Returns:
53
+ Dictionary containing extracted fields and metadata
54
+ """
55
+ try:
56
+ # Read and validate image
57
+ contents = await file.read()
58
+ image = Image.open(io.BytesIO(contents))
59
+ if image.mode != "RGB":
60
+ image = image.convert("RGB")
61
+
62
+ # Process image
63
+ extraction_results = model.predict(
64
+ image=image,
65
+ confidence_threshold=confidence_threshold
66
+ )
67
+
68
+ # Validate extraction
69
+ validation_results = model.validate_extraction(
70
+ extracted_fields=extraction_results,
71
+ document_type=document_type
72
+ )
73
+
74
+ # Combine results
75
+ response = {
76
+ "extraction": extraction_results,
77
+ "validation": validation_results,
78
+ "metadata": {
79
+ "document_type": document_type,
80
+ "confidence_threshold": confidence_threshold
81
+ }
82
+ }
83
+
84
+ return response
85
+
86
+ except Exception as e:
87
+ logger.error(f"Error processing document: {str(e)}")
88
+ raise HTTPException(
89
+ status_code=500,
90
+ detail=f"Error processing document: {str(e)}"
91
+ )
92
+
93
+ @app.get("/model-info")
94
+ async def get_model_info() -> Dict[str, Any]:
95
+ """Get information about the current model."""
96
+ return {
97
+ "model_name": model.model.config.model_type,
98
+ "device": model.device,
99
+ "version": "1.0.0"
100
+ }
101
+
102
+ if __name__ == "__main__":
103
+ import uvicorn
104
+ uvicorn.run(app, host="0.0.0.0", port=8000)
src/config/config.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Configuration
2
+ model:
3
+ name: "microsoft/layoutlmv3-base"
4
+ device: "cuda" # or "cpu"
5
+ confidence_threshold: 0.5
6
+ max_length: 512
7
+
8
+ # Training Configuration
9
+ training:
10
+ batch_size: 8
11
+ learning_rate: 2e-5
12
+ num_epochs: 10
13
+ warmup_steps: 100
14
+ weight_decay: 0.01
15
+ gradient_accumulation_steps: 4
16
+
17
+ # Dataset Configuration
18
+ dataset:
19
+ train_path: "data/train"
20
+ val_path: "data/val"
21
+ test_path: "data/test"
22
+ max_samples: null # Set to null for all samples
23
+ augmentation:
24
+ enabled: true
25
+ rotation_range: 10
26
+ width_shift_range: 0.1
27
+ height_shift_range: 0.1
28
+ zoom_range: 0.1
29
+ fill_mode: "nearest"
30
+
31
+ # Logging Configuration
32
+ logging:
33
+ level: "INFO"
34
+ wandb:
35
+ enabled: true
36
+ project: "formiq"
37
+ entity: null # Set your W&B username
38
+ tensorboard:
39
+ enabled: true
40
+ log_dir: "logs"
41
+
42
+ # API Configuration
43
+ api:
44
+ host: "0.0.0.0"
45
+ port: 8000
46
+ workers: 4
47
+ timeout: 60
48
+
49
+ # Frontend Configuration
50
+ frontend:
51
+ host: "localhost"
52
+ port: 8501
53
+ debug: false
54
+
55
+ # MLOps Configuration
56
+ mlops:
57
+ dvc:
58
+ remote: "s3://formiq-data"
59
+ cache_dir: ".dvc/cache"
60
+ mlflow:
61
+ tracking_uri: "http://localhost:5000"
62
+ experiment_name: "formiq"
63
+ evidently:
64
+ enabled: true
65
+ drift_threshold: 0.1
66
+ window_size: 1000
src/frontend/app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ from PIL import Image
4
+ import io
5
+ import json
6
+ import pandas as pd
7
+ import plotly.express as px
8
+ import numpy as np
9
+ from typing import Dict, Any
10
+ import logging
11
+ import pytesseract
12
+ import re
13
+ from openai import OpenAI, OpenAIError
14
+ import boto3
15
+ from botocore.exceptions import ClientError
16
+ import os
17
+ from dotenv import load_dotenv
18
+ load_dotenv()
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Constants
24
+ API_URL = "http://localhost:8000"
25
+ SUPPORTED_DOCUMENT_TYPES = ["invoice", "receipt", "form"]
26
+
27
+ api_key = os.getenv("PERPLEXITY_API_KEY")
28
+ client = OpenAI(api_key=api_key, base_url="https://api.perplexity.ai")
29
+
30
+ REGION = "us-east-1"
31
+ dynamodb = boto3.resource('dynamodb', region_name=REGION)
32
+
33
+ def extract_json_from_llm_output(llm_result):
34
+ match = re.search(r'\{.*\}', llm_result, re.DOTALL)
35
+ if match:
36
+ return match.group(0)
37
+ return None
38
+
39
+ def save_to_dynamodb(data, table_name="Receipts"):
40
+ dynamodb = boto3.resource("dynamodb")
41
+ table = dynamodb.Table(table_name)
42
+ try:
43
+ table.put_item(Item=data)
44
+ return True
45
+ except ClientError as e:
46
+ st.error(f"Failed to save to DynamoDB: {e}")
47
+ return False
48
+
49
+ def main():
50
+ st.set_page_config(
51
+ page_title="FormIQ - Intelligent Document Parser",
52
+ page_icon="📄",
53
+ layout="wide"
54
+ )
55
+
56
+ st.title("FormIQ: Intelligent Document Parser")
57
+ st.markdown("""
58
+ Upload your documents to extract and validate information using advanced AI models.
59
+ """)
60
+
61
+ # Sidebar
62
+ with st.sidebar:
63
+ st.header("Settings")
64
+ document_type = st.selectbox(
65
+ "Document Type",
66
+ options=SUPPORTED_DOCUMENT_TYPES,
67
+ index=0
68
+ )
69
+
70
+ confidence_threshold = st.slider(
71
+ "Confidence Threshold",
72
+ min_value=0.0,
73
+ max_value=1.0,
74
+ value=0.5,
75
+ step=0.05
76
+ )
77
+
78
+ st.markdown("---")
79
+ st.markdown("### About")
80
+ st.markdown("""
81
+ FormIQ uses LayoutLMv3 and GPT-4 to extract and validate information from documents.
82
+ """)
83
+
84
+ # Main content
85
+ uploaded_file = st.file_uploader(
86
+ "Upload Document",
87
+ type=["png", "jpg", "jpeg", "pdf"],
88
+ help="Upload a document image to process"
89
+ )
90
+
91
+ if uploaded_file is not None:
92
+ # Display uploaded image
93
+ image = Image.open(uploaded_file)
94
+ st.image(image, caption="Uploaded Document", width=600)
95
+
96
+ # Process button
97
+ if st.button("Process Document"):
98
+ with st.spinner("Processing document..."):
99
+ try:
100
+ # Save the uploaded file to a temporary location
101
+ temp_path = "temp_uploaded_image.jpg"
102
+ image.save(temp_path)
103
+
104
+ # Extract fields using OCR + regex
105
+ fields = extract_fields(temp_path)
106
+
107
+ # Extract with Perplexity LLM using the provided API key
108
+ with st.spinner("Extracting structured data with Perplexity LLM..."):
109
+ try:
110
+ llm_result = extract_with_perplexity_llm(pytesseract.image_to_string(Image.open(temp_path)))
111
+ st.subheader("Structured Data (Perplexity LLM)")
112
+ st.code(llm_result, language="json")
113
+
114
+ # Extract and save JSON to DynamoDB
115
+ raw_json = extract_json_from_llm_output(llm_result)
116
+ if raw_json:
117
+ try:
118
+ llm_data = json.loads(raw_json)
119
+ if save_to_dynamodb(llm_data):
120
+ st.success("Data saved to DynamoDB!")
121
+ except Exception as e:
122
+ st.error(f"Failed to parse/save JSON: {e}")
123
+ else:
124
+ st.error("No valid JSON found in LLM output.")
125
+ except Exception as e:
126
+ st.error(f"LLM extraction failed: {e}")
127
+
128
+ except Exception as e:
129
+ logger.error(f"Error processing document: {str(e)}")
130
+ st.error(f"Error processing document: {str(e)}")
131
+
132
+ def display_results(results: Dict[str, Any]):
133
+ """Display extraction and validation results."""
134
+
135
+ # Create tabs for different views
136
+ tab1, tab2, tab3 = st.tabs(["Extracted Fields", "Validation", "Visualization"])
137
+
138
+ with tab1:
139
+ st.subheader("Extracted Fields")
140
+ if "fields" in results["extraction"]:
141
+ fields_df = pd.DataFrame(results["extraction"]["fields"])
142
+ st.dataframe(fields_df)
143
+ else:
144
+ st.info("No fields extracted")
145
+
146
+ with tab2:
147
+ st.subheader("Validation Results")
148
+ validation = results["validation"]
149
+
150
+ # Display validation status
151
+ status_color = "green" if validation["is_valid"] else "red"
152
+ st.markdown(f"### Status: :{status_color}[{validation['is_valid']}]")
153
+
154
+ # Display validation errors if any
155
+ if validation["validation_errors"]:
156
+ st.error("Validation Errors:")
157
+ for error in validation["validation_errors"]:
158
+ st.markdown(f"- {error}")
159
+
160
+ # Display confidence score
161
+ st.metric(
162
+ "Overall Confidence",
163
+ f"{validation['confidence_score']:.2%}"
164
+ )
165
+
166
+ with tab3:
167
+ st.subheader("Confidence Visualization")
168
+ if "confidence_scores" in results["extraction"]["metadata"]:
169
+ scores = results["extraction"]["metadata"]["confidence_scores"]
170
+
171
+ # Create confidence distribution plot
172
+ fig = px.histogram(
173
+ x=scores,
174
+ nbins=20,
175
+ title="Confidence Score Distribution",
176
+ labels={"x": "Confidence Score", "y": "Count"}
177
+ )
178
+ st.plotly_chart(fig)
179
+
180
+ # Display heatmap if available
181
+ if "bbox" in results["extraction"]["fields"][0]:
182
+ st.subheader("Field Location Heatmap")
183
+ # TODO: Implement heatmap visualization
184
+ st.info("Heatmap visualization coming soon!")
185
+
186
+ def group_tokens_by_label(tokens, labels):
187
+ structured = {}
188
+ current_label = None
189
+ current_tokens = []
190
+ for token, label in zip(tokens, labels):
191
+ if label != current_label:
192
+ if current_label is not None:
193
+ structured.setdefault(current_label, []).append(' '.join(current_tokens))
194
+ current_label = label
195
+ current_tokens = [token]
196
+ else:
197
+ current_tokens.append(token)
198
+ if current_label is not None:
199
+ structured.setdefault(current_label, []).append(' '.join(current_tokens))
200
+ return structured
201
+
202
+ def extract_fields(image_path):
203
+ # OCR
204
+ text = pytesseract.image_to_string(Image.open(image_path))
205
+
206
+ # Display OCR output for debugging
207
+ st.subheader("Raw OCR Output (for debugging)")
208
+ st.code(text)
209
+
210
+ # Improved Regex patterns for fields
211
+ patterns = {
212
+ "name": r"Mrs\s+\w+\s+\w+",
213
+ "date": r"Date[:\s]+([\d/]+)",
214
+ "product": r"\d+\s+\w+.*Style\s+\d+",
215
+ "amount_paid": r"Total Paid\s+\$?([\d.,]+)",
216
+ # Improved pattern for receipt number (handles optional dot, colon, spaces)
217
+ "receipt_no": r"Receipt No\.?\s*:?\s*(\d+)"
218
+ }
219
+
220
+ results = {}
221
+ for field, pattern in patterns.items():
222
+ match = re.search(pattern, text, re.IGNORECASE)
223
+ if match:
224
+ results[field] = match.group(1) if match.groups() else match.group(0)
225
+ else:
226
+ results[field] = None
227
+
228
+ return results
229
+
230
+ def extract_with_perplexity_llm(ocr_text):
231
+ prompt = f"""
232
+ Extract the following fields from this receipt text:
233
+ - name
234
+ - date
235
+ - product
236
+ - amount_paid
237
+ - receipt_no
238
+
239
+ Text:
240
+ \"\"\"{ocr_text}\"\"\"
241
+
242
+ Return the result as a JSON object with those fields.
243
+ """
244
+ messages = [
245
+ {
246
+ "role": "system",
247
+ "content": (
248
+ "You are an artificial intelligence assistant. "
249
+ "Answer user questions as concisely and directly as possible. "
250
+ "Limit your responses to 2-3 sentences unless the user asks for more detail."
251
+ )
252
+ },
253
+ {
254
+ "role": "user",
255
+ "content": prompt
256
+ }
257
+ ]
258
+ response = client.chat.completions.create(
259
+ model="sonar-pro", # Use a valid model name for your account
260
+ messages=messages,
261
+ )
262
+ return response.choices[0].message.content
263
+
264
+ def interactive_chatbot_ui():
265
+ st.header("🤖 Chatbot")
266
+ if "chat_history" not in st.session_state:
267
+ st.session_state.chat_history = []
268
+
269
+ # Display chat history as chat bubbles
270
+ for sender, msg in st.session_state.chat_history:
271
+ if sender == "You":
272
+ st.markdown(f"<div style='text-align: right; background: #262730; color: #fff; padding: 8px 12px; border-radius: 12px; margin: 4px 0 4px 40px;'><b>You:</b> {msg}</div>", unsafe_allow_html=True)
273
+ else:
274
+ st.markdown(f"<div style='text-align: left; background: #31333F; color: #fff; padding: 8px 12px; border-radius: 12px; margin: 4px 40px 4px 0;'><b>Bot:</b> {msg}</div>", unsafe_allow_html=True)
275
+
276
+ # Input at the bottom
277
+ with st.form(key="chat_form", clear_on_submit=True):
278
+ user_input = st.text_input("Type your message...", key="chat_input_main", placeholder="Ask me anything...")
279
+ submitted = st.form_submit_button("Send")
280
+ if submitted and user_input:
281
+ st.session_state.chat_history.append(("You", user_input))
282
+ try:
283
+ response = requests.post(
284
+ f"{API_URL}/chat",
285
+ json={"question": user_input}
286
+ )
287
+ if response.status_code == 200:
288
+ bot_reply = response.json()["answer"]
289
+ else:
290
+ bot_reply = f"Error: Server returned status code {response.status_code}"
291
+ except Exception as e:
292
+ bot_reply = f"Error: {e}"
293
+ st.session_state.chat_history.append(("Bot", bot_reply))
294
+
295
+ if __name__ == "__main__":
296
+
297
+ main()
298
+ st.markdown("---")
299
+ interactive_chatbot_ui()
src/frontend/temp_uploaded_image.jpg ADDED
src/models/layoutlm.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
3
+ from PIL import Image
4
+ import numpy as np
5
+ from typing import Dict, List, Tuple, Optional
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class FormIQModel:
11
+ def __init__(
12
+ self,
13
+ model_name: str = "microsoft/layoutlmv3-base",
14
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
15
+ ):
16
+ """Initialize the FormIQ model with LayoutLMv3.
17
+
18
+ Args:
19
+ model_name: Name of the pre-trained model to use
20
+ device: Device to run the model on ('cuda' or 'cpu')
21
+ """
22
+ self.device = device
23
+ self.processor = LayoutLMv3Processor.from_pretrained(model_name)
24
+ self.model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
25
+ self.model.to(device)
26
+ logger.info(f"Model initialized on {device}")
27
+
28
+ def preprocess_image(self, image: Image.Image) -> Dict[str, torch.Tensor]:
29
+ """Preprocess the input image for the model.
30
+
31
+ Args:
32
+ image: PIL Image to process
33
+
34
+ Returns:
35
+ Dictionary of processed inputs
36
+ """
37
+ try:
38
+ # Process image and text
39
+ encoding = self.processor(
40
+ image,
41
+ return_tensors="pt",
42
+ truncation=True,
43
+ max_length=512
44
+ )
45
+
46
+ # Move tensors to device
47
+ encoding = {k: v.to(self.device) for k, v in encoding.items()}
48
+ return encoding
49
+
50
+ except Exception as e:
51
+ logger.error(f"Error preprocessing image: {str(e)}")
52
+ raise
53
+
54
+ def predict(
55
+ self,
56
+ image: Image.Image,
57
+ confidence_threshold: float = 0.5
58
+ ) -> Dict[str, List[Dict[str, any]]]:
59
+ """Extract information from the document image.
60
+
61
+ Args:
62
+ image: PIL Image of the document
63
+ confidence_threshold: Minimum confidence score for predictions
64
+
65
+ Returns:
66
+ Dictionary containing extracted fields and their metadata
67
+ """
68
+ try:
69
+ # Preprocess image
70
+ inputs = self.preprocess_image(image)
71
+
72
+ # Get model predictions
73
+ with torch.no_grad():
74
+ outputs = self.model(**inputs)
75
+ predictions = outputs.logits.argmax(-1).squeeze().cpu().numpy()
76
+ scores = torch.softmax(outputs.logits, dim=-1).max(-1)[0].squeeze().cpu().numpy()
77
+
78
+ # Process predictions
79
+ extracted_fields = self._process_predictions(predictions, scores, confidence_threshold)
80
+
81
+ return {
82
+ "fields": extracted_fields,
83
+ "metadata": {
84
+ "confidence_scores": scores.tolist(),
85
+ "model_version": self.model.config.model_type
86
+ }
87
+ }
88
+
89
+ except Exception as e:
90
+ logger.error(f"Error during prediction: {str(e)}")
91
+ raise
92
+
93
+ def _process_predictions(
94
+ self,
95
+ predictions: np.ndarray,
96
+ scores: np.ndarray,
97
+ confidence_threshold: float
98
+ ) -> List[Dict[str, any]]:
99
+ """Process raw model predictions into structured output.
100
+
101
+ Args:
102
+ predictions: Array of predicted class indices
103
+ scores: Array of confidence scores
104
+ confidence_threshold: Minimum confidence score
105
+
106
+ Returns:
107
+ List of dictionaries containing field information
108
+ """
109
+ # TODO: Implement field-specific post-processing
110
+ # This is a placeholder implementation
111
+ processed_fields = []
112
+
113
+ for pred, score in zip(predictions, scores):
114
+ if score >= confidence_threshold:
115
+ field_info = {
116
+ "label": self.model.config.id2label[pred],
117
+ "confidence": float(score),
118
+ "bbox": None # TODO: Add bounding box information
119
+ }
120
+ processed_fields.append(field_info)
121
+
122
+ return processed_fields
123
+
124
+ def validate_extraction(
125
+ self,
126
+ extracted_fields: Dict[str, List[Dict[str, any]]],
127
+ document_type: str
128
+ ) -> Dict[str, any]:
129
+ """Validate extracted fields based on document type rules.
130
+
131
+ Args:
132
+ extracted_fields: Dictionary of extracted fields
133
+ document_type: Type of document (e.g., 'invoice', 'receipt')
134
+
135
+ Returns:
136
+ Dictionary containing validation results
137
+ """
138
+ # TODO: Implement field validation logic
139
+ # This is a placeholder implementation
140
+ return {
141
+ "is_valid": True,
142
+ "validation_errors": [],
143
+ "confidence_score": 1.0
144
+ }
src/scripts/train.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ from omegaconf import DictConfig, OmegaConf
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
+ from datasets import load_dataset
7
+ import mlflow
8
+ import wandb
9
+ from pathlib import Path
10
+ import logging
11
+ from typing import Dict, Any
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class FormIQTrainer:
20
+ def __init__(self, config: DictConfig):
21
+ """Initialize the trainer with configuration."""
22
+ self.config = config
23
+ self.device = torch.device(config.model.device)
24
+
25
+ # Initialize model and processor
26
+ self.processor = LayoutLMv3Processor.from_pretrained(config.model.name)
27
+ self.model = LayoutLMv3ForTokenClassification.from_pretrained(
28
+ config.model.name,
29
+ num_labels=config.model.num_labels
30
+ )
31
+ self.model.to(self.device)
32
+
33
+ # Initialize optimizer
34
+ self.optimizer = torch.optim.AdamW(
35
+ self.model.parameters(),
36
+ lr=config.training.learning_rate,
37
+ weight_decay=config.training.weight_decay
38
+ )
39
+
40
+ # Setup logging
41
+ self.setup_logging()
42
+
43
+ def setup_logging(self):
44
+ """Setup MLflow and W&B logging."""
45
+ if self.config.logging.mlflow.enabled:
46
+ mlflow.set_tracking_uri(self.config.logging.mlflow.tracking_uri)
47
+ mlflow.set_experiment(self.config.logging.mlflow.experiment_name)
48
+
49
+ if self.config.logging.wandb.enabled:
50
+ wandb.init(
51
+ project=self.config.logging.wandb.project,
52
+ entity=self.config.logging.wandb.entity,
53
+ config=OmegaConf.to_container(self.config, resolve=True)
54
+ )
55
+
56
+ def prepare_dataset(self):
57
+ """Prepare the dataset for training."""
58
+ # TODO: Implement dataset preparation
59
+ # This is a placeholder implementation
60
+ return None, None
61
+
62
+ def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
63
+ """Train for one epoch.
64
+
65
+ Args:
66
+ train_loader: DataLoader for training data
67
+
68
+ Returns:
69
+ Dictionary containing training metrics
70
+ """
71
+ self.model.train()
72
+ total_loss = 0
73
+ correct_predictions = 0
74
+ total_predictions = 0
75
+
76
+ progress_bar = tqdm(train_loader, desc="Training")
77
+ for batch in progress_bar:
78
+ # Move batch to device
79
+ batch = {k: v.to(self.device) for k, v in batch.items()}
80
+
81
+ # Forward pass
82
+ outputs = self.model(**batch)
83
+ loss = outputs.loss
84
+
85
+ # Backward pass
86
+ loss.backward()
87
+
88
+ # Update weights
89
+ self.optimizer.step()
90
+ self.optimizer.zero_grad()
91
+
92
+ # Update metrics
93
+ total_loss += loss.item()
94
+ predictions = outputs.logits.argmax(-1)
95
+ correct_predictions += (predictions == batch["labels"]).sum().item()
96
+ total_predictions += batch["labels"].numel()
97
+
98
+ # Update progress bar
99
+ progress_bar.set_postfix({
100
+ "loss": loss.item(),
101
+ "accuracy": correct_predictions / total_predictions
102
+ })
103
+
104
+ # Calculate epoch metrics
105
+ metrics = {
106
+ "train_loss": total_loss / len(train_loader),
107
+ "train_accuracy": correct_predictions / total_predictions
108
+ }
109
+
110
+ return metrics
111
+
112
+ def evaluate(self, eval_loader: DataLoader) -> Dict[str, float]:
113
+ """Evaluate the model.
114
+
115
+ Args:
116
+ eval_loader: DataLoader for evaluation data
117
+
118
+ Returns:
119
+ Dictionary containing evaluation metrics
120
+ """
121
+ self.model.eval()
122
+ total_loss = 0
123
+ correct_predictions = 0
124
+ total_predictions = 0
125
+
126
+ with torch.no_grad():
127
+ for batch in tqdm(eval_loader, desc="Evaluating"):
128
+ # Move batch to device
129
+ batch = {k: v.to(self.device) for k, v in batch.items()}
130
+
131
+ # Forward pass
132
+ outputs = self.model(**batch)
133
+ loss = outputs.loss
134
+
135
+ # Update metrics
136
+ total_loss += loss.item()
137
+ predictions = outputs.logits.argmax(-1)
138
+ correct_predictions += (predictions == batch["labels"]).sum().item()
139
+ total_predictions += batch["labels"].numel()
140
+
141
+ # Calculate evaluation metrics
142
+ metrics = {
143
+ "eval_loss": total_loss / len(eval_loader),
144
+ "eval_accuracy": correct_predictions / total_predictions
145
+ }
146
+
147
+ return metrics
148
+
149
+ def train(self):
150
+ """Train the model."""
151
+ # Prepare datasets
152
+ train_loader, eval_loader = self.prepare_dataset()
153
+
154
+ # Training loop
155
+ best_eval_loss = float('inf')
156
+ for epoch in range(self.config.training.num_epochs):
157
+ logger.info(f"Epoch {epoch + 1}/{self.config.training.num_epochs}")
158
+
159
+ # Train
160
+ train_metrics = self.train_epoch(train_loader)
161
+
162
+ # Evaluate
163
+ eval_metrics = self.evaluate(eval_loader)
164
+
165
+ # Log metrics
166
+ metrics = {**train_metrics, **eval_metrics}
167
+ if self.config.logging.mlflow.enabled:
168
+ mlflow.log_metrics(metrics, step=epoch)
169
+ if self.config.logging.wandb.enabled:
170
+ wandb.log(metrics, step=epoch)
171
+
172
+ # Save best model
173
+ if eval_metrics["eval_loss"] < best_eval_loss:
174
+ best_eval_loss = eval_metrics["eval_loss"]
175
+ self.save_model("best_model")
176
+
177
+ # Save checkpoint
178
+ self.save_model(f"checkpoint_epoch_{epoch + 1}")
179
+
180
+ def save_model(self, name: str):
181
+ """Save the model.
182
+
183
+ Args:
184
+ name: Name of the saved model
185
+ """
186
+ save_path = Path(self.config.model.save_dir) / name
187
+ save_path.mkdir(parents=True, exist_ok=True)
188
+
189
+ self.model.save_pretrained(save_path)
190
+ self.processor.save_pretrained(save_path)
191
+
192
+ if self.config.logging.mlflow.enabled:
193
+ mlflow.log_artifacts(str(save_path), f"models/{name}")
194
+
195
+ @hydra.main(config_path="../config", config_name="config")
196
+ def main(config: DictConfig):
197
+ """Main training function."""
198
+ trainer = FormIQTrainer(config)
199
+ trainer.train()
200
+
201
+ if __name__ == "__main__":
202
+ main()
tests/test_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from src.models.layoutlm import FormIQModel
6
+
7
+ @pytest.fixture
8
+ def model():
9
+ """Create a model instance for testing."""
10
+ return FormIQModel(device="cpu")
11
+
12
+ @pytest.fixture
13
+ def sample_image():
14
+ """Create a sample image for testing."""
15
+ # Create a random image
16
+ image_array = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
17
+ return Image.fromarray(image_array)
18
+
19
+ def test_model_initialization(model):
20
+ """Test model initialization."""
21
+ assert model.device == "cpu"
22
+ assert model.model is not None
23
+ assert model.processor is not None
24
+
25
+ def test_preprocess_image(model, sample_image):
26
+ """Test image preprocessing."""
27
+ processed = model.preprocess_image(sample_image)
28
+
29
+ # Check if all required keys are present
30
+ assert "input_ids" in processed
31
+ assert "attention_mask" in processed
32
+ assert "bbox" in processed
33
+ assert "pixel_values" in processed
34
+
35
+ # Check tensor types and shapes
36
+ assert isinstance(processed["input_ids"], torch.Tensor)
37
+ assert isinstance(processed["attention_mask"], torch.Tensor)
38
+ assert isinstance(processed["bbox"], torch.Tensor)
39
+ assert isinstance(processed["pixel_values"], torch.Tensor)
40
+
41
+ def test_predict(model, sample_image):
42
+ """Test prediction functionality."""
43
+ results = model.predict(sample_image, confidence_threshold=0.5)
44
+
45
+ # Check result structure
46
+ assert "fields" in results
47
+ assert "metadata" in results
48
+ assert isinstance(results["fields"], list)
49
+ assert isinstance(results["metadata"], dict)
50
+
51
+ # Check metadata
52
+ assert "confidence_scores" in results["metadata"]
53
+ assert "model_version" in results["metadata"]
54
+
55
+ def test_validate_extraction(model):
56
+ """Test field validation."""
57
+ # Create sample extraction results
58
+ sample_extraction = {
59
+ "fields": [
60
+ {"label": "amount", "confidence": 0.95, "value": "100.00"},
61
+ {"label": "date", "confidence": 0.85, "value": "2024-03-20"}
62
+ ]
63
+ }
64
+
65
+ # Test validation
66
+ validation_results = model.validate_extraction(
67
+ sample_extraction,
68
+ document_type="invoice"
69
+ )
70
+
71
+ # Check validation results structure
72
+ assert "is_valid" in validation_results
73
+ assert "validation_errors" in validation_results
74
+ assert "confidence_score" in validation_results
75
+
76
+ # Check types
77
+ assert isinstance(validation_results["is_valid"], bool)
78
+ assert isinstance(validation_results["validation_errors"], list)
79
+ assert isinstance(validation_results["confidence_score"], float)
80
+
81
+ def test_error_handling(model):
82
+ """Test error handling."""
83
+ # Test with invalid image
84
+ with pytest.raises(Exception):
85
+ model.predict(Image.new("RGB", (0, 0)))
86
+
87
+ # Test with invalid confidence threshold
88
+ with pytest.raises(Exception):
89
+ model.predict(Image.new("RGB", (224, 224)), confidence_threshold=2.0)
90
+
91
+ # Test with invalid document type
92
+ with pytest.raises(Exception):
93
+ model.validate_extraction({}, document_type="invalid_type")
transformers ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit b3db4ddb2255bb4c8c4340fa630a53ac1cc53dee