Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import json | |
from typing import Dict, List, Any | |
import time | |
# Initialize the client with retries | |
MAX_RETRIES = 3 | |
RETRY_DELAY = 2 | |
def create_client(retries=MAX_RETRIES): | |
"""Create inference client with retry logic""" | |
for attempt in range(retries): | |
try: | |
return InferenceClient( | |
"HuggingFaceH4/zephyr-7b-beta", | |
timeout=30 | |
) | |
except Exception as e: | |
if attempt == retries - 1: | |
print(f"Failed to create client after {retries} attempts: {e}") | |
return None | |
print(f"Attempt {attempt + 1} failed, retrying in {RETRY_DELAY} seconds...") | |
time.sleep(RETRY_DELAY) | |
# Initialize the client | |
client = create_client() | |
def load_site_content() -> Dict[str, Any]: | |
"""Load the site content from JSON file.""" | |
try: | |
with open("data/site_content.json", "r") as f: | |
return json.load(f) | |
except Exception as e: | |
print(f"Error loading JSON: {e}") | |
return {} | |
def get_relevant_context(query: str, data: Dict[str, Any]) -> str: | |
"""Get relevant context based on the query keywords.""" | |
query = query.lower() | |
context_parts = [] | |
# Company info for general queries | |
if any(word in query for word in ['company', 'about', 'who', 'where', 'location', 'south africa', 'african', 'ceo', 'founder', 'wayne', 'sletcher']): | |
info = data.get('company_info', {}) | |
leadership = info.get('leadership', {}).get('ceo', {}) | |
context_parts.append(f""" | |
Company Information: | |
- Name: {info.get('name', '')} | |
- {info.get('tagline', '')} | |
- CEO and Founder: {leadership.get('name', 'Wayne Sletcher')} | |
- Mission: {info.get('mission', '')} | |
- Vision: {info.get('vision', '')} | |
- Location: {info.get('location', '')} | |
- Payment Methods: {info.get('payment', '')} | |
""") | |
# Stats for numbers/achievements queries | |
if any(word in query for word in ['many', 'numbers', 'statistics', 'stats', 'achievements']): | |
stats = data.get('stats', {}) | |
context_parts.append("\nAchievements and Statistics:") | |
for key, value in stats.items(): | |
context_parts.append(f"- {value}") | |
# Services for service-related queries | |
if any(word in query for word in ['service', 'offer', 'provide', 'can you', 'capability']): | |
services = data.get('services', []) | |
context_parts.append("\nOur Services:") | |
for service in services: | |
context_parts.append(f""" | |
- {service.get('name', '')} | |
Description: {service.get('description', '')} | |
Key Features: {', '.join(service.get('key_features', []))} | |
Technologies: {', '.join(service.get('technologies', []))}""") | |
# Solutions for solution-related queries | |
if any(word in query for word in ['solution', 'blockchain', 'development', 'network']): | |
solutions = data.get('solutions', []) | |
context_parts.append("\nOur Solutions:") | |
for solution in solutions: | |
context_parts.append(f""" | |
- {solution.get('name', '')} | |
Description: {solution.get('description', '')} | |
Key Features: {', '.join(solution.get('key_features', []))} | |
Technologies: {', '.join(solution.get('technologies', []))}""") | |
# Handle specialization queries | |
if any(word in query for word in ['specialize', 'specialization', 'education', 'learning', 'game', 'games', 'development', 'rag', 'speech', 'tts', 'stt']): | |
specs = data.get('specializations', {}) | |
context_parts.append("\nOur Specializations:") | |
for spec_key, spec_data in specs.items(): | |
context_parts.append(f""" | |
- {spec_data.get('name', '')} | |
Features: {', '.join(spec_data.get('core_features', []))} | |
Technologies: {', '.join(spec_data.get('key_technologies', []))}""") | |
# Add payment information for relevant queries | |
if any(word in query for word in ['payment', 'pay', 'cost', 'pricing', 'bank', 'bitcoin', 'eth', 'btc']): | |
info = data.get('company_info', {}) | |
context_parts.append(f"\nPayment Information: {info.get('payment', '')}") | |
# If no specific context matched, provide general information | |
if not context_parts: | |
info = data.get('company_info', {}) | |
context_parts.append(f""" | |
{info.get('name', '')} - {info.get('tagline', '')} | |
{info.get('mission', '')} | |
{info.get('location', '')}""") | |
return "\n".join(context_parts) | |
def respond( | |
message: str, | |
history: List[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
): | |
global client | |
# Ensure client is available | |
if client is None: | |
client = create_client() | |
if client is None: | |
yield "I apologize, but I'm having trouble connecting to the language model. Please try again in a moment." | |
return | |
# Load content | |
content = load_site_content() | |
if not content: | |
yield "I apologize, but I'm having trouble accessing the company information. Please try again in a moment." | |
return | |
# Get relevant context | |
context = get_relevant_context(message, content) | |
# Enhanced system message with strict instructions | |
enhanced_system_message = f"""{system_message} | |
IMPORTANT CONTEXT - USE THIS INFORMATION ONLY: | |
{context} | |
STRICT INSTRUCTIONS: | |
1. ONLY use information from the context provided above | |
2. If information isn't in the context, say "I don't have that specific information" | |
3. NEVER make assumptions about location - we are a proudly South African company | |
4. NEVER invent services or capabilities not listed | |
5. Be accurate about our AI and educational technology focus | |
6. Acknowledge our cryptocurrency acceptance when relevant | |
7. Use exact statistics when they're provided in the context | |
8. Always acknowledge Wayne Sletcher as CEO and Founder when relevant""" | |
try: | |
# Format conversation history | |
messages = [{"role": "system", "content": enhanced_system_message}] | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
# Stream the response | |
response = "" | |
for msg in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = msg.choices[0].delta.content | |
response += token | |
yield response | |
except Exception as e: | |
print(f"Error in chat completion: {e}") | |
# Try to recreate client on error | |
client = create_client() | |
yield "I apologize, but I encountered an error. Please try your question again." | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are the official AI assistant for SletcherSystems, a proudly South African technology company. Provide accurate, specific information based only on the provided context.", | |
label="System message" | |
), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" | |
), | |
], | |
title="SletcherSystems AI Assistant", | |
description="Welcome! I'm here to help you learn about SletcherSystems, a proudly South African technology company.", | |
theme=gr.themes.Soft() | |
) | |
if __name__ == "__main__": | |
demo.launch() |