sachin commited on
Commit
75bbaa5
·
1 Parent(s): badf26d
Files changed (1) hide show
  1. src/server/main.py +97 -1
src/server/main.py CHANGED
@@ -9,7 +9,7 @@ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, Uploa
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from PIL import Image
12
- from pydantic import BaseModel, field_validator
13
  from pydantic_settings import BaseSettings
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
@@ -26,6 +26,10 @@ from starlette.responses import StreamingResponse
26
  from logging_config import logger
27
  from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
 
 
 
 
29
 
30
  # Device setup
31
  if torch.cuda.is_available():
@@ -296,6 +300,14 @@ class SynthesizeRequest(BaseModel):
296
  class KannadaSynthesizeRequest(BaseModel):
297
  text: str
298
 
 
 
 
 
 
 
 
 
299
  # TTS Functions
300
  def load_audio_from_url(url: str):
301
  response = requests.get(url)
@@ -762,6 +774,90 @@ async def visual_query(
762
  logger.error(f"Error processing request: {str(e)}")
763
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
  @app.post("/v1/chat_v2", response_model=ChatResponse)
766
  @limiter.limit(settings.chat_rate_limit)
767
  async def chat_v2(
 
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from PIL import Image
12
+ from pydantic import BaseModel, field_validator, Field
13
  from pydantic_settings import BaseSettings
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
 
26
  from logging_config import logger
27
  from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
29
+ import base64
30
+ from io import BytesIO
31
+ from pypdf import PdfReader
32
+ from olmocr.data.renderpdf import render_pdf_to_base64png
33
 
34
  # Device setup
35
  if torch.cuda.is_available():
 
300
  class KannadaSynthesizeRequest(BaseModel):
301
  text: str
302
 
303
+ class ExtractTextRequest(BaseModel):
304
+ page_number: int = Field(
305
+ default=1,
306
+ description="The page number to extract text from (1-based indexing). Must be a positive integer.",
307
+ ge=1,
308
+ example=1
309
+ )
310
+
311
  # TTS Functions
312
  def load_audio_from_url(url: str):
313
  response = requests.get(url)
 
774
  logger.error(f"Error processing request: {str(e)}")
775
  raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
776
 
777
+ @app.post(
778
+ "/v1/extract-text-visual-query/",
779
+ response_model=dict,
780
+ summary="Extract text from a PDF page using visual query",
781
+ description=(
782
+ "Extracts text from a specific page of a PDF file by rendering it as an image and processing it with the internal vision query model. "
783
+ "The query 'describe the image' is used to generate a description of the page content."
784
+ ),
785
+ response_description="A JSON object containing the extracted text from the specified page."
786
+ )
787
+ async def extract_text_visual_query(
788
+ file: UploadFile = File(..., description="The PDF file to process. Must be a valid PDF."),
789
+ page_number: int = Body(
790
+ default=1,
791
+ embed=True,
792
+ description=ExtractTextRequest.model_fields["page_number"].description,
793
+ ge=1,
794
+ example=1
795
+ )
796
+ ):
797
+ """
798
+ Extract text from a specific page of a PDF file using the internal vision query model.
799
+
800
+ Args:
801
+ file (UploadFile): The PDF file to process.
802
+ page_number (int): The page number to extract text from (1-based indexing). Defaults to 1.
803
+
804
+ Returns:
805
+ JSONResponse: A dictionary containing:
806
+ - page_content: The extracted text from the specified page via the vision query model.
807
+
808
+ Raises:
809
+ HTTPException: If the file is not a PDF, the page number is invalid, or processing fails.
810
+
811
+ Example:
812
+ ```json
813
+ {"page_content": "Here’s a summary of the page in one sentence:\\n\\nThe page displays..."}
814
+ ```
815
+ """
816
+ try:
817
+ # Validate file type
818
+ if not file.filename.lower().endswith(".pdf"):
819
+ raise HTTPException(status_code=400, detail="Only PDF files are supported.")
820
+
821
+ # Save the uploaded PDF to a temporary file
822
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
823
+ temp_file.write(await file.read())
824
+ temp_file_path = temp_file.name
825
+
826
+ # Render the specified page to an image
827
+ try:
828
+ image_base64 = render_pdf_to_base64png(
829
+ temp_file_path, page_number, target_longest_image_dim=1024
830
+ )
831
+ except Exception as e:
832
+ raise HTTPException(status_code=500, detail=f"Failed to render PDF page: {str(e)}")
833
+
834
+ # Decode base64 image to PIL Image
835
+ try:
836
+ image_bytes = base64.b64decode(image_base64)
837
+ image = Image.open(BytesIO(image_bytes))
838
+ except Exception as e:
839
+ raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}")
840
+
841
+ # Process image with vision query
842
+ try:
843
+ page_content = await llm_manager.vision_query(image, "describe the image")
844
+ except Exception as e:
845
+ raise HTTPException(status_code=500, detail=f"Vision query processing failed: {str(e)}")
846
+
847
+ # Clean up temporary file
848
+ os.remove(temp_file_path)
849
+
850
+ return JSONResponse(content={"page_content": page_content})
851
+
852
+ except Exception as e:
853
+ # Clean up in case of error
854
+ if 'temp_file_path' in locals():
855
+ try:
856
+ os.remove(temp_file_path)
857
+ except:
858
+ pass
859
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
860
+
861
  @app.post("/v1/chat_v2", response_model=ChatResponse)
862
  @limiter.limit(settings.chat_rate_limit)
863
  async def chat_v2(