|
""" |
|
Helper functions for structured OpenAI API calls using Pydantic models. |
|
Includes NER and RAG-specific prompting logic with retry and error handling. |
|
""" |
|
|
|
import os |
|
from typing import Literal, Optional, overload, Union |
|
|
|
from dotenv import load_dotenv |
|
from openai import AsyncOpenAI |
|
from pydantic import BaseModel, Field |
|
from tenacity import retry, retry_if_result, stop_after_attempt, wait_random_exponential |
|
from tqdm.auto import tqdm |
|
|
|
load_dotenv() |
|
api_key = os.getenv("OPENAI_API_KEY") |
|
if not api_key: |
|
raise EnvironmentError("Missing OPENAI_API_KEY in environment.") |
|
client = AsyncOpenAI(api_key=api_key, timeout=120.0) |
|
|
|
|
|
class NEROutput(BaseModel): |
|
answer: list[str] = Field(..., description="List of extracted entities") |
|
|
|
|
|
class RAGOutput(BaseModel): |
|
answer: str = Field(..., description="Closest match to input term") |
|
reason: str = Field(..., description="Why you chose the answer match to input term") |
|
|
|
|
|
def is_invalid_result(result): |
|
return result is None |
|
|
|
@overload |
|
async def ask_openai(user_prompt: str, usage: Literal["ner"], model: str = ...) -> Optional[list[str]]: ... |
|
@overload |
|
async def ask_openai(user_prompt: str, usage: Literal["rag"], model: str = ...) -> Optional[str]: ... |
|
|
|
@retry( |
|
retry=retry_if_result(is_invalid_result), |
|
wait=wait_random_exponential(min=1, max=60), |
|
stop=stop_after_attempt(6), |
|
) |
|
async def ask_openai( |
|
user_prompt: str, |
|
usage: Literal['ner', 'rag'], |
|
model: str = "o4-mini-2025-04-16", |
|
) -> Optional[Union[list[str], str]]: |
|
""" |
|
Function to interact with the OpenAI API. |
|
""" |
|
if model in ["chatgpt-4o-latest", "o1-mini"]: |
|
raise ValueError(f"Model {model} does not support structured outputs.") |
|
|
|
response_format = NEROutput if usage == 'ner' else RAGOutput |
|
|
|
try: |
|
response = await client.responses.parse( |
|
model=model, |
|
input=[{"role": "user", "content": user_prompt}], |
|
text_format=response_format, |
|
|
|
) |
|
response_obj = response.output_parsed |
|
return response_obj.answer if response_obj else None |
|
|
|
except Exception as e: |
|
tqdm.write(f"❌ Unexpected error. Error: {e}") |
|
raise |
|
|