BioMedNorm-MCP-Server / openai_utils.py
RohanKarthikeyan's picture
Upload 9 files
046bc11 verified
"""
Helper functions for structured OpenAI API calls using Pydantic models.
Includes NER and RAG-specific prompting logic with retry and error handling.
"""
import os
from typing import Literal, Optional, overload, Union
from dotenv import load_dotenv
from openai import AsyncOpenAI
from pydantic import BaseModel, Field
from tenacity import retry, retry_if_result, stop_after_attempt, wait_random_exponential
from tqdm.auto import tqdm
load_dotenv() # take environment variables from .env
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise EnvironmentError("Missing OPENAI_API_KEY in environment.")
client = AsyncOpenAI(api_key=api_key, timeout=120.0)
class NEROutput(BaseModel):
answer: list[str] = Field(..., description="List of extracted entities")
class RAGOutput(BaseModel):
answer: str = Field(..., description="Closest match to input term")
reason: str = Field(..., description="Why you chose the answer match to input term")
def is_invalid_result(result):
return result is None
@overload
async def ask_openai(user_prompt: str, usage: Literal["ner"], model: str = ...) -> Optional[list[str]]: ...
@overload
async def ask_openai(user_prompt: str, usage: Literal["rag"], model: str = ...) -> Optional[str]: ...
@retry(
retry=retry_if_result(is_invalid_result),
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
)
async def ask_openai(
user_prompt: str,
usage: Literal['ner', 'rag'],
model: str = "o4-mini-2025-04-16",
) -> Optional[Union[list[str], str]]:
"""
Function to interact with the OpenAI API.
"""
if model in ["chatgpt-4o-latest", "o1-mini"]:
raise ValueError(f"Model {model} does not support structured outputs.")
response_format = NEROutput if usage == 'ner' else RAGOutput
try:
response = await client.responses.parse(
model=model,
input=[{"role": "user", "content": user_prompt}],
text_format=response_format,
# temperature=0.05,
)
response_obj = response.output_parsed
return response_obj.answer if response_obj else None
except Exception as e:
tqdm.write(f"❌ Unexpected error. Error: {e}")
raise