Spaces:
Sleeping
Sleeping
File size: 4,964 Bytes
9b98f2e 8cccf42 9b98f2e 665c6da 8cccf42 665c6da 5bca8f8 8cccf42 5bca8f8 8cccf42 5bca8f8 8cccf42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import pytest
from unittest.mock import Mock, patch, MagicMock
from app import generate_prompt
TEST_SYSTEM_PROMPT = "PROMPT ROULETTE ..."
TEST_OPENAI_API_KEY = "sk-prompt-roulette-..."
TEST_OPENAI_CHAT_RESPONSE_CONTENT = "You are ..."
TEST_GRADIO_SESSION_HASH = "prompt-roulette-session-..."
class TestGeneratePrompt:
def test_gradio_app_loads(self):
from app import demo
assert demo is not None
@patch("app.OpenAI")
@patch.dict(os.environ, {
"SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
"OPENAI_API_KEY": TEST_OPENAI_API_KEY
})
def test_generate_prompt_success(self, mock_openai_class):
"""Test a successful OpenAI API call."""
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message = Mock()
mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT
mock_client.chat.completions.create.return_value = mock_response
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
result = generate_prompt(mock_gradio_request)
assert result == TEST_OPENAI_CHAT_RESPONSE_CONTENT
@patch("app.OpenAI")
@patch.dict(os.environ, {
"SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
"OPENAI_API_KEY": TEST_OPENAI_API_KEY
})
def test_generate_prompt_api_failure(self, mock_openai_class):
"""Test API failure handling."""
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
mock_client.chat.completions.create.side_effect = Exception("OpenAI is down.")
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
result = generate_prompt(mock_gradio_request)
assert result == "⚠️ Could not generate a prompt. No fish today."
@patch("app.logger")
@patch("app.OpenAI")
@patch.dict(os.environ, {
"SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
"OPENAI_API_KEY": TEST_OPENAI_API_KEY
})
def test_logging_on_success(self, mock_openai_class, mock_logger):
"""Test that logging works correctly on successful API call."""
api_total_token_count = 500
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message = Mock()
mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT
mock_response.usage.total_tokens = api_total_token_count
mock_client.chat.completions.create.return_value = mock_response
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
generate_prompt(mock_gradio_request)
mock_logger.info.assert_any_call(
f"Making OpenAI API request - Session: {TEST_GRADIO_SESSION_HASH}"
)
mock_logger.info.assert_any_call(
f"API request successful - Session: {TEST_GRADIO_SESSION_HASH}"
f" - tokens used: {api_total_token_count}"
)
@patch("app.logger")
@patch("app.OpenAI")
@patch.dict(os.environ, {
"SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
"OPENAI_API_KEY": TEST_OPENAI_API_KEY
})
def test_logging_on_failure(self, mock_openai_class, mock_logger):
"""Test that logging works correctly on failed API call."""
api_error_message = "OpenAI is down."
# Mock the OpenAI client and response
mock_client = Mock()
mock_openai_class.return_value = mock_client
mock_client.chat.completions.create.side_effect = Exception(api_error_message)
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
generate_prompt(mock_gradio_request)
mock_logger.error.assert_any_call(
f"API request failed - Session: {TEST_GRADIO_SESSION_HASH} - Error: {api_error_message}"
)
@patch.dict(os.environ, {}, clear=True)
def test_missing_env_vars(self):
"""Test what happens when SYSTEM_PROMPT and/or OPENAI_API_KEY are missing."""
# Mock the gradio request
mock_gradio_request = Mock()
mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH
# Call the function
result = generate_prompt(mock_gradio_request)
assert result == "⚠️ Service temporarily unavailable. Please try again later."
|