Create model_logic.py
Browse files- model_logic.py +402 -0
model_logic.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logging.basicConfig(
|
| 7 |
+
level=logging.INFO,
|
| 8 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 9 |
+
)
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
API_KEYS = {
|
| 13 |
+
"HUGGINGFACE": 'HF_TOKEN',
|
| 14 |
+
"GROQ": 'GROQ_API_KEY',
|
| 15 |
+
"OPENROUTER": 'OPENROUTER_API_KEY',
|
| 16 |
+
"TOGETHERAI": 'TOGETHERAI_API_KEY',
|
| 17 |
+
"COHERE": 'COHERE_API_KEY',
|
| 18 |
+
"XAI": 'XAI_API_KEY',
|
| 19 |
+
"OPENAI": 'OPENAI_API_KEY',
|
| 20 |
+
"GOOGLE": 'GOOGLE_API_KEY',
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
API_URLS = {
|
| 24 |
+
"HUGGINGFACE": 'https://api-inference.huggingface.co/models/',
|
| 25 |
+
"GROQ": 'https://api.groq.com/openai/v1/chat/completions',
|
| 26 |
+
"OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions',
|
| 27 |
+
"TOGETHERAI": 'https://api.together.ai/v1/chat/completions',
|
| 28 |
+
"COHERE": 'https://api.cohere.ai/v1/chat',
|
| 29 |
+
"XAI": 'https://api.x.ai/v1/chat/completions',
|
| 30 |
+
"OPENAI": 'https://api.openai.com/v1/chat/completions',
|
| 31 |
+
"GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
MODELS_BY_PROVIDER = {
|
| 35 |
+
"groq": {
|
| 36 |
+
"default": "llama3-8b-8192",
|
| 37 |
+
"models": {
|
| 38 |
+
"Llama 3 8B (Groq)": "llama3-8b-8192",
|
| 39 |
+
"Llama 3 70B (Groq)": "llama3-70b-8192",
|
| 40 |
+
"Mixtral 8x7B (Groq)": "mixtral-8x7b-32768",
|
| 41 |
+
"Gemma 7B (Groq)": "gemma-7b-it",
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"openrouter": {
|
| 45 |
+
"default": "nousresearch/llama-3-8b-instruct",
|
| 46 |
+
"models": {
|
| 47 |
+
"Nous Llama-3 8B Instruct (OpenRouter)": "nousresearch/llama-3-8b-instruct",
|
| 48 |
+
"Mistral 7B Instruct v0.2 (OpenRouter)": "mistralai/mistral-7b-instruct:free",
|
| 49 |
+
"Gemma 7B Instruct (OpenRouter)": "google/gemma-7b-it:free",
|
| 50 |
+
"Mixtral 8x7B Instruct v0.1 (OpenRouter)": "mistralai/mixtral-8x7b-instruct",
|
| 51 |
+
"Llama 2 70B Chat (OpenRouter)": "meta-llama/llama-2-70b-chat",
|
| 52 |
+
"Neural Chat 7B v3.1 (OpenRouter)": "intel/neural-chat-7b-v3-1",
|
| 53 |
+
"Goliath 120B (OpenRouter)": "twob/goliath-v2-120b",
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"togetherai": {
|
| 57 |
+
"default": "meta-llama/Llama-3-8b-chat-hf",
|
| 58 |
+
"models": {
|
| 59 |
+
"Llama 3 8B Chat (TogetherAI)": "meta-llama/Llama-3-8b-chat-hf",
|
| 60 |
+
"Llama 3 70B Chat (TogetherAI)": "meta-llama/Llama-3-70b-chat-hf",
|
| 61 |
+
"Mixtral 8x7B Instruct (TogetherAI)": "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
| 62 |
+
"Gemma 7B Instruct (TogetherAI)": "google/gemma-7b-it",
|
| 63 |
+
"RedPajama INCITE Chat 3B (TogetherAI)": "togethercomputer/RedPajama-INCITE-Chat-3B-v1",
|
| 64 |
+
}
|
| 65 |
+
},
|
| 66 |
+
"google": {
|
| 67 |
+
"default": "gemini-1.5-flash-latest",
|
| 68 |
+
"models": {
|
| 69 |
+
"Gemini 1.5 Flash (Latest)": "gemini-1.5-flash-latest",
|
| 70 |
+
"Gemini 1.5 Pro (Latest)": "gemini-1.5-pro-latest",
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
"cohere": {
|
| 74 |
+
"default": "command-light",
|
| 75 |
+
"models": {
|
| 76 |
+
"Command R (Cohere)": "command-r",
|
| 77 |
+
"Command R+ (Cohere)": "command-r-plus",
|
| 78 |
+
"Command Light (Cohere)": "command-light",
|
| 79 |
+
"Command (Cohere)": "command",
|
| 80 |
+
}
|
| 81 |
+
},
|
| 82 |
+
"huggingface": {
|
| 83 |
+
"default": "HuggingFaceH4/zephyr-7b-beta",
|
| 84 |
+
"models": {
|
| 85 |
+
"Zephyr 7B Beta (H4/HF Inf.)": "HuggingFaceH4/zephyr-7b-beta",
|
| 86 |
+
"Mistral 7B Instruct v0.2 (HF Inf.)": "mistralai/Mistral-7B-Instruct-v0.2",
|
| 87 |
+
"Llama 2 13B Chat (Meta/HF Inf.)": "meta-llama/Llama-2-13b-chat-hf",
|
| 88 |
+
"OpenAssistant/oasst-sft-4-pythia-12b (HF Inf.)": "OpenAssistant/oasst-sft-4-pythia-12b",
|
| 89 |
+
}
|
| 90 |
+
},
|
| 91 |
+
"openai": {
|
| 92 |
+
"default": "gpt-3.5-turbo",
|
| 93 |
+
"models": {
|
| 94 |
+
"GPT-4o (OpenAI)": "gpt-4o",
|
| 95 |
+
"GPT-4o mini (OpenAI)": "gpt-4o-mini",
|
| 96 |
+
"GPT-4 Turbo (OpenAI)": "gpt-4-turbo",
|
| 97 |
+
"GPT-3.5 Turbo (OpenAI)": "gpt-3.5-turbo",
|
| 98 |
+
}
|
| 99 |
+
},
|
| 100 |
+
"xai": {
|
| 101 |
+
"default": "grok-1",
|
| 102 |
+
"models": {
|
| 103 |
+
"Grok-1 (xAI)": "grok-1",
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
def _get_api_key(provider: str, ui_api_key_override: str = None) -> str:
|
| 109 |
+
if ui_api_key_override:
|
| 110 |
+
return ui_api_key_override.strip()
|
| 111 |
+
|
| 112 |
+
env_var_name = API_KEYS.get(provider.upper())
|
| 113 |
+
if env_var_name:
|
| 114 |
+
env_key = os.getenv(env_var_name)
|
| 115 |
+
if env_key:
|
| 116 |
+
return env_key.strip()
|
| 117 |
+
|
| 118 |
+
if provider.lower() == 'huggingface':
|
| 119 |
+
hf_token = os.getenv("HF_TOKEN")
|
| 120 |
+
if hf_token: return hf_token.strip()
|
| 121 |
+
|
| 122 |
+
logger.warning(f"API Key not found for provider '{provider}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
def get_available_providers() -> list[str]:
|
| 126 |
+
return sorted(list(MODELS_BY_PROVIDER.keys()))
|
| 127 |
+
|
| 128 |
+
def get_models_for_provider(provider: str) -> list[str]:
|
| 129 |
+
return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))
|
| 130 |
+
|
| 131 |
+
def get_default_model_for_provider(provider: str) -> str | None:
|
| 132 |
+
models_dict = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
|
| 133 |
+
default_model_id = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("default")
|
| 134 |
+
if default_model_id:
|
| 135 |
+
for display_name, model_id in models_dict.items():
|
| 136 |
+
if model_id == default_model_id:
|
| 137 |
+
return display_name
|
| 138 |
+
if models_dict:
|
| 139 |
+
return sorted(list(models_dict.keys()))[0]
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
|
| 143 |
+
models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
|
| 144 |
+
return models.get(display_name)
|
| 145 |
+
|
| 146 |
+
def generate_stream(provider: str, model_display_name: str, api_key_override: str, messages: list[dict]) -> iter:
|
| 147 |
+
provider_lower = provider.lower()
|
| 148 |
+
api_key = _get_api_key(provider_lower, api_key_override)
|
| 149 |
+
|
| 150 |
+
base_url = API_URLS.get(provider.upper())
|
| 151 |
+
model_id = get_model_id_from_display_name(provider_lower, model_display_name)
|
| 152 |
+
|
| 153 |
+
if not api_key:
|
| 154 |
+
env_var_name = API_KEYS.get(provider.upper(), 'N/A')
|
| 155 |
+
yield f"Error: API Key not found for {provider}. Please set it in the UI override or environment variable '{env_var_name}'."
|
| 156 |
+
return
|
| 157 |
+
if not base_url:
|
| 158 |
+
yield f"Error: Unknown provider '{provider}' or missing API URL configuration."
|
| 159 |
+
return
|
| 160 |
+
if not model_id:
|
| 161 |
+
yield f"Error: Unknown model '{model_display_name}' for provider '{provider}'. Please select a valid model."
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
headers = {}
|
| 165 |
+
payload = {}
|
| 166 |
+
request_url = base_url
|
| 167 |
+
|
| 168 |
+
logger.info(f"Calling {provider}/{model_display_name} (ID: {model_id}) stream...")
|
| 169 |
+
|
| 170 |
+
try:
|
| 171 |
+
if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]:
|
| 172 |
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
| 173 |
+
payload = {
|
| 174 |
+
"model": model_id,
|
| 175 |
+
"messages": messages,
|
| 176 |
+
"stream": True
|
| 177 |
+
}
|
| 178 |
+
if provider_lower == "openrouter":
|
| 179 |
+
headers["HTTP-Referer"] = os.getenv("SPACE_HOST") or "https://github.com/your_username/ai-space-builder"
|
| 180 |
+
headers["X-Title"] = "AI Space Builder"
|
| 181 |
+
|
| 182 |
+
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
|
| 183 |
+
response.raise_for_status()
|
| 184 |
+
|
| 185 |
+
byte_buffer = b""
|
| 186 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 187 |
+
byte_buffer += chunk
|
| 188 |
+
while b'\n' in byte_buffer:
|
| 189 |
+
line, byte_buffer = byte_buffer.split(b'\n', 1)
|
| 190 |
+
decoded_line = line.decode('utf-8', errors='ignore')
|
| 191 |
+
if decoded_line.startswith('data: '):
|
| 192 |
+
data = decoded_line[6:]
|
| 193 |
+
if data == '[DONE]':
|
| 194 |
+
byte_buffer = b''
|
| 195 |
+
break
|
| 196 |
+
try:
|
| 197 |
+
event_data = json.loads(data)
|
| 198 |
+
if event_data.get("choices") and len(event_data["choices"]) > 0:
|
| 199 |
+
delta = event_data["choices"][0].get("delta")
|
| 200 |
+
if delta and delta.get("content"):
|
| 201 |
+
yield delta["content"]
|
| 202 |
+
except json.JSONDecodeError:
|
| 203 |
+
logger.warning(f"Failed to decode JSON from stream line: {decoded_line}")
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"Error processing stream data: {e}, Data: {decoded_line}")
|
| 206 |
+
if byte_buffer:
|
| 207 |
+
remaining_line = byte_buffer.decode('utf-8', errors='ignore')
|
| 208 |
+
if remaining_line.startswith('data: '):
|
| 209 |
+
data = remaining_line[6:]
|
| 210 |
+
if data != '[DONE]':
|
| 211 |
+
try:
|
| 212 |
+
event_data = json.loads(data)
|
| 213 |
+
if event_data.get("choices") and len(event_data["choices"]) > 0:
|
| 214 |
+
delta = event_data["choices"][0].get("delta")
|
| 215 |
+
if delta and delta.get("content"):
|
| 216 |
+
yield delta["content"]
|
| 217 |
+
except json.JSONDecodeError:
|
| 218 |
+
logger.warning(f"Failed to decode final stream buffer JSON: {remaining_line}")
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logger.error(f"Error processing final stream buffer data: {e}, Data: {remaining_line}")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
elif provider_lower == "google":
|
| 224 |
+
system_instruction = None
|
| 225 |
+
filtered_messages = []
|
| 226 |
+
for msg in messages:
|
| 227 |
+
if msg["role"] == "system":
|
| 228 |
+
system_instruction = msg["content"]
|
| 229 |
+
else:
|
| 230 |
+
role = "model" if msg["role"] == "assistant" else msg["role"]
|
| 231 |
+
filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})
|
| 232 |
+
|
| 233 |
+
payload = {
|
| 234 |
+
"contents": filtered_messages,
|
| 235 |
+
"safetySettings": [
|
| 236 |
+
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
| 237 |
+
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
| 238 |
+
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
| 239 |
+
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
| 240 |
+
],
|
| 241 |
+
"generationConfig": {
|
| 242 |
+
"temperature": 0.7,
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
if system_instruction:
|
| 246 |
+
payload["system_instruction"] = {"parts": [{"text": system_instruction}]}
|
| 247 |
+
|
| 248 |
+
request_url = f"{base_url}{model_id}:streamGenerateContent"
|
| 249 |
+
headers = {"Content-Type": "application/json"}
|
| 250 |
+
request_url = f"{request_url}?key={api_key}"
|
| 251 |
+
|
| 252 |
+
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
|
| 253 |
+
response.raise_for_status()
|
| 254 |
+
|
| 255 |
+
byte_buffer = b""
|
| 256 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 257 |
+
byte_buffer += chunk
|
| 258 |
+
while b'\n' in byte_buffer:
|
| 259 |
+
line, byte_buffer = byte_buffer.split(b'\n', 1)
|
| 260 |
+
decoded_line = line.decode('utf-8', errors='ignore')
|
| 261 |
+
|
| 262 |
+
if decoded_line.startswith('data: '):
|
| 263 |
+
decoded_line = decoded_line[6:].strip()
|
| 264 |
+
|
| 265 |
+
if not decoded_line: continue
|
| 266 |
+
|
| 267 |
+
try:
|
| 268 |
+
event_data_list = json.loads(f"[{decoded_line}]")
|
| 269 |
+
if not isinstance(event_data_list, list): event_data_list = [event_data_list]
|
| 270 |
+
|
| 271 |
+
for event_data in event_data_list:
|
| 272 |
+
if not isinstance(event_data, dict): continue
|
| 273 |
+
|
| 274 |
+
if event_data.get("candidates") and len(event_data["candidates"]) > 0:
|
| 275 |
+
candidate = event_data["candidates"][0]
|
| 276 |
+
if candidate.get("content") and candidate["content"].get("parts"):
|
| 277 |
+
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
|
| 278 |
+
if full_text_chunk:
|
| 279 |
+
yield full_text_chunk
|
| 280 |
+
|
| 281 |
+
except json.JSONDecodeError:
|
| 282 |
+
logger.warning(f"Failed to decode JSON from Google stream chunk: {decoded_line}. Accumulating buffer.")
|
| 283 |
+
pass
|
| 284 |
+
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(f"Error processing Google stream data: {e}, Data: {decoded_line}")
|
| 287 |
+
|
| 288 |
+
if byte_buffer:
|
| 289 |
+
remaining_line = byte_buffer.decode('utf-8', errors='ignore').strip()
|
| 290 |
+
if remaining_line:
|
| 291 |
+
try:
|
| 292 |
+
event_data_list = json.loads(f"[{remaining_line}]")
|
| 293 |
+
if not isinstance(event_data_list, list): event_data_list = [event_data_list]
|
| 294 |
+
for event_data in event_data_list:
|
| 295 |
+
if not isinstance(event_data, dict): continue
|
| 296 |
+
if event_data.get("candidates") and len(event_data["candidates"]) > 0:
|
| 297 |
+
candidate = event_data["candidates"][0]
|
| 298 |
+
if candidate.get("content") and candidate["content"].get("parts"):
|
| 299 |
+
full_text_chunk = "".join(part.get("text", "") for part in candidate["content"]["parts"])
|
| 300 |
+
if full_text_chunk:
|
| 301 |
+
yield full_text_chunk
|
| 302 |
+
except json.JSONDecodeError:
|
| 303 |
+
logger.warning(f"Failed to decode final Google stream buffer JSON: {remaining_line}")
|
| 304 |
+
except Exception as e:
|
| 305 |
+
logger.error(f"Error processing final Google stream buffer data: {e}, Data: {remaining_line}")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
elif provider_lower == "cohere":
|
| 309 |
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
| 310 |
+
request_url = f"{base_url}"
|
| 311 |
+
|
| 312 |
+
chat_history_for_cohere = []
|
| 313 |
+
system_prompt_for_cohere = None
|
| 314 |
+
current_message_for_cohere = ""
|
| 315 |
+
|
| 316 |
+
temp_history = []
|
| 317 |
+
for msg in messages:
|
| 318 |
+
if msg["role"] == "system":
|
| 319 |
+
system_prompt_for_cohere = msg["content"]
|
| 320 |
+
elif msg["role"] == "user" or msg["role"] == "assistant":
|
| 321 |
+
temp_history.append(msg)
|
| 322 |
+
|
| 323 |
+
if temp_history:
|
| 324 |
+
current_message_for_cohere = temp_history[-1]["content"]
|
| 325 |
+
chat_history_for_cohere = [{"role": ("chatbot" if m["role"] == "assistant" else m["role"]), "message": m["content"]} for m in temp_history[:-1]]
|
| 326 |
+
|
| 327 |
+
if not current_message_for_cohere:
|
| 328 |
+
yield "Error: User message not found for Cohere API call."
|
| 329 |
+
return
|
| 330 |
+
|
| 331 |
+
payload = {
|
| 332 |
+
"model": model_id,
|
| 333 |
+
"message": current_message_for_cohere,
|
| 334 |
+
"stream": True,
|
| 335 |
+
"temperature": 0.7
|
| 336 |
+
}
|
| 337 |
+
if chat_history_for_cohere:
|
| 338 |
+
payload["chat_history"] = chat_history_for_cohere
|
| 339 |
+
if system_prompt_for_cohere:
|
| 340 |
+
payload["preamble"] = system_prompt_for_cohere
|
| 341 |
+
|
| 342 |
+
response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
|
| 343 |
+
response.raise_for_status()
|
| 344 |
+
|
| 345 |
+
byte_buffer = b""
|
| 346 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 347 |
+
byte_buffer += chunk
|
| 348 |
+
while b'\n\n' in byte_buffer:
|
| 349 |
+
event_chunk, byte_buffer = byte_buffer.split(b'\n\n', 1)
|
| 350 |
+
lines = event_chunk.strip().split(b'\n')
|
| 351 |
+
event_type = None
|
| 352 |
+
event_data = None
|
| 353 |
+
|
| 354 |
+
for l in lines:
|
| 355 |
+
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
|
| 356 |
+
elif l.startswith(b"data: "):
|
| 357 |
+
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
|
| 358 |
+
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode event data JSON: {l[6:].strip()}")
|
| 359 |
+
|
| 360 |
+
if event_type == "text-generation" and event_data and "text" in event_data:
|
| 361 |
+
yield event_data["text"]
|
| 362 |
+
elif event_type == "stream-end":
|
| 363 |
+
byte_buffer = b''
|
| 364 |
+
break
|
| 365 |
+
|
| 366 |
+
if byte_buffer:
|
| 367 |
+
event_chunk = byte_buffer.strip()
|
| 368 |
+
if event_chunk:
|
| 369 |
+
lines = event_chunk.split(b'\n')
|
| 370 |
+
event_type = None
|
| 371 |
+
event_data = None
|
| 372 |
+
for l in lines:
|
| 373 |
+
if l.startswith(b"event: "): event_type = l[7:].strip().decode('utf-8', errors='ignore')
|
| 374 |
+
elif l.startswith(b"data: "):
|
| 375 |
+
try: event_data = json.loads(l[6:].strip().decode('utf-8', errors='ignore'))
|
| 376 |
+
except json.JSONDecodeError: logger.warning(f"Cohere: Failed to decode final event data JSON: {l[6:].strip()}")
|
| 377 |
+
|
| 378 |
+
if event_type == "text-generation" and event_data and "text" in event_data:
|
| 379 |
+
yield event_data["text"]
|
| 380 |
+
elif event_type == "stream-end":
|
| 381 |
+
pass
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
elif provider_lower == "huggingface":
|
| 385 |
+
yield f"Error: Direct Hugging Face Inference API streaming for chat models is experimental and model-dependent. Consider using OpenRouter or TogetherAI for HF models with standardized streaming."
|
| 386 |
+
return
|
| 387 |
+
|
| 388 |
+
else:
|
| 389 |
+
yield f"Error: Unsupported provider '{provider}' for streaming chat."
|
| 390 |
+
return
|
| 391 |
+
|
| 392 |
+
except requests.exceptions.HTTPError as e:
|
| 393 |
+
status_code = e.response.status_code if e.response is not None else 'N/A'
|
| 394 |
+
error_text = e.response.text if e.response is not None else 'No response text'
|
| 395 |
+
logger.error(f"HTTP error during streaming for {provider}/{model_id}: {e}")
|
| 396 |
+
yield f"API HTTP Error ({status_code}): {error_text}\nDetails: {e}"
|
| 397 |
+
except requests.exceptions.RequestException as e:
|
| 398 |
+
logger.error(f"Request error during streaming for {provider}/{model_id}: {e}")
|
| 399 |
+
yield f"API Request Error: Could not connect or receive response from {provider} ({e})"
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.exception(f"Unexpected error during streaming for {provider}/{model_id}:")
|
| 402 |
+
yield f"An unexpected error occurred: {e}"
|