import os import pytest from unittest.mock import Mock, patch, MagicMock from app import generate_prompt TEST_SYSTEM_PROMPT = "PROMPT ROULETTE ..." TEST_OPENAI_API_KEY = "sk-prompt-roulette-..." TEST_OPENAI_CHAT_RESPONSE_CONTENT = "You are ..." TEST_GRADIO_SESSION_HASH = "prompt-roulette-session-..." class TestGeneratePrompt: def test_gradio_app_loads(self): from app import demo assert demo is not None @patch("app.OpenAI") @patch.dict(os.environ, { "SYSTEM_PROMPT": TEST_SYSTEM_PROMPT, "OPENAI_API_KEY": TEST_OPENAI_API_KEY }) def test_generate_prompt_success(self, mock_openai_class): """Test a successful OpenAI API call.""" # Mock the OpenAI client and response mock_client = Mock() mock_openai_class.return_value = mock_client mock_response = Mock() mock_response.choices = [Mock()] mock_response.choices[0].message = Mock() mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT mock_client.chat.completions.create.return_value = mock_response # Mock the gradio request mock_gradio_request = Mock() mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH # Call the function result = generate_prompt(mock_gradio_request) assert result == TEST_OPENAI_CHAT_RESPONSE_CONTENT @patch("app.OpenAI") @patch.dict(os.environ, { "SYSTEM_PROMPT": TEST_SYSTEM_PROMPT, "OPENAI_API_KEY": TEST_OPENAI_API_KEY }) def test_generate_prompt_api_failure(self, mock_openai_class): """Test API failure handling.""" # Mock the OpenAI client and response mock_client = Mock() mock_openai_class.return_value = mock_client mock_client.chat.completions.create.side_effect = Exception("OpenAI is down.") # Mock the gradio request mock_gradio_request = Mock() mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH # Call the function result = generate_prompt(mock_gradio_request) assert result == "⚠️ Could not generate a prompt. No fish today." @patch("app.logger") @patch("app.OpenAI") @patch.dict(os.environ, { "SYSTEM_PROMPT": TEST_SYSTEM_PROMPT, "OPENAI_API_KEY": TEST_OPENAI_API_KEY }) def test_logging_on_success(self, mock_openai_class, mock_logger): """Test that logging works correctly on successful API call.""" api_total_token_count = 500 # Mock the OpenAI client and response mock_client = Mock() mock_openai_class.return_value = mock_client mock_response = Mock() mock_response.choices = [Mock()] mock_response.choices[0].message = Mock() mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT mock_response.usage.total_tokens = api_total_token_count mock_client.chat.completions.create.return_value = mock_response # Mock the gradio request mock_gradio_request = Mock() mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH # Call the function generate_prompt(mock_gradio_request) mock_logger.info.assert_any_call( f"Making OpenAI API request - Session: {TEST_GRADIO_SESSION_HASH}" ) mock_logger.info.assert_any_call( f"API request successful - Session: {TEST_GRADIO_SESSION_HASH}" f" - tokens used: {api_total_token_count}" ) @patch("app.logger") @patch("app.OpenAI") @patch.dict(os.environ, { "SYSTEM_PROMPT": TEST_SYSTEM_PROMPT, "OPENAI_API_KEY": TEST_OPENAI_API_KEY }) def test_logging_on_failure(self, mock_openai_class, mock_logger): """Test that logging works correctly on failed API call.""" api_error_message = "OpenAI is down." # Mock the OpenAI client and response mock_client = Mock() mock_openai_class.return_value = mock_client mock_client.chat.completions.create.side_effect = Exception(api_error_message) # Mock the gradio request mock_gradio_request = Mock() mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH # Call the function generate_prompt(mock_gradio_request) mock_logger.error.assert_any_call( f"API request failed - Session: {TEST_GRADIO_SESSION_HASH} - Error: {api_error_message}" ) @patch.dict(os.environ, {}, clear=True) def test_missing_env_vars(self): """Test what happens when SYSTEM_PROMPT and/or OPENAI_API_KEY are missing.""" # Mock the gradio request mock_gradio_request = Mock() mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH # Call the function result = generate_prompt(mock_gradio_request) assert result == "⚠️ Service temporarily unavailable. Please try again later."