arthuroe commited on
Commit
e85f548
·
verified ·
1 Parent(s): 5561947

Create openrouter_llm.py

Browse files
Files changed (1) hide show
  1. openrouter_llm.py +428 -0
openrouter_llm.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import json
4
+ import requests
5
+ from typing import List, Dict, Any, Optional, Union
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class OpenRouterFreeAdapter:
13
+ """Adapter for accessing only free LLMs through OpenRouter.ai API"""
14
+
15
+ def __init__(
16
+ self,
17
+ api_key: str = None,
18
+ base_url: str = "https://openrouter.ai/api/v1"
19
+ ):
20
+ """
21
+ Initialize the OpenRouter adapter for free models only.
22
+
23
+ Args:
24
+ api_key: OpenRouter API key. If None, will try to load from environment.
25
+ base_url: Base URL for the OpenRouter API.
26
+ """
27
+ self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
28
+ if not self.api_key:
29
+ logger.warning(
30
+ "No OpenRouter API key provided. Using limited free access.")
31
+
32
+ self.base_url = base_url
33
+ self.app_url = ""
34
+
35
+ # Get app info for better tracking
36
+ self.app_name = os.getenv("APP_NAME", "AskMyDocs")
37
+
38
+ self.update_best_free_model()
39
+
40
+ def update_best_free_model(self) -> bool:
41
+ """
42
+ Find and set the best available free model.
43
+
44
+ Returns:
45
+ Boolean indicating success.
46
+ """
47
+ free_models = self.list_free_models()
48
+
49
+ if not free_models:
50
+ # If API call fails, use fallback list of known free models
51
+ logger.warning(
52
+ "Could not retrieve free models list. Using fallback models.")
53
+ self.model = self._get_fallback_model()
54
+ return False
55
+
56
+ # Sort models by preference:
57
+ # 1. Llama 4 models (highest priority)
58
+ # 2. Gemini models
59
+ # 3. Mistral models
60
+ # 4. DeepSeek models
61
+ # 5. Others
62
+ ranked_models = self._rank_free_models(free_models)
63
+
64
+ if ranked_models:
65
+ self.model = ranked_models[0]["id"]
66
+ logger.info(f"Selected free model: {self.model}")
67
+ return True
68
+ else:
69
+ self.model = self._get_fallback_model()
70
+ logger.warning(
71
+ f"No suitable free models found. Using fallback: {self.model}")
72
+ return False
73
+
74
+ def _rank_free_models(self, free_models: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
75
+ """
76
+ Rank free models by preference for document QA tasks.
77
+
78
+ Args:
79
+ free_models: List of free model dictionaries.
80
+
81
+ Returns:
82
+ Sorted list of models by preference.
83
+ """
84
+ # Define preference tiers
85
+ tier_1_patterns = ["llama-4", "llama4"]
86
+ tier_2_patterns = ["gemini", "claude"]
87
+ tier_3_patterns = ["mistral", "mixtral"]
88
+ tier_4_patterns = ["deepseek"]
89
+
90
+ # Helper function to determine tier
91
+ def get_model_tier(model_id: str) -> int:
92
+ model_id_lower = model_id.lower()
93
+
94
+ # Check for free tag/suffix
95
+ is_free = ":free" in model_id_lower or "-free" in model_id_lower
96
+ if not is_free:
97
+ return 99 # Deprioritize non-free models
98
+
99
+ # Check pattern matches
100
+ for pattern in tier_1_patterns:
101
+ if pattern in model_id_lower:
102
+ return 1
103
+
104
+ for pattern in tier_2_patterns:
105
+ if pattern in model_id_lower:
106
+ return 2
107
+
108
+ for pattern in tier_3_patterns:
109
+ if pattern in model_id_lower:
110
+ return 3
111
+
112
+ for pattern in tier_4_patterns:
113
+ if pattern in model_id_lower:
114
+ return 4
115
+
116
+ return 5 # Other free models
117
+
118
+ # Sort by tier, then by context length (longer is better)
119
+ ranked_models = sorted(
120
+ free_models,
121
+ key=lambda m: (
122
+ get_model_tier(m["id"]),
123
+ # Negative to sort in descending order
124
+ -m.get("context_length", 0)
125
+ )
126
+ )
127
+
128
+ return ranked_models
129
+
130
+ def _get_fallback_model(self) -> str:
131
+ """
132
+ Get a fallback model if API calls fail.
133
+
134
+ Returns:
135
+ Model ID string for a known free model.
136
+ """
137
+ # List of known free models, ordered by preference
138
+ fallback_models = [
139
+ "meta-llama/llama-4-scout:free",
140
+ "google/gemini-2.5-pro-exp-03-25:free",
141
+ "mistralai/mistral-small-3.1-24b-instruct:free",
142
+ "deepseek/deepseek-v3-base:free",
143
+ "nousresearch/deephermes-3-llama-3-8b-preview:free",
144
+ "huggingfaceh4/zephyr-7b-beta" # Always fallback to this older but reliable one
145
+ ]
146
+
147
+ return fallback_models[0]
148
+
149
+ def _get_headers(self) -> Dict[str, str]:
150
+ """
151
+ Get headers for OpenRouter API requests.
152
+
153
+ Returns:
154
+ Dictionary of headers.
155
+ """
156
+ headers = {
157
+ "Content-Type": "application/json"
158
+ }
159
+
160
+ # Add API key if available
161
+ if self.api_key:
162
+ headers["Authorization"] = f"Bearer {self.api_key}"
163
+
164
+ headers["HTTP-Referer"] = self.app_url
165
+ headers["X-Title"] = self.app_name
166
+
167
+ return headers
168
+
169
+ def list_models(self) -> List[Dict[str, Any]]:
170
+ """
171
+ List available models on OpenRouter.
172
+
173
+ Returns:
174
+ List of model information dictionaries.
175
+ """
176
+ try:
177
+ headers = self._get_headers()
178
+
179
+ response = requests.get(
180
+ f"{self.base_url}/models",
181
+ headers=headers
182
+ )
183
+
184
+ if response.status_code == 200:
185
+ return response.json().get("data", [])
186
+ else:
187
+ logger.error(
188
+ f"Error listing models: {response.status_code} - {response.text}"
189
+ )
190
+ return []
191
+
192
+ except Exception as e:
193
+ logger.error(f"Exception listing models: {str(e)}")
194
+ return []
195
+
196
+ def list_free_models(self) -> List[Dict[str, Any]]:
197
+ """
198
+ List models that are free to use on OpenRouter.
199
+
200
+ Returns:
201
+ List of free model information dictionaries.
202
+ """
203
+ # Get all models
204
+ models = self.list_models()
205
+
206
+ # Filter for free models - looking for multiple indicators
207
+ free_models = []
208
+ for model in models:
209
+ model_id = model.get("id", "").lower()
210
+ pricing = model.get("pricing", {})
211
+
212
+ # Check various indicators that a model is free
213
+ is_free = False
214
+
215
+ # Check for explicit free tag in model ID
216
+ if ":free" in model_id or "-free" in model_id:
217
+ is_free = True
218
+
219
+ # Check for zero pricing
220
+ elif (pricing.get("prompt") == 0 and pricing.get("completion") == 0):
221
+ is_free = True
222
+
223
+ # Check for free_tier indicator if present
224
+ elif model.get("free_tier", False):
225
+ is_free = True
226
+
227
+ if is_free:
228
+ free_models.append(model)
229
+
230
+ # Log the number of free models found
231
+ logger.info(f"Found {len(free_models)} free models on OpenRouter")
232
+
233
+ return free_models
234
+
235
+ def _handle_streaming_response(self, response):
236
+ """
237
+ Handle streaming response from OpenRouter API.
238
+
239
+ Args:
240
+ response: Response object from requests.
241
+
242
+ Returns:
243
+ Combined text from streaming response.
244
+ """
245
+ result = ""
246
+
247
+ for line in response.iter_lines():
248
+ if line:
249
+ line_text = line.decode('utf-8')
250
+
251
+ # Remove the "data: " prefix
252
+ if line_text.startswith("data: "):
253
+ line_text = line_text[6:]
254
+
255
+ # Skip keep-alive lines
256
+ if line_text.strip() == "[DONE]":
257
+ break
258
+
259
+ try:
260
+ # Parse the JSON
261
+ json_data = json.loads(line_text)
262
+
263
+ # Extract the text
264
+ if "choices" in json_data and json_data["choices"]:
265
+ delta = json_data["choices"][0].get("delta", {})
266
+ if "content" in delta:
267
+ result += delta["content"]
268
+ except json.JSONDecodeError:
269
+ pass
270
+
271
+ return result
272
+
273
+ def generate(
274
+ self,
275
+ prompt: str,
276
+ temperature: float = 0.0,
277
+ max_tokens: int = 1000,
278
+ stream: bool = False
279
+ ) -> str:
280
+ """
281
+ Generate text using OpenRouter API with a free model.
282
+
283
+ Args:
284
+ prompt: The prompt to send to the model.
285
+ temperature: Controls randomness. Lower is more deterministic.
286
+ max_tokens: Maximum number of tokens to generate.
287
+ stream: Whether to stream the response.
288
+
289
+ Returns:
290
+ Generated text from the model.
291
+ """
292
+ # Ensure we have a model selected
293
+ if not self.model:
294
+ self.update_best_free_model()
295
+
296
+ # If still no model, return error
297
+ if not self.model:
298
+ return "Error: No free models available on OpenRouter."
299
+
300
+ try:
301
+ headers = self._get_headers()
302
+
303
+ # Use OpenAI-compatible format for the request
304
+ payload = {
305
+ "model": self.model,
306
+ "messages": [
307
+ {"role": "user", "content": prompt}
308
+ ],
309
+ "temperature": temperature,
310
+ "max_tokens": max_tokens,
311
+ "stream": stream
312
+ }
313
+
314
+ response = requests.post(
315
+ f"{self.base_url}/chat/completions",
316
+ headers=headers,
317
+ json=payload
318
+ )
319
+
320
+ if response.status_code == 200:
321
+ if stream:
322
+ # Handle streaming response
323
+ return self._handle_streaming_response(response)
324
+ else:
325
+ # Handle regular response
326
+ content = response.json(
327
+ )["choices"][0]["message"]["content"]
328
+ # Log model usage for tracking
329
+ usage = response.json().get("usage", {})
330
+ logger.info(
331
+ f"Used model {self.model} - Input: {usage.get('prompt_tokens', 0)}, Output: {usage.get('completion_tokens', 0)}")
332
+ return content
333
+ else:
334
+ error_info = f"Error {response.status_code}"
335
+ try:
336
+ error_detail = response.json()
337
+ error_message = error_detail.get(
338
+ "error", {}).get("message", "Unknown error")
339
+ error_info = f"{error_info}: {error_message}"
340
+ except:
341
+ error_info = f"{error_info}: {response.text}"
342
+
343
+ logger.error(f"Error generating text: {error_info}")
344
+
345
+ # Check for specific error cases
346
+ if "rate limit" in error_info.lower():
347
+ return "Error: Rate limit exceeded for this free model. Please try again later or try a different model."
348
+
349
+ # If there's an issue with the model, try to get a different one
350
+ if "model" in error_info.lower() or "no endpoints" in error_info.lower():
351
+ prev_model = self.model
352
+ if self.update_best_free_model() and self.model != prev_model:
353
+ logger.info(
354
+ f"Retrying with different free model: {self.model}")
355
+ return self.generate(prompt, temperature, max_tokens, stream)
356
+
357
+ return f"Error: Failed to generate response. {error_info}"
358
+
359
+ except Exception as e:
360
+ logger.error(f"Exception during text generation: {str(e)}")
361
+ return f"Error: {str(e)}"
362
+
363
+
364
+ class OpenRouterFreeChain:
365
+ """Chain for handling Q&A with OpenRouter free LLMs"""
366
+
367
+ def __init__(self, adapter: OpenRouterFreeAdapter):
368
+ """
369
+ Initialize the OpenRouter free chain.
370
+
371
+ Args:
372
+ adapter: An initialized OpenRouterFreeAdapter.
373
+ """
374
+ self.adapter = adapter
375
+
376
+ def create_prompt(self, query: str, context: List[str]) -> str:
377
+ """
378
+ Create a prompt for the LLM based on the query and context.
379
+
380
+ Args:
381
+ query: The user's question.
382
+ context: List of document contents to provide as context.
383
+
384
+ Returns:
385
+ Formatted prompt string.
386
+ """
387
+ context_str = "\n\n".join(
388
+ [f"Document {i+1}:\n{doc}" for i, doc in enumerate(context)])
389
+
390
+ prompt = f"""You are an AI assistant answering questions based on the provided documents.
391
+
392
+ Context information:
393
+ {context_str}
394
+
395
+ Based on the above context, please answer the following question:
396
+ {query}
397
+
398
+ If the information to answer the question is not contained in the provided documents, respond with: "I don't have enough information in the provided documents to answer this question."
399
+
400
+ Answer:"""
401
+
402
+ return prompt
403
+
404
+ def run(self, query: str, context: List[str]) -> str:
405
+ """
406
+ Run the chain to get an answer.
407
+
408
+ Args:
409
+ query: The user's question.
410
+ context: List of document contents to provide as context.
411
+
412
+ Returns:
413
+ Answer from the model.
414
+ """
415
+ prompt = self.create_prompt(query, context)
416
+ return self.adapter.generate(prompt)
417
+
418
+
419
+ def get_best_free_model() -> str:
420
+ """
421
+ Get the best available free model from OpenRouter.
422
+
423
+ Returns:
424
+ Model ID string for the recommended free model.
425
+ """
426
+ adapter = OpenRouterFreeAdapter()
427
+ adapter.update_best_free_model()
428
+ return adapter.model