igitman's picture
inference_endpoint (#2)
b1a665d verified
import json
import logging
from http.server import HTTPServer, BaseHTTPRequestHandler
from handler import EndpointHandler
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize the handler
handler = EndpointHandler()
class RequestHandler(BaseHTTPRequestHandler):
def do_POST(self):
try:
content_length = int(self.headers['Content-Length'])
post_data = self.rfile.read(content_length)
data = json.loads(post_data.decode('utf-8'))
logger.info(f'Received request with {len(data.get("inputs", []))} inputs')
# Process the request
result = handler(data)
# Send response
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(result).encode('utf-8'))
except Exception as e:
logger.error(f'Error processing request: {str(e)}')
self.send_response(500)
self.send_header('Content-Type', 'application/json')
self.end_headers()
error_response = [{'error': str(e), 'generated_text': ''}]
self.wfile.write(json.dumps(error_response).encode('utf-8'))
def do_GET(self):
if self.path == '/health':
# Trigger initialisation if needed but don't block.
if not handler.initialized:
try:
handler._initialize_components()
except Exception as e:
logger.error(f'Initialization failed during health check: {str(e)}')
is_ready = handler.initialized
health_response = {
'status': 'healthy' if is_ready else 'unhealthy',
'model_ready': is_ready
}
try:
self.send_response(200 if is_ready else 503)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(health_response).encode('utf-8'))
except BrokenPipeError:
# Client disconnected before we replied – safe to ignore.
pass
return
else:
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
# Suppress default HTTP server logs to keep output clean
pass
if __name__ == "__main__":
server = HTTPServer(('0.0.0.0', 80), RequestHandler)
logger.info('HTTP server started on port 80')
server.serve_forever()