prompt-roulette / test_app.py
n8cha's picture
add error handling for missing env vars and respective tests
8cccf42
import os
import pytest
from unittest.mock import Mock, patch, MagicMock
from app import generate_prompt
TEST_SYSTEM_PROMPT = "PROMPT ROULETTE ..."
TEST_OPENAI_API_KEY = "sk-prompt-roulette-..."
TEST_OPENAI_CHAT_RESPONSE_CONTENT = "You are ..."
TEST_GRADIO_SESSION_HASH = "prompt-roulette-session-..."
class TestGeneratePrompt:
def test_gradio_app_loads(self):
from app import demo
assert demo is not None
@patch("app.OpenAI")
@patch.dict(os.environ, {
"SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
"OPENAI_API_KEY": TEST_OPENAI_API_KEY
})
def test_generate_prompt_success(self, mock_openai_class):
"""Test a successful OpenAI API call."""
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message = Mock()
mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT
mock_client.chat.completions.create.return_value = mock_response
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
result = generate_prompt(mock_gradio_request)
assert result == TEST_OPENAI_CHAT_RESPONSE_CONTENT
@patch("app.OpenAI")
@patch.dict(os.environ, {
"SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
"OPENAI_API_KEY": TEST_OPENAI_API_KEY
})
def test_generate_prompt_api_failure(self, mock_openai_class):
"""Test API failure handling."""
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
mock_client.chat.completions.create.side_effect = Exception("OpenAI is down.")
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
result = generate_prompt(mock_gradio_request)
assert result == "⚠️ Could not generate a prompt. No fish today."
@patch("app.logger")
@patch("app.OpenAI")
@patch.dict(os.environ, {
"SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
"OPENAI_API_KEY": TEST_OPENAI_API_KEY
})
def test_logging_on_success(self, mock_openai_class, mock_logger):
"""Test that logging works correctly on successful API call."""
api_total_token_count = 500
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message = Mock()
mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT
mock_response.usage.total_tokens = api_total_token_count
mock_client.chat.completions.create.return_value = mock_response
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
generate_prompt(mock_gradio_request)
mock_logger.info.assert_any_call(
f"Making OpenAI API request - Session: {TEST_GRADIO_SESSION_HASH}"
)
mock_logger.info.assert_any_call(
f"API request successful - Session: {TEST_GRADIO_SESSION_HASH}"
f" - tokens used: {api_total_token_count}"
)
@patch("app.logger")
@patch("app.OpenAI")
@patch.dict(os.environ, {
"SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
"OPENAI_API_KEY": TEST_OPENAI_API_KEY
})
def test_logging_on_failure(self, mock_openai_class, mock_logger):
"""Test that logging works correctly on failed API call."""
api_error_message = "OpenAI is down."
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
mock_client.chat.completions.create.side_effect = Exception(api_error_message)
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
generate_prompt(mock_gradio_request)
mock_logger.error.assert_any_call(
f"API request failed - Session: {TEST_GRADIO_SESSION_HASH} - Error: {api_error_message}"
)
@patch.dict(os.environ, {}, clear=True)
def test_missing_env_vars(self):
"""Test what happens when SYSTEM_PROMPT and/or OPENAI_API_KEY are missing."""
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
result = generate_prompt(mock_gradio_request)
assert result == "⚠️ Service temporarily unavailable. Please try again later."