File size: 4,964 Bytes
9b98f2e
 
 
 
 
 
 
 
 
 
 
8cccf42
 
 
 
9b98f2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665c6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cccf42
665c6da
5bca8f8
 
8cccf42
 
 
 
5bca8f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cccf42
 
 
 
5bca8f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cccf42
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import pytest
from unittest.mock import Mock, patch, MagicMock
from app import generate_prompt

TEST_SYSTEM_PROMPT = "PROMPT ROULETTE ..."
TEST_OPENAI_API_KEY = "sk-prompt-roulette-..."
TEST_OPENAI_CHAT_RESPONSE_CONTENT = "You are ..."
TEST_GRADIO_SESSION_HASH = "prompt-roulette-session-..."

class TestGeneratePrompt:
    def test_gradio_app_loads(self):
        from app import demo
        assert demo is not None

    @patch("app.OpenAI")
    @patch.dict(os.environ, {
        "SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
        "OPENAI_API_KEY": TEST_OPENAI_API_KEY
    })
    def test_generate_prompt_success(self, mock_openai_class):
        """Test a successful OpenAI API call."""
        # Mock the OpenAI client and response
        mock_client = Mock()
        mock_openai_class.return_value = mock_client

        mock_response = Mock()
        mock_response.choices = [Mock()]
        mock_response.choices[0].message = Mock()
        mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT

        mock_client.chat.completions.create.return_value = mock_response

        # Mock the gradio request
        mock_gradio_request = Mock()
        mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH

        # Call the function
        result = generate_prompt(mock_gradio_request)

        assert result == TEST_OPENAI_CHAT_RESPONSE_CONTENT

    @patch("app.OpenAI")
    @patch.dict(os.environ, {
        "SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
        "OPENAI_API_KEY": TEST_OPENAI_API_KEY
    })
    def test_generate_prompt_api_failure(self, mock_openai_class):
        """Test API failure handling."""
        # Mock the OpenAI client and response
        mock_client = Mock()
        mock_openai_class.return_value = mock_client
        mock_client.chat.completions.create.side_effect = Exception("OpenAI is down.")

        # Mock the gradio request
        mock_gradio_request = Mock()
        mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH

        # Call the function
        result = generate_prompt(mock_gradio_request)

        assert result == "⚠️  Could not generate a prompt. No fish today."

    @patch("app.logger")
    @patch("app.OpenAI")
    @patch.dict(os.environ, {
        "SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
        "OPENAI_API_KEY": TEST_OPENAI_API_KEY
    })
    def test_logging_on_success(self, mock_openai_class, mock_logger):
        """Test that logging works correctly on successful API call."""
        api_total_token_count = 500

        # Mock the OpenAI client and response
        mock_client = Mock()
        mock_openai_class.return_value = mock_client

        mock_response = Mock()
        mock_response.choices = [Mock()]
        mock_response.choices[0].message = Mock()
        mock_response.choices[0].message.content = TEST_OPENAI_CHAT_RESPONSE_CONTENT
        mock_response.usage.total_tokens = api_total_token_count

        mock_client.chat.completions.create.return_value = mock_response

        # Mock the gradio request
        mock_gradio_request = Mock()
        mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH

        # Call the function
        generate_prompt(mock_gradio_request)

        mock_logger.info.assert_any_call(
            f"Making OpenAI API request - Session: {TEST_GRADIO_SESSION_HASH}"
        )
        mock_logger.info.assert_any_call(
            f"API request successful - Session: {TEST_GRADIO_SESSION_HASH}"
            f" - tokens used: {api_total_token_count}"
        )

    @patch("app.logger")
    @patch("app.OpenAI")
    @patch.dict(os.environ, {
        "SYSTEM_PROMPT": TEST_SYSTEM_PROMPT,
        "OPENAI_API_KEY": TEST_OPENAI_API_KEY
    })
    def test_logging_on_failure(self, mock_openai_class, mock_logger):
        """Test that logging works correctly on failed API call."""
        api_error_message = "OpenAI is down."

        # Mock the OpenAI client and response
        mock_client = Mock()
        mock_openai_class.return_value = mock_client
        mock_client.chat.completions.create.side_effect = Exception(api_error_message)

        # Mock the gradio request
        mock_gradio_request = Mock()
        mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH

        # Call the function
        generate_prompt(mock_gradio_request)

        mock_logger.error.assert_any_call(
            f"API request failed - Session: {TEST_GRADIO_SESSION_HASH} - Error: {api_error_message}"
        )

    @patch.dict(os.environ, {}, clear=True)
    def test_missing_env_vars(self):
        """Test what happens when SYSTEM_PROMPT and/or OPENAI_API_KEY are missing."""
        # Mock the gradio request
        mock_gradio_request = Mock()
        mock_gradio_request.session_hash = TEST_GRADIO_SESSION_HASH

        # Call the function
        result = generate_prompt(mock_gradio_request)

        assert result ==  "⚠️  Service temporarily unavailable. Please try again later."