Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
"""Test suite for Flask application endpoints and integration. | |
This module provides comprehensive tests for the Flask application including | |
route testing, model integration, caching behavior, and error handling. | |
Run with: python test_app.py or pytest test_app.py | |
""" | |
import json | |
import os | |
import unittest | |
from unittest import mock | |
# Mock the environment before importing app to avoid initialization errors | |
with mock.patch.dict(os.environ, {'KEY': 'test_api_key_for_import'}): | |
from app import Model, app, setup_cache | |
class TestFlaskApplication(unittest.TestCase): | |
def setUpClass(cls): | |
cls.test_client = app.test_client() | |
app.config['TESTING'] = True | |
def test_index_route_returns_html(self): | |
response = self.test_client.get('/') | |
self.assertEqual(response.status_code, 200) | |
self.assertIn('text/html', response.content_type) | |
def test_cache_stats_route(self): | |
response = self.test_client.get('/cache/stats') | |
self.assertEqual(response.status_code, 200) | |
self.assertEqual(response.content_type, 'application/json') | |
data = json.loads(response.data) | |
self.assertIsInstance(data, dict) | |
def test_predict_route_with_valid_data(self, mock_predict): | |
mock_predict.return_value = { | |
'segments': [{'type': 'body', 'content': 'test'}], | |
'text': 'test output', | |
} | |
response = self.test_client.post('/predict', data='FINDINGS: Normal chest CT') | |
self.assertEqual(response.status_code, 200) | |
data = json.loads(response.data) | |
self.assertIn('segments', data) | |
self.assertIn('text', data) | |
def test_predict_route_with_empty_data(self): | |
response = self.test_client.post('/predict', data='') | |
self.assertEqual(response.status_code, 400) | |
data = json.loads(response.data) | |
self.assertIn('error', data) | |
self.assertEqual(data['error'], 'Empty input') | |
self.assertIn('message', data) | |
self.assertEqual(data['message'], 'Input text is required') | |
self.assertIn('max_length', data) | |
def test_predict_with_custom_headers(self, mock_predict): | |
mock_predict.return_value = {'segments': [], 'text': 'test'} | |
headers = { | |
'X-Use-Cache': 'false', | |
'X-Sample-ID': 'test_sample', | |
'X-Model-ID': 'gemini-2.5-flash', | |
} | |
response = self.test_client.post( | |
'/predict', data='Test report', headers=headers | |
) | |
self.assertEqual(response.status_code, 200) | |
mock_predict.assert_called_once_with('Test report', model_id='gemini-2.5-flash') | |
def test_predict_with_cache_hit(self, mock_get_cached): | |
cached_response = { | |
'segments': [{'type': 'body', 'content': 'cached'}], | |
'text': 'cached result', | |
} | |
mock_get_cached.return_value = cached_response | |
response = self.test_client.post( | |
'/predict', data='Test report', headers={'X-Use-Cache': 'true'} | |
) | |
data = json.loads(response.data) | |
self.assertTrue(data.get('from_cache')) | |
self.assertIn('segments', data) | |
class TestModelClass(unittest.TestCase): | |
def test_model_initialization_with_api_key(self): | |
model = Model() | |
self.assertEqual(model.gemini_api_key, 'test_api_key') | |
self.assertIn('gemini-2.5-flash', model._structurers) | |
def test_model_initialization_without_api_key(self): | |
with self.assertRaises(ValueError) as context: | |
Model() | |
self.assertIn('KEY environment variable not set', str(context.exception)) | |
def test_model_initialization_with_custom_model(self): | |
model = Model() | |
self.assertIn('custom-model', model._structurers) | |
def test_get_structurer_creates_new_instance(self, mock_structurer_class): | |
model = Model() | |
model._get_structurer('new-model') | |
# Should be called twice: once for default, once for new model | |
self.assertEqual(mock_structurer_class.call_count, 2) | |
def test_predict_calls_structurer(self, mock_structurer_class): | |
mock_instance = mock.Mock() | |
mock_instance.predict.return_value = {'result': 'test'} | |
mock_structurer_class.return_value = mock_instance | |
model = Model() | |
result = model.predict('test data', 'test-model') | |
mock_instance.predict.assert_called_once_with('test data') | |
self.assertEqual(result, {'result': 'test'}) | |
class TestCacheSetup(unittest.TestCase): | |
def test_setup_cache_copies_existing_file( | |
self, mock_makedirs, mock_copy, mock_exists | |
): | |
mock_exists.return_value = True | |
cache_dir = setup_cache() | |
mock_makedirs.assert_called_once_with('/tmp/cache', exist_ok=True) | |
mock_copy.assert_called_once() | |
self.assertEqual(cache_dir, '/tmp/cache') | |
def test_setup_cache_handles_missing_source(self, mock_makedirs, mock_exists): | |
mock_exists.return_value = False | |
cache_dir = setup_cache() | |
mock_makedirs.assert_called_once_with('/tmp/cache', exist_ok=True) | |
self.assertEqual(cache_dir, '/tmp/cache') | |
class TestErrorHandling(unittest.TestCase): | |
def setUpClass(cls): | |
cls.test_client = app.test_client() | |
app.config['TESTING'] = True | |
def setUp(self): | |
# Suppress all logging during error tests to reduce noise | |
import logging | |
logging.disable(logging.CRITICAL) | |
def tearDown(self): | |
# Re-enable logging | |
import logging | |
logging.disable(logging.NOTSET) | |
def test_predict_handles_type_error(self, mock_logger, mock_predict): | |
mock_predict.side_effect = TypeError('Invalid type') | |
response = self.test_client.post('/predict', data='Test data') | |
self.assertEqual(response.status_code, 500) | |
data = json.loads(response.data) | |
self.assertIn('Processing error', data['error']) | |
def test_predict_handles_general_exception(self, mock_logger, mock_predict): | |
mock_predict.side_effect = Exception('General error') | |
response = self.test_client.post('/predict', data='Test data') | |
self.assertEqual(response.status_code, 500) | |
data = json.loads(response.data) | |
self.assertIn('General error', data['error']) | |
if __name__ == '__main__': | |
unittest.main() | |