Spaces:
Running
Running
Chandima Prabhath
Refactor LLM functions by removing unused summarize, translate, and meme functions; update generate_llm to increase max_tokens limit for improved response handling.
9e02558
| import os | |
| import time | |
| import random | |
| import logging | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| from utils import read_config | |
| # --- Load environment & config --- | |
| load_dotenv() | |
| _config = read_config()["llm"] | |
| # --- Logging setup --- | |
| LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() | |
| logger = logging.getLogger("polLLM") | |
| logger.setLevel(LOG_LEVEL) | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")) | |
| logger.addHandler(handler) | |
| # --- LLM settings from config.yaml --- | |
| _DEFAULT_MODEL = "openai-large" # _config.get("model", "openai-large") | |
| _SYSTEM_TEMPLATE = _config.get("system_prompt", "") | |
| _CHAR = _config.get("char", "Eve") | |
| # --- Custom exception --- | |
| class LLMBadRequestError(Exception): | |
| """Raised when the LLM returns HTTP 400 (Bad Request).""" | |
| pass | |
| # --- OpenAI client init --- | |
| client = OpenAI( | |
| base_url="https://text.pollinations.ai/openai", | |
| api_key="OPENAI_API_KEY" | |
| ) | |
| def _build_system_prompt() -> str: | |
| """ | |
| Substitute {char} into the system prompt template. | |
| """ | |
| return _SYSTEM_TEMPLATE.replace("{char}", _CHAR) | |
| def generate_llm( | |
| prompt: str, | |
| ) -> str: | |
| """ | |
| Send a chat-completion request to the LLM, with retries and backoff. | |
| Reads defaults from config.yaml, but can be overridden per-call. | |
| """ | |
| model = _DEFAULT_MODEL | |
| system_prompt = _build_system_prompt() | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| backoff = 1 | |
| for attempt in range(1, 6): | |
| try: | |
| seed = random.randint(0, 2**31 - 1) | |
| logger.debug(f"LLM call attempt={attempt}, model={model}, seed={seed}") | |
| resp = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| seed=seed, | |
| max_tokens=4000, | |
| ) | |
| text = resp.choices[0].message.content.strip() | |
| logger.debug("LLM response received") | |
| return text | |
| except Exception as e: | |
| if getattr(e, "status_code", None) == 400: | |
| logger.error("LLM error 400 (Bad Request): Not retrying.") | |
| raise LLMBadRequestError("LLM returned HTTP 400") | |
| logger.error(f"LLM error on attempt {attempt}: {e}") | |
| if attempt < 5: | |
| time.sleep(backoff) | |
| backoff *= 2 | |
| else: | |
| logger.critical("LLM failed after 5 attempts, raising") | |
| raise | |
| # Example local test | |
| if __name__ == "__main__": | |
| logger.info("Testing generate_llm() with a sample prompt") | |
| try: | |
| print(generate_llm("Say hello in a poetic style.")) | |
| except LLMBadRequestError as e: | |
| logger.warning(f"Test failed with bad request: {e}") | |