Spaces:
Sleeping
Sleeping
import os | |
import pytest | |
from unittest.mock import Mock, patch, MagicMock | |
from app import generate_prompt | |
TEST_SYSTEM_PROMPT = "PROMPT ROULETTE ..." | |
TEST_OPENAI_API_KEY = "sk-prompt-roulette-..." | |
TEST_OPENAI_CHAT_RESPONSE_CONTENT = "You are ..." | |
TEST_GRADIO_SESSION_HASH = "prompt-roulette-session-..." | |
class TestGeneratePrompt: | |
def test_gradio_app_loads(self): | |
from app import demo | |
assert demo is not None | |
def test_generate_prompt_success(self, mock_openai_class): | |
"""Test a successful OpenAI API call.""" | |
# Mock the OpenAI client and response | |
mock_client = Mock() | |
mock_openai_class.return_value = mock_client | |
mock_response = Mock() | |
mock_response.choices = [Mock()] | |
mock_response.choices[0].message = Mock() | |
mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT | |
mock_client.chat.completions.create.return_value = mock_response | |
# Mock the gradio request | |
mock_gradio_request = Mock() | |
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH | |
# Call the function | |
result = generate_prompt(mock_gradio_request) | |
assert result == TEST_OPENAI_CHAT_RESPONSE_CONTENT | |
def test_generate_prompt_api_failure(self, mock_openai_class): | |
"""Test API failure handling.""" | |
# Mock the OpenAI client and response | |
mock_client = Mock() | |
mock_openai_class.return_value = mock_client | |
mock_client.chat.completions.create.side_effect = Exception("OpenAI is down.") | |
# Mock the gradio request | |
mock_gradio_request = Mock() | |
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH | |
# Call the function | |
result = generate_prompt(mock_gradio_request) | |
assert result == "⚠️ Could not generate a prompt. No fish today." | |
def test_logging_on_success(self, mock_openai_class, mock_logger): | |
"""Test that logging works correctly on successful API call.""" | |
api_total_token_count = 500 | |
# Mock the OpenAI client and response | |
mock_client = Mock() | |
mock_openai_class.return_value = mock_client | |
mock_response = Mock() | |
mock_response.choices = [Mock()] | |
mock_response.choices[0].message = Mock() | |
mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT | |
mock_response.usage.total_tokens = api_total_token_count | |
mock_client.chat.completions.create.return_value = mock_response | |
# Mock the gradio request | |
mock_gradio_request = Mock() | |
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH | |
# Call the function | |
generate_prompt(mock_gradio_request) | |
mock_logger.info.assert_any_call( | |
f"Making OpenAI API request - Session: {TEST_GRADIO_SESSION_HASH}" | |
) | |
mock_logger.info.assert_any_call( | |
f"API request successful - Session: {TEST_GRADIO_SESSION_HASH}" | |
f" - tokens used: {api_total_token_count}" | |
) | |
def test_logging_on_failure(self, mock_openai_class, mock_logger): | |
"""Test that logging works correctly on failed API call.""" | |
api_error_message = "OpenAI is down." | |
# Mock the OpenAI client and response | |
mock_client = Mock() | |
mock_openai_class.return_value = mock_client | |
mock_client.chat.completions.create.side_effect = Exception(api_error_message) | |
# Mock the gradio request | |
mock_gradio_request = Mock() | |
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH | |
# Call the function | |
generate_prompt(mock_gradio_request) | |
mock_logger.error.assert_any_call( | |
f"API request failed - Session: {TEST_GRADIO_SESSION_HASH} - Error: {api_error_message}" | |
) | |
def test_missing_env_vars(self): | |
"""Test what happens when SYSTEM_PROMPT and/or OPENAI_API_KEY are missing.""" | |
# Mock the gradio request | |
mock_gradio_request = Mock() | |
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH | |
# Call the function | |
result = generate_prompt(mock_gradio_request) | |
assert result == "⚠️ Service temporarily unavailable. Please try again later." | |