sachin commited on
Commit
4eda5de
·
1 Parent(s): 43e8ee0
Files changed (3) hide show
  1. Dockerfile +0 -1
  2. src/server/main-v2.py +855 -0
  3. src/server/main.py +713 -352
Dockerfile CHANGED
@@ -6,6 +6,5 @@ COPY . .
6
  ENV HF_HOME=/data/huggingface
7
  # Expose port
8
  EXPOSE 7860
9
- RUN pip install tenacity
10
  # Start the server
11
  CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
 
6
  ENV HF_HOME=/data/huggingface
7
  # Expose port
8
  EXPOSE 7860
 
9
  # Start the server
10
  CMD ["python", "/app/src/server/main.py", "--host", "0.0.0.0", "--port", "7860", "--config", "config_two"]
src/server/main-v2.py ADDED
@@ -0,0 +1,855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ from time import time
5
+ from typing import List, Dict
6
+ import tempfile
7
+ import uvicorn
8
+ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form, APIRouter
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
+ from PIL import Image
12
+ from pydantic import BaseModel, field_validator
13
+ from pydantic_settings import BaseSettings
14
+ from slowapi import Limiter
15
+ from slowapi.util import get_remote_address
16
+ import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
18
+ from IndicTransToolkit import IndicProcessor
19
+ import json
20
+ import asyncio
21
+ from contextlib import asynccontextmanager
22
+ import soundfile as sf
23
+ import numpy as np
24
+ import requests
25
+ from starlette.responses import StreamingResponse
26
+ from logging_config import logger
27
+ from tts_config import SPEED, ResponseFormat, config as tts_config
28
+ import torchaudio
29
+ from tenacity import retry, stop_after_attempt, wait_exponential
30
+ from torch.cuda.amp import autocast
31
+
32
+ # Device setup
33
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
34
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
35
+ logger.info(f"{'GPU' if device != 'cpu' else 'CPU'} will be used for inference")
36
+
37
+ # Check CUDA availability and version
38
+ cuda_available = torch.cuda.is_available()
39
+ cuda_version = torch.version.cuda if cuda_available else None
40
+ if cuda_available:
41
+ device_idx = torch.cuda.current_device()
42
+ capability = torch.cuda.get_device_capability(device_idx)
43
+ compute_capability_float = float(f"{capability[0]}.{capability[1]}")
44
+ print(f"CUDA version: {cuda_version}")
45
+ print(f"CUDA Compute Capability: {compute_capability_float}")
46
+ else:
47
+ print("CUDA is not available on this system.")
48
+
49
+ # Settings
50
+ class Settings(BaseSettings):
51
+ llm_model_name: str = "google/gemma-3-4b-it"
52
+ max_tokens: int = 512
53
+ host: str = "0.0.0.0"
54
+ port: int = 7860
55
+ chat_rate_limit: str = "100/minute"
56
+ speech_rate_limit: str = "5/minute"
57
+
58
+ @field_validator("chat_rate_limit", "speech_rate_limit")
59
+ def validate_rate_limit(cls, v):
60
+ if not v.count("/") == 1 or not v.split("/")[0].isdigit():
61
+ raise ValueError("Rate limit must be in format 'number/period' (e.g., '5/minute')")
62
+ return v
63
+
64
+ class Config:
65
+ env_file = ".env"
66
+
67
+ settings = Settings()
68
+
69
+ # Quantization config for LLM
70
+ quantization_config = BitsAndBytesConfig(
71
+ load_in_4bit=True,
72
+ bnb_4bit_quant_type="nf4",
73
+ bnb_4bit_use_double_quant=True,
74
+ bnb_4bit_compute_dtype=torch.bfloat16
75
+ )
76
+
77
+ # Request queue for concurrency control
78
+ request_queue = asyncio.Queue(maxsize=10)
79
+
80
+ # LLM Manager with batching
81
+ class LLMManager:
82
+ def __init__(self, model_name: str, device: str = device):
83
+ self.model_name = model_name
84
+ self.device = torch.device(device)
85
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
+ self.model = None
87
+ self.processor = None
88
+ self.is_loaded = False
89
+ self.token_cache = {}
90
+ self.load()
91
+ logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
92
+
93
+ def load(self):
94
+ if not self.is_loaded:
95
+ try:
96
+ self.model = Gemma3ForConditionalGeneration.from_pretrained(
97
+ self.model_name,
98
+ device_map="auto",
99
+ quantization_config=quantization_config,
100
+ torch_dtype=self.torch_dtype
101
+ ).eval()
102
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
103
+ dummy_input = self.processor("test", return_tensors="pt").to(self.device)
104
+ with torch.no_grad():
105
+ self.model.generate(**dummy_input, max_new_tokens=10)
106
+ self.is_loaded = True
107
+ logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
108
+ except Exception as e:
109
+ logger.error(f"Failed to load LLM: {str(e)}")
110
+ raise
111
+
112
+ def unload(self):
113
+ if self.is_loaded:
114
+ del self.model
115
+ del self.processor
116
+ if self.device.type == "cuda":
117
+ torch.cuda.empty_cache()
118
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
119
+ self.is_loaded = False
120
+ self.token_cache.clear()
121
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
122
+
123
+ async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
124
+ if not self.is_loaded:
125
+ self.load()
126
+
127
+ cache_key = f"{prompt}:{max_tokens}:{temperature}"
128
+ if cache_key in self.token_cache:
129
+ logger.info("Using cached response")
130
+ return self.token_cache[cache_key]
131
+
132
+ future = asyncio.Future()
133
+ await request_queue.put({"prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, "future": future})
134
+ response = await future
135
+ self.token_cache[cache_key] = response
136
+ logger.info(f"Generated response: {response}")
137
+ return response
138
+
139
+ async def batch_generate(self, prompts: List[Dict]) -> List[str]:
140
+ messages_batch = [
141
+ [
142
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
143
+ {"role": "user", "content": [{"type": "text", "text": prompt["prompt"]}]}
144
+ ]
145
+ for prompt in prompts
146
+ ]
147
+ try:
148
+ inputs_vlm = self.processor.apply_chat_template(
149
+ messages_batch,
150
+ add_generation_prompt=True,
151
+ tokenize=True,
152
+ return_dict=True,
153
+ return_tensors="pt",
154
+ padding=True
155
+ ).to(self.device, dtype=torch.bfloat16)
156
+
157
+ with autocast(), torch.no_grad():
158
+ outputs = self.model.generate(
159
+ **inputs_vlm,
160
+ max_new_tokens=max(prompt["max_tokens"] for prompt in prompts),
161
+ do_sample=True,
162
+ top_p=0.9,
163
+ temperature=max(prompt["temperature"] for prompt in prompts)
164
+ )
165
+ responses = [
166
+ self.processor.decode(output[input_len:], skip_special_tokens=True)
167
+ for output, input_len in zip(outputs, inputs_vlm["input_ids"].shape[1])
168
+ ]
169
+ logger.info(f"Batch generated {len(responses)} responses")
170
+ return responses
171
+ except Exception as e:
172
+ logger.error(f"Error in batch generation: {str(e)}")
173
+ raise HTTPException(status_code=500, detail=f"Batch generation failed: {str(e)}")
174
+
175
+ async def vision_query(self, image: Image.Image, query: str) -> str:
176
+ if not self.is_loaded:
177
+ self.load()
178
+
179
+ messages_vlm = [
180
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]},
181
+ {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
182
+ ]
183
+
184
+ try:
185
+ inputs_vlm = self.processor.apply_chat_template(
186
+ messages_vlm,
187
+ add_generation_prompt=True,
188
+ tokenize=True,
189
+ return_dict=True,
190
+ return_tensors="pt"
191
+ ).to(self.device, dtype=torch.bfloat16)
192
+ except Exception as e:
193
+ logger.error(f"Error in apply_chat_template: {str(e)}")
194
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
195
+
196
+ input_len = inputs_vlm["input_ids"].shape[-1]
197
+ with torch.inference_mode():
198
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
199
+ generation = generation[0][input_len:]
200
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
201
+ logger.info(f"Vision query response: {decoded}")
202
+ return decoded
203
+
204
+ async def chat_v2(self, image: Image.Image, query: str) -> str:
205
+ if not self.is_loaded:
206
+ self.load()
207
+
208
+ messages_vlm = [
209
+ {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]},
210
+ {"role": "user", "content": [{"type": "text", "text": query}] + ([{"type": "image", "image": image}] if image and image.size[0] > 0 and image.size[1] > 0 else [])}
211
+ ]
212
+
213
+ try:
214
+ inputs_vlm = self.processor.apply_chat_template(
215
+ messages_vlm,
216
+ add_generation_prompt=True,
217
+ tokenize=True,
218
+ return_dict=True,
219
+ return_tensors="pt"
220
+ ).to(self.device, dtype=torch.bfloat16)
221
+ except Exception as e:
222
+ logger.error(f"Error in apply_chat_template: {str(e)}")
223
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
224
+
225
+ input_len = inputs_vlm["input_ids"].shape[-1]
226
+ with torch.inference_mode():
227
+ generation = self.model.generate(**inputs_vlm, max_new_tokens=512, do_sample=True, temperature=0.7)
228
+ generation = generation[0][input_len:]
229
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
230
+ logger.info(f"Chat_v2 response: {decoded}")
231
+ return decoded
232
+
233
+ # TTS Manager
234
+ class TTSManager:
235
+ def __init__(self, device_type=device):
236
+ self.device_type = torch.device(device_type)
237
+ self.model = None
238
+ self.repo_id = "ai4bharat/IndicF5"
239
+ self.load()
240
+
241
+ def load(self):
242
+ if not self.model:
243
+ logger.info("Loading TTS model IndicF5...")
244
+ self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
245
+ logger.info("TTS model IndicF5 loaded")
246
+
247
+ def synthesize(self, text, ref_audio_path, ref_text):
248
+ if not self.model:
249
+ raise ValueError("TTS model not loaded")
250
+ with autocast():
251
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
252
+
253
+ # Translation Manager
254
+ class TranslateManager:
255
+ def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
256
+ self.device_type = torch.device(device_type)
257
+ self.tokenizer = None
258
+ self.model = None
259
+ self.src_lang = src_lang
260
+ self.tgt_lang = tgt_lang
261
+ self.use_distilled = use_distilled
262
+ self.load()
263
+
264
+ def load(self):
265
+ if not self.tokenizer or not self.model:
266
+ if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
267
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
268
+ elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
269
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
270
+ elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
271
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
272
+ else:
273
+ raise ValueError("Invalid language combination")
274
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
275
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
276
+ model_name,
277
+ trust_remote_code=True,
278
+ torch_dtype=torch.float16,
279
+ attn_implementation="flash_attention_2"
280
+ ).to(self.device_type)
281
+ self.model = torch.compile(self.model, mode="reduce-overhead")
282
+ logger.info(f"Translation model {model_name} loaded")
283
+
284
+ # Model Manager
285
+ class ModelManager:
286
+ def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
287
+ self.models = {}
288
+ self.device_type = device_type
289
+ self.use_distilled = use_distilled
290
+ self.is_lazy_loading = is_lazy_loading
291
+
292
+ def load_model(self, src_lang, tgt_lang, key):
293
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
294
+ translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
295
+ self.models[key] = translate_manager
296
+ logger.info(f"Loaded translation model for {key}")
297
+
298
+ def get_model(self, src_lang, tgt_lang):
299
+ key = self._get_model_key(src_lang, tgt_lang)
300
+ if key not in self.models and self.is_lazy_loading:
301
+ self.load_model(src_lang, tgt_lang, key)
302
+ return self.models.get(key) or (self.load_model(src_lang, tgt_lang, key) or self.models[key])
303
+
304
+ def _get_model_key(self, src_lang, tgt_lang):
305
+ if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
306
+ return 'eng_indic'
307
+ elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
308
+ return 'indic_eng'
309
+ elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
310
+ return 'indic_indic'
311
+ raise ValueError("Invalid language combination")
312
+
313
+ # ASR Manager
314
+ class ASRModelManager:
315
+ def __init__(self, device_type=device):
316
+ self.device_type = torch.device(device_type)
317
+ self.model = None
318
+ self.model_language = {"kannada": "kn"}
319
+ self.load()
320
+
321
+ def load(self):
322
+ if not self.model:
323
+ logger.info("Loading ASR model...")
324
+ self.model = AutoModel.from_pretrained(
325
+ "ai4bharat/indic-conformer-600m-multilingual",
326
+ trust_remote_code=True
327
+ ).to(self.device_type)
328
+ logger.info("ASR model loaded")
329
+
330
+ # Global Managers
331
+ llm_manager = LLMManager(settings.llm_model_name)
332
+ model_manager = ModelManager()
333
+ asr_manager = ASRModelManager()
334
+ tts_manager = TTSManager()
335
+ ip = IndicProcessor(inference=True)
336
+
337
+ # TTS Constants
338
+ EXAMPLES = [
339
+ {
340
+ "audio_name": "KAN_F (Happy)",
341
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
342
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ।",
343
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
344
+ },
345
+ ]
346
+
347
+ # Pydantic Models
348
+ class SynthesizeRequest(BaseModel):
349
+ text: str
350
+ ref_audio_name: str
351
+ ref_text: str = None
352
+
353
+ class KannadaSynthesizeRequest(BaseModel):
354
+ text: str
355
+
356
+ class ChatRequest(BaseModel):
357
+ prompt: str
358
+ src_lang: str = "kan_Knda"
359
+ tgt_lang: str = "kan_Knda"
360
+
361
+ @field_validator("prompt")
362
+ def prompt_must_be_valid(cls, v):
363
+ if len(v) > 1000:
364
+ raise ValueError("Prompt cannot exceed 1000 characters")
365
+ return v.strip()
366
+
367
+ @field_validator("src_lang", "tgt_lang")
368
+ def validate_language(cls, v):
369
+ if v not in SUPPORTED_LANGUAGES:
370
+ raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
371
+ return v
372
+
373
+ class ChatResponse(BaseModel):
374
+ response: str
375
+
376
+ class TranslationRequest(BaseModel):
377
+ sentences: List[str]
378
+ src_lang: str
379
+ tgt_lang: str
380
+
381
+ class TranscriptionResponse(BaseModel):
382
+ text: str
383
+
384
+ class TranslationResponse(BaseModel):
385
+ translations: List[str]
386
+
387
+ # TTS Functions
388
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
389
+ def load_audio_from_url(url: str):
390
+ response = requests.get(url)
391
+ if response.status_code == 200:
392
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
393
+ return sample_rate, audio_data
394
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL after retries")
395
+
396
+ async def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str) -> io.BytesIO:
397
+ async with request_queue:
398
+ ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None)
399
+ if not ref_audio_url:
400
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
401
+ if not text.strip() or not ref_text.strip():
402
+ raise HTTPException(status_code=400, detail="Text or reference text cannot be empty.")
403
+
404
+ logger.info(f"Synthesizing speech for text: {text[:50]}... with ref_audio: {ref_audio_name}")
405
+ loop = asyncio.get_running_loop()
406
+ sample_rate, audio_data = await loop.run_in_executor(None, load_audio_from_url, ref_audio_url)
407
+
408
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio:
409
+ await loop.run_in_executor(None, sf.write, temp_audio.name, audio_data, sample_rate, "WAV")
410
+ temp_audio.flush()
411
+ audio = tts_manager.synthesize(text, temp_audio.name, ref_text)
412
+
413
+ buffer = io.BytesIO()
414
+ await loop.run_in_executor(None, sf.write, buffer, audio.astype(np.float32) / 32768.0 if audio.dtype == np.int16 else audio, 24000, "WAV")
415
+ buffer.seek(0)
416
+ logger.info("Speech synthesis completed")
417
+ return buffer
418
+
419
+ # Supported Languages
420
+ SUPPORTED_LANGUAGES = {
421
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
422
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
423
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
424
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
425
+ "kan_Knda", "ory_Orya",
426
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
427
+ "por_Latn", "rus_Cyrl", "pol_Latn"
428
+ }
429
+
430
+ # Dependency
431
+ def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
432
+ return model_manager.get_model(src_lang, tgt_lang)
433
+
434
+ # Translation Function
435
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
436
+ try:
437
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
438
+ except ValueError as e:
439
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
440
+ key = model_manager._get_model_key(src_lang, tgt_lang)
441
+ model_manager.load_model(src_lang, tgt_lang, key)
442
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
443
+
444
+ if not translate_manager.model:
445
+ translate_manager.load()
446
+
447
+ batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
448
+ inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
449
+ with torch.no_grad(), autocast():
450
+ generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
451
+ with translate_manager.tokenizer.as_target_tokenizer():
452
+ generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
453
+ return ip.postprocess_batch(generated_tokens, lang=tgt_lang)
454
+
455
+ # Lifespan Event Handler
456
+ translation_configs = []
457
+
458
+ @asynccontextmanager
459
+ async def lifespan(app: FastAPI):
460
+ def load_all_models():
461
+ logger.info("Loading LLM model...")
462
+ llm_manager.load()
463
+ logger.info("Loading TTS model...")
464
+ tts_manager.load()
465
+ logger.info("Loading ASR model...")
466
+ asr_manager.load()
467
+ translation_tasks = [
468
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
469
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
470
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
471
+ ]
472
+ for config in translation_configs:
473
+ src_lang = config["src_lang"]
474
+ tgt_lang = config["tgt_lang"]
475
+ key = model_manager._get_model_key(src_lang, tgt_lang)
476
+ translation_tasks.append((src_lang, tgt_lang, key))
477
+ for src_lang, tgt_lang, key in translation_tasks:
478
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
479
+ model_manager.load_model(src_lang, tgt_lang, key)
480
+ logger.info("All models loaded successfully")
481
+
482
+ logger.info("Starting sequential model loading...")
483
+ load_all_models()
484
+ batch_task = asyncio.create_task(batch_worker())
485
+ yield
486
+ batch_task.cancel()
487
+ llm_manager.unload()
488
+ logger.info("Server shutdown complete")
489
+
490
+ # Batch Worker
491
+ async def batch_worker():
492
+ while True:
493
+ batch = []
494
+ last_request_time = time()
495
+ try:
496
+ while len(batch) < 4:
497
+ try:
498
+ request = await asyncio.wait_for(request_queue.get(), timeout=1.0)
499
+ batch.append(request)
500
+ current_time = time()
501
+ if current_time - last_request_time > 1.0 and batch:
502
+ break
503
+ last_request_time = current_time
504
+ except asyncio.TimeoutError:
505
+ if batch:
506
+ break
507
+ continue
508
+ if batch:
509
+ start_time = time()
510
+ responses = await llm_manager.batch_generate(batch)
511
+ duration = time() - start_time
512
+ logger.info(f"Batch of {len(batch)} requests processed in {duration:.3f} seconds")
513
+ for request, response in zip(batch, responses):
514
+ request["future"].set_result(response)
515
+ except Exception as e:
516
+ logger.error(f"Batch worker error: {str(e)}")
517
+ for request in batch:
518
+ request["future"].set_exception(e)
519
+
520
+ # FastAPI App
521
+ app = FastAPI(
522
+ title="Dhwani API",
523
+ description="AI Chat API supporting Indian languages",
524
+ version="1.0.0",
525
+ redirect_slashes=False,
526
+ lifespan=lifespan
527
+ )
528
+
529
+ app.add_middleware(
530
+ CORSMiddleware,
531
+ allow_origins=["*"],
532
+ allow_credentials=False,
533
+ allow_methods=["*"],
534
+ allow_headers=["*"],
535
+ )
536
+
537
+ @app.middleware("http")
538
+ async def add_request_timing(request: Request, call_next):
539
+ start_time = time()
540
+ response = await call_next(request)
541
+ duration = time() - start_time
542
+ logger.info(f"Request to {request.url.path} took {duration:.3f} seconds")
543
+ response.headers["X-Response-Time"] = f"{duration:.3f}"
544
+ return response
545
+
546
+ limiter = Limiter(key_func=get_remote_address)
547
+ app.state.limiter = limiter
548
+
549
+ # Endpoints
550
+ @app.post("/audio/speech", response_class=StreamingResponse)
551
+ async def synthesize_kannada(request: KannadaSynthesizeRequest):
552
+ if not tts_manager.model:
553
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
554
+ kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
555
+ if not request.text.strip():
556
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
557
+ audio_buffer = await synthesize_speech(tts_manager, request.text, "KAN_F (Happy)", kannada_example["ref_text"])
558
+ return StreamingResponse(
559
+ audio_buffer,
560
+ media_type="audio/wav",
561
+ headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
562
+ )
563
+
564
+ @app.post("/translate", response_model=TranslationResponse)
565
+ async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
566
+ if not request.sentences:
567
+ raise HTTPException(status_code=400, detail="Input sentences are required")
568
+ batch = ip.preprocess_batch(request.sentences, src_lang=request.src_lang, tgt_lang=request.tgt_lang)
569
+ inputs = translate_manager.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt", return_attention_mask=True).to(translate_manager.device_type)
570
+ with torch.no_grad(), autocast():
571
+ generated_tokens = translate_manager.model.generate(**inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1)
572
+ with translate_manager.tokenizer.as_target_tokenizer():
573
+ generated_tokens = translate_manager.tokenizer.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
574
+ translations = ip.postprocess_batch(generated_tokens, lang=request.tgt_lang)
575
+ return TranslationResponse(translations=translations)
576
+
577
+ @app.get("/v1/health")
578
+ async def health_check():
579
+ memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0
580
+ llm_status = "unhealthy"
581
+ llm_latency = None
582
+ if llm_manager.is_loaded:
583
+ start = time()
584
+ try:
585
+ llm_test = await llm_manager.generate("What is the capital of Karnataka?", max_tokens=10)
586
+ llm_latency = time() - start
587
+ llm_status = "healthy" if llm_test else "unhealthy"
588
+ except Exception as e:
589
+ logger.error(f"LLM health check failed: {str(e)}")
590
+
591
+ tts_status = "unhealthy"
592
+ tts_latency = None
593
+ if tts_manager.model:
594
+ start = time()
595
+ try:
596
+ audio_buffer = await synthesize_speech(tts_manager, "Test", "KAN_F (Happy)", EXAMPLES[0]["ref_text"])
597
+ tts_latency = time() - start
598
+ tts_status = "healthy" if audio_buffer else "unhealthy"
599
+ except Exception as e:
600
+ logger.error(f"TTS health check failed: {str(e)}")
601
+
602
+ asr_status = "unhealthy"
603
+ asr_latency = None
604
+ if asr_manager.model:
605
+ start = time()
606
+ try:
607
+ dummy_audio = np.zeros(16000, dtype=np.float32)
608
+ wav = torch.tensor(dummy_audio).unsqueeze(0).to(device)
609
+ with autocast(), torch.no_grad():
610
+ asr_test = asr_manager.model(wav, asr_manager.model_language["kannada"], "rnnt")
611
+ asr_latency = time() - start
612
+ asr_status = "healthy" if asr_test else "unhealthy"
613
+ except Exception as e:
614
+ logger.error(f"ASR health check failed: {str(e)}")
615
+
616
+ status = {
617
+ "status": "healthy" if llm_status == "healthy" and tts_status == "healthy" and asr_status == "healthy" else "degraded",
618
+ "model": settings.llm_model_name,
619
+ "llm_status": llm_status,
620
+ "llm_latency": f"{llm_latency:.3f}s" if llm_latency else "N/A",
621
+ "tts_status": tts_status,
622
+ "tts_latency": f"{tts_latency:.3f}s" if tts_latency else "N/A",
623
+ "asr_status": asr_status,
624
+ "asr_latency": f"{asr_latency:.3f}s" if asr_latency else "N/A",
625
+ "gpu_memory_usage": f"{memory_usage:.2%}"
626
+ }
627
+ logger.info("Health check completed")
628
+ return status
629
+
630
+ @app.get("/")
631
+ async def home():
632
+ return RedirectResponse(url="/docs")
633
+
634
+ @app.post("/v1/unload_all_models")
635
+ async def unload_all_models():
636
+ try:
637
+ logger.info("Starting to unload all models...")
638
+ llm_manager.unload()
639
+ logger.info("All models unloaded successfully")
640
+ return {"status": "success", "message": "All models unloaded"}
641
+ except Exception as e:
642
+ logger.error(f"Error unloading models: {str(e)}")
643
+ raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
644
+
645
+ @app.post("/v1/load_all_models")
646
+ async def load_all_models():
647
+ try:
648
+ logger.info("Starting to load all models...")
649
+ llm_manager.load()
650
+ logger.info("All models loaded successfully")
651
+ return {"status": "success", "message": "All models loaded"}
652
+ except Exception as e:
653
+ logger.error(f"Error loading models: {str(e)}")
654
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
655
+
656
+ @app.post("/v1/translate", response_model=TranslationResponse)
657
+ async def translate_endpoint(request: TranslationRequest):
658
+ logger.info(f"Received translation request: {request.dict()}")
659
+ try:
660
+ translations = await perform_internal_translation(request.sentences, request.src_lang, request.tgt_lang)
661
+ logger.info(f"Translation successful: {translations}")
662
+ return TranslationResponse(translations=translations)
663
+ except Exception as e:
664
+ logger.error(f"Unexpected error during translation: {str(e)}")
665
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
666
+
667
+ @app.post("/v1/chat", response_model=ChatResponse)
668
+ @limiter.limit(settings.chat_rate_limit)
669
+ async def chat(request: Request, chat_request: ChatRequest):
670
+ if not chat_request.prompt:
671
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
672
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
673
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
674
+ try:
675
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
676
+ translated_prompt = await perform_internal_translation([chat_request.prompt], chat_request.src_lang, "eng_Latn")
677
+ prompt_to_process = translated_prompt[0]
678
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
679
+ else:
680
+ prompt_to_process = chat_request.prompt
681
+ logger.info("Prompt in English or European language, no translation needed")
682
+
683
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
684
+ logger.info(f"Generated English response: {response}")
685
+
686
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
687
+ translated_response = await perform_internal_translation([response], "eng_Latn", chat_request.tgt_lang)
688
+ final_response = translated_response[0]
689
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
690
+ else:
691
+ final_response = response
692
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
693
+ return ChatResponse(response=final_response)
694
+ except Exception as e:
695
+ logger.error(f"Error processing request: {str(e)}")
696
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
697
+
698
+ @app.post("/v1/visual_query/")
699
+ async def visual_query(
700
+ file: UploadFile = File(...),
701
+ query: str = Body(...),
702
+ src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
703
+ tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
704
+ ):
705
+ try:
706
+ image = Image.open(file.file)
707
+ if image.size == (0, 0):
708
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
709
+ if src_lang != "eng_Latn":
710
+ translated_query = await perform_internal_translation([query], src_lang, "eng_Latn")
711
+ query_to_process = translated_query[0]
712
+ logger.info(f"Translated query to English: {query_to_process}")
713
+ else:
714
+ query_to_process = query
715
+ logger.info("Query already in English, no translation needed")
716
+ answer = await llm_manager.vision_query(image, query_to_process)
717
+ logger.info(f"Generated English answer: {answer}")
718
+ if tgt_lang != "eng_Latn":
719
+ translated_answer = await perform_internal_translation([answer], "eng_Latn", tgt_lang)
720
+ final_answer = translated_answer[0]
721
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
722
+ else:
723
+ final_answer = answer
724
+ logger.info("Answer kept in English, no translation needed")
725
+ return {"answer": final_answer}
726
+ except Exception as e:
727
+ logger.error(f"Error processing request: {str(e)}")
728
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
729
+
730
+ @app.post("/v1/chat_v2", response_model=ChatResponse)
731
+ @limiter.limit(settings.chat_rate_limit)
732
+ async def chat_v2(
733
+ request: Request,
734
+ prompt: str = Form(...),
735
+ image: UploadFile = File(default=None),
736
+ src_lang: str = Form("kan_Knda"),
737
+ tgt_lang: str = Form("kan_Knda"),
738
+ ):
739
+ if not prompt:
740
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
741
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
742
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
743
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
744
+ try:
745
+ if image:
746
+ image_data = await image.read()
747
+ if not image_data:
748
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
749
+ img = Image.open(io.BytesIO(image_data))
750
+ if src_lang != "eng_Latn":
751
+ translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
752
+ prompt_to_process = translated_prompt[0]
753
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
754
+ else:
755
+ prompt_to_process = prompt
756
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
757
+ logger.info(f"Generated English response: {decoded}")
758
+ if tgt_lang != "eng_Latn":
759
+ translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
760
+ final_response = translated_response[0]
761
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
762
+ else:
763
+ final_response = decoded
764
+ else:
765
+ if src_lang != "eng_Latn":
766
+ translated_prompt = await perform_internal_translation([prompt], src_lang, "eng_Latn")
767
+ prompt_to_process = translated_prompt[0]
768
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
769
+ else:
770
+ prompt_to_process = prompt
771
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
772
+ logger.info(f"Generated English response: {decoded}")
773
+ if tgt_lang != "eng_Latn":
774
+ translated_response = await perform_internal_translation([decoded], "eng_Latn", tgt_lang)
775
+ final_response = translated_response[0]
776
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
777
+ else:
778
+ final_response = decoded
779
+ return ChatResponse(response=final_response)
780
+ except Exception as e:
781
+ logger.error(f"Error processing request: {str(e)}")
782
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
783
+
784
+ @app.post("/transcribe/", response_model=TranscriptionResponse)
785
+ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
786
+ if not asr_manager.model:
787
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
788
+ try:
789
+ wav, sr = torchaudio.load(file.file, backend="cuda" if cuda_available else "cpu")
790
+ wav = torch.mean(wav, dim=0, keepdim=True).to(device)
791
+ target_sample_rate = 16000
792
+ if sr != target_sample_rate:
793
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate).to(device)
794
+ wav = resampler(wav)
795
+ with autocast(), torch.no_grad():
796
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
797
+ return TranscriptionResponse(text=transcription_rnnt)
798
+ except Exception as e:
799
+ logger.error(f"Error in transcription: {str(e)}")
800
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
801
+
802
+ @app.post("/v1/speech_to_speech")
803
+ async def speech_to_speech(
804
+ request: Request,
805
+ file: UploadFile = File(...),
806
+ language: str = Query(..., enum=list(asr_manager.model_language.keys())),
807
+ ) -> StreamingResponse:
808
+ if not tts_manager.model:
809
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
810
+ transcription = await transcribe_audio(file, language)
811
+ logger.info(f"Transcribed text: {transcription.text}")
812
+ chat_request = ChatRequest(prompt=transcription.text, src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"), tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"))
813
+ processed_text = await chat(request, chat_request)
814
+ logger.info(f"Processed text: {processed_text.response}")
815
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
816
+ audio_response = await synthesize_kannada(voice_request)
817
+ return audio_response
818
+
819
+ LANGUAGE_TO_SCRIPT = {"kannada": "kan_Knda"}
820
+
821
+ if __name__ == "__main__":
822
+ parser = argparse.ArgumentParser(description="Run the FastAPI server.")
823
+ parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
824
+ parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
825
+ parser.add_argument("--config", type=str, default="config_one", help="Configuration to use")
826
+ args = parser.parse_args()
827
+
828
+ def load_config(config_path="dhwani_config.json"):
829
+ with open(config_path, "r") as f:
830
+ return json.load(f)
831
+
832
+ config_data = load_config()
833
+ if args.config not in config_data["configs"]:
834
+ raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
835
+
836
+ selected_config = config_data["configs"][args.config]
837
+ global_settings = config_data["global_settings"]
838
+
839
+ settings.llm_model_name = selected_config["components"]["LLM"]["model"]
840
+ settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
841
+ settings.host = global_settings["host"]
842
+ settings.port = global_settings["port"]
843
+ settings.chat_rate_limit = global_settings["chat_rate_limit"]
844
+ settings.speech_rate_limit = global_settings["speech_rate_limit"]
845
+
846
+ llm_manager = LLMManager(settings.llm_model_name)
847
+ if selected_config["components"]["ASR"]:
848
+ asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
849
+ if selected_config["components"]["Translation"]:
850
+ translation_configs.extend(selected_config["components"]["Translation"])
851
+
852
+ host = args.host if args.host != settings.host else settings.host
853
+ port = args.port if args.port != settings.port else settings.port
854
+
855
+ uvicorn.run(app, host=host, port=port, workers=2)
src/server/main.py CHANGED
@@ -1,11 +1,11 @@
1
  import argparse
2
  import io
3
  import os
4
- import tempfile
5
  from time import time
6
  from typing import List
 
7
  import uvicorn
8
- from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form, APIRouter
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from PIL import Image
@@ -14,7 +14,7 @@ from pydantic_settings import BaseSettings
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, AutoModel, Gemma3ForConditionalGeneration
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
@@ -22,28 +22,32 @@ from contextlib import asynccontextmanager
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
25
- import logging
26
  from starlette.responses import StreamingResponse
27
- from logging_config import logger # Assumed external logging config
28
- from tts_config import SPEED, ResponseFormat, config as tts_config # Assumed external TTS config
29
  import torchaudio
30
- from tenacity import retry, stop_after_attempt, wait_exponential
31
- from torch.cuda.amp import autocast
32
 
33
  # Device setup
34
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
35
- torch_dtype = torch.float16 if device != "cpu" else torch.float32 # Use float16 for speed
36
- logger.info(f"Using device: {device} with dtype: {torch_dtype}")
 
 
 
 
37
 
38
  # Check CUDA availability and version
39
  cuda_available = torch.cuda.is_available()
40
  cuda_version = torch.version.cuda if cuda_available else None
41
- if cuda_available:
 
42
  device_idx = torch.cuda.current_device()
43
  capability = torch.cuda.get_device_capability(device_idx)
44
- logger.info(f"CUDA version: {cuda_version}, Compute Capability: {capability[0]}.{capability[1]}")
 
 
45
  else:
46
- logger.info("CUDA is not available; falling back to CPU.")
47
 
48
  # Settings
49
  class Settings(BaseSettings):
@@ -65,49 +69,41 @@ class Settings(BaseSettings):
65
 
66
  settings = Settings()
67
 
68
- # Request queue for concurrency control (max 10 concurrent GPU tasks)
69
- request_queue = asyncio.Queue(maxsize=10)
70
-
71
- # Logging optimization
72
- logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
 
 
73
 
74
- # LLM Manager with persistent loading and improved caching
75
  class LLMManager:
76
- def __init__(self, model_name: str, device: str = device):
77
  self.model_name = model_name
78
  self.device = torch.device(device)
79
- self.torch_dtype = torch_dtype
80
  self.model = None
81
  self.processor = None
82
  self.is_loaded = False
83
- self.token_cache = {}
84
- self.load() # Load persistently at initialization
85
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
86
 
87
  def load(self):
88
  if not self.is_loaded:
89
  try:
90
- if self.device.type == "cuda":
91
- torch.set_float32_matmul_precision('high')
92
- logger.info("Enabled TF32 matrix multiplication for improved GPU performance")
93
-
94
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
95
  self.model_name,
96
  device_map="auto",
97
- torch_dtype=torch.float16, # Use float16 for speed
98
- max_memory={0: "10GiB"}
99
- ).eval()
100
-
101
- self.processor = AutoProcessor.from_pretrained(self.model_name, use_fast=True)
102
- # Warm-up model
103
- dummy_input = self.processor("test", return_tensors="pt").to(self.device)
104
- with torch.no_grad():
105
- self.model.generate(**dummy_input, max_new_tokens=10)
106
  self.is_loaded = True
107
- logger.info(f"LLM {self.model_name} loaded and warmed up on {self.device}")
108
  except Exception as e:
109
  logger.error(f"Failed to load LLM: {str(e)}")
110
- self.is_loaded = False # Allow graceful degradation
111
 
112
  def unload(self):
113
  if self.is_loaded:
@@ -115,29 +111,74 @@ class LLMManager:
115
  del self.processor
116
  if self.device.type == "cuda":
117
  torch.cuda.empty_cache()
118
- logger.info(f"GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
119
  self.is_loaded = False
120
- self.token_cache.clear()
121
- logger.info(f"LLM {self.model_name} unloaded")
122
 
123
- async def generate(self, prompt: str, max_tokens: int = settings.max_tokens, temperature: float = 0.7) -> str:
124
  if not self.is_loaded:
125
- logger.warning("LLM not loaded; attempting reload")
126
  self.load()
127
- if not self.is_loaded:
128
- raise HTTPException(status_code=503, detail="LLM model unavailable")
129
 
130
- # Improved cache key with parameters
131
- cache_key = f"{prompt}:{max_tokens}:{temperature}"
132
- if cache_key in self.token_cache:
133
- logger.info("Using cached response")
134
- return self.token_cache[cache_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  messages_vlm = [
137
- {"role": "system", "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]},
138
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
 
 
 
 
 
 
139
  ]
140
 
 
 
 
 
 
 
 
141
  try:
142
  inputs_vlm = self.processor.apply_chat_template(
143
  messages_vlm,
@@ -145,160 +186,248 @@ class LLMManager:
145
  tokenize=True,
146
  return_dict=True,
147
  return_tensors="pt"
148
- ).to(self.device)
149
-
150
- with autocast(): # Mixed precision for speed
151
- generation = self.model.generate(
152
- **inputs_vlm,
153
- max_new_tokens=max_tokens,
154
- do_sample=True,
155
- top_p=0.9,
156
- temperature=temperature
157
- )
158
- generation = generation[0][inputs_vlm["input_ids"].shape[-1]:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- response = self.processor.decode(generation, skip_special_tokens=True)
161
- self.token_cache[cache_key] = response
162
- logger.info(f"Generated response: {response}")
163
- return response
 
 
 
 
164
  except Exception as e:
165
- logger.error(f"Error in generation: {str(e)}")
166
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- # TTS Manager with file-based synthesis
169
  class TTSManager:
170
  def __init__(self, device_type=device):
171
- self.device_type = torch.device(device_type)
172
  self.model = None
173
  self.repo_id = "ai4bharat/IndicF5"
174
- self.load() # Persistent loading
175
 
176
  def load(self):
177
  if not self.model:
178
- logger.info(f"Loading TTS model {self.repo_id} on {self.device_type}...")
179
- self.model = AutoModel.from_pretrained(self.repo_id, trust_remote_code=True).to(self.device_type)
180
- logger.info("TTS model loaded")
181
-
182
- def unload(self):
183
- if self.model:
184
- del self.model
185
- if self.device_type.type == "cuda":
186
- torch.cuda.empty_cache()
187
- logger.info(f"TTS GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
188
- self.model = None
189
- logger.info("TTS model unloaded")
190
 
191
  def synthesize(self, text, ref_audio_path, ref_text):
192
  if not self.model:
193
  raise ValueError("TTS model not loaded")
194
- with autocast(): # Mixed precision
195
- return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Translation Manager with warm-up and error handling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  class TranslateManager:
199
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
200
- self.device_type = torch.device(device_type)
201
- self.tokenizer, self.model = self.initialize_model(src_lang, tgt_lang, use_distilled)
202
- if self.model:
203
- self.warm_up()
 
 
204
 
205
- def initialize_model(self, src_lang, tgt_lang, use_distilled=True):
206
- try:
207
- if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
208
- model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if use_distilled else "ai4bharat/indictrans2-en-indic-1B"
209
- elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
210
- model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if use_distilled else "ai4bharat/indictrans2-indic-en-1B"
211
- elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
212
- model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
213
  else:
214
  raise ValueError("Invalid language combination")
215
 
216
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
217
- model = AutoModelForSeq2SeqLM.from_pretrained(
 
 
 
218
  model_name,
219
  trust_remote_code=True,
220
  torch_dtype=torch.float16,
221
  attn_implementation="flash_attention_2"
222
- ).to(self.device_type)
223
- return tokenizer, model
224
- except Exception as e:
225
- logger.error(f"Failed to load translation model: {str(e)}")
226
- return None, None # Graceful degradation
227
-
228
- def warm_up(self):
229
- dummy_input = self.tokenizer("test", return_tensors="pt").to(self.device_type)
230
- with torch.no_grad(), autocast():
231
- self.model.generate(**dummy_input, max_length=10)
232
- logger.info("Translation model warmed up")
233
-
234
- def unload(self):
235
- if self.model:
236
- del self.model
237
- del self.tokenizer
238
- if self.device_type.type == "cuda":
239
- torch.cuda.empty_cache()
240
- logger.info(f"Translation GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
241
- self.model = None
242
- self.tokenizer = None
243
- logger.info("Translation model unloaded")
244
 
245
- # Model Manager with preloading
246
  class ModelManager:
247
- def __init__(self, device_type=device, use_distilled=True):
248
  self.models = {}
249
  self.device_type = device_type
250
  self.use_distilled = use_distilled
251
- self.preload_models()
252
 
253
- def preload_models(self):
254
- translation_pairs = [
255
- ('eng_Latn', 'kan_Knda', 'eng_indic'),
256
- ('kan_Knda', 'eng_Latn', 'indic_eng'),
257
- ('kan_Knda', 'hin_Deva', 'indic_indic')
258
- ]
259
- for src_lang, tgt_lang, key in translation_pairs:
260
- logger.info(f"Preloading translation model for {src_lang} -> {tgt_lang}...")
261
- self.models[key] = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
262
 
263
  def get_model(self, src_lang, tgt_lang):
 
 
 
 
 
 
 
 
 
264
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
265
- key = 'eng_indic'
266
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
267
- key = 'indic_eng'
268
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
269
- key = 'indic_indic'
270
- else:
271
- raise ValueError("Invalid language combination")
272
- if key not in self.models or not self.models[key].model:
273
- raise HTTPException(status_code=503, detail=f"Translation model for {key} unavailable")
274
- return self.models[key]
275
 
276
- # ASR Manager with GPU audio processing
277
  class ASRModelManager:
278
- def __init__(self, device_type=device):
279
- self.device_type = torch.device(device_type)
280
  self.model = None
281
  self.model_language = {"kannada": "kn"}
282
- self.load()
283
 
284
  def load(self):
285
  if not self.model:
286
- logger.info(f"Loading ASR model on {self.device_type}...")
287
  self.model = AutoModel.from_pretrained(
288
  "ai4bharat/indic-conformer-600m-multilingual",
289
  trust_remote_code=True
290
- ).to(self.device_type)
 
291
  logger.info("ASR model loaded")
292
 
293
- def unload(self):
294
- if self.model:
295
- del self.model
296
- if self.device_type.type == "cuda":
297
- torch.cuda.empty_cache()
298
- logger.info(f"ASR GPU memory cleared: {torch.cuda.memory_allocated()} bytes allocated")
299
- self.model = None
300
- logger.info("ASR model unloaded")
301
-
302
  # Global Managers
303
  llm_manager = LLMManager(settings.llm_model_name)
304
  model_manager = ModelManager()
@@ -306,15 +435,6 @@ asr_manager = ASRModelManager()
306
  tts_manager = TTSManager()
307
  ip = IndicProcessor(inference=True)
308
 
309
- # TTS Constants
310
- EXAMPLES = [
311
- {
312
- "audio_name": "KAN_F (Happy)",
313
- "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
314
- "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ।",
315
- },
316
- ]
317
-
318
  # Pydantic Models
319
  class ChatRequest(BaseModel):
320
  prompt: str
@@ -327,70 +447,92 @@ class ChatRequest(BaseModel):
327
  raise ValueError("Prompt cannot exceed 1000 characters")
328
  return v.strip()
329
 
 
 
 
 
 
 
 
330
  class ChatResponse(BaseModel):
331
  response: str
332
 
333
- class KannadaSynthesizeRequest(BaseModel):
334
- text: str
335
-
336
- @field_validator("text")
337
- def text_must_be_valid(cls, v):
338
- if len(v) > 500:
339
- raise ValueError("Text cannot exceed 500 characters")
340
- return v.strip()
341
 
342
  class TranscriptionResponse(BaseModel):
343
  text: str
344
 
345
- # TTS Functions
346
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
347
- def load_audio_from_url(url: str):
348
- response = requests.get(url)
349
- if response.status_code == 200:
350
- audio_data, sample_rate = sf.read(io.BytesIO(response.content))
351
- return sample_rate, audio_data
352
- raise HTTPException(status_code=500, detail="Failed to load reference audio from URL after retries")
353
-
354
- async def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str) -> io.BytesIO:
355
- async with request_queue:
356
- ref_audio_url = None
357
- for example in EXAMPLES:
358
- if example["audio_name"] == ref_audio_name:
359
- ref_audio_url = example["audio_url"]
360
- if not ref_text:
361
- ref_text = example["ref_text"]
362
- break
363
-
364
- if not ref_audio_url:
365
- raise HTTPException(status_code=400, detail=f"Invalid reference audio name: {ref_audio_name}")
366
- if not text.strip() or not ref_text.strip():
367
- raise HTTPException(status_code=400, detail="Text or reference text cannot be empty")
368
-
369
- logger.info(f"Synthesizing speech for text: {text[:50]}... with ref_audio: {ref_audio_name}")
370
- sample_rate, audio_data = load_audio_from_url(ref_audio_url)
371
-
372
- # Use temporary file since IndicF5 requires a path
373
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_ref_audio:
374
- sf.write(temp_ref_audio.name, audio_data, sample_rate, format='WAV')
375
- temp_ref_audio.flush()
376
- audio = tts_manager.synthesize(text, temp_ref_audio.name, ref_text)
377
-
378
- if audio.dtype == np.int16:
379
- audio = audio.astype(np.float32) / 32768.0
380
- output_buffer = io.BytesIO()
381
- sf.write(output_buffer, audio, 24000, format='WAV')
382
- output_buffer.seek(0)
383
- logger.info("Speech synthesis completed")
384
- return output_buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  # FastAPI App
387
  app = FastAPI(
388
- title="Optimized Dhwani API",
389
- description="AI Chat API with optimized performance and robustness",
390
  version="1.0.0",
 
391
  lifespan=lifespan
392
  )
393
 
 
394
  app.add_middleware(
395
  CORSMiddleware,
396
  allow_origins=["*"],
@@ -399,6 +541,7 @@ app.add_middleware(
399
  allow_headers=["*"],
400
  )
401
 
 
402
  @app.middleware("http")
403
  async def add_request_timing(request: Request, call_next):
404
  start_time = time()
@@ -412,157 +555,375 @@ async def add_request_timing(request: Request, call_next):
412
  limiter = Limiter(key_func=get_remote_address)
413
  app.state.limiter = limiter
414
 
415
- # Lifespan Event Handler
416
- @asynccontextmanager
417
- async def lifespan(app: FastAPI):
418
- logger.info("Starting server with preloaded models...")
419
- yield
420
- llm_manager.unload()
421
- tts_manager.unload()
422
- asr_manager.unload()
423
- for model in model_manager.models.values():
424
- model.unload()
425
- logger.info("Server shutdown complete; all models unloaded")
426
-
427
- # Endpoints
428
- @app.post("/v1/speech_to_speech", response_class=StreamingResponse)
429
- async def speech_to_speech(
430
- request: Request,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  file: UploadFile = File(...),
432
- language: str = Query(..., enum=list(asr_manager.model_language.keys())),
 
 
433
  ):
434
- async with request_queue:
435
- if not tts_manager.model or not asr_manager.model:
436
- raise HTTPException(status_code=503, detail="TTS or ASR model not loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
- audio_data = await file.read()
439
- if not audio_data:
440
- raise HTTPException(status_code=400, detail="Uploaded audio file is empty")
441
- if len(audio_data) > 10 * 1024 * 1024:
442
- raise HTTPException(status_code=400, detail="Audio file exceeds 10MB limit")
443
 
444
- logger.info(f"Processing speech-to-speech for file: {file.filename} in language: {language}")
445
- try:
446
- # GPU-accelerated transcription
447
- wav, sr = torchaudio.load(io.BytesIO(audio_data), backend="cuda" if cuda_available else "cpu")
448
- wav = torch.mean(wav, dim=0, keepdim=True).to(device)
449
- target_sample_rate = 16000
450
- if sr != target_sample_rate:
451
- resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device)
452
- wav = resampler(wav)
453
- with autocast(), torch.no_grad():
454
- transcription = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
455
- logger.info(f"Transcribed text: {transcription[:50]}...")
456
-
457
- chat_request = ChatRequest(
458
- prompt=transcription,
459
- src_lang="kan_Knda",
460
- tgt_lang="kan_Knda"
461
  )
462
- translate_mgr = model_manager.get_model(chat_request.src_lang, "eng_Latn")
463
- if translate_mgr.model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  translated_prompt = await perform_internal_translation(
465
- [chat_request.prompt], chat_request.src_lang, "eng_Latn"
 
 
466
  )
467
  prompt_to_process = translated_prompt[0]
 
468
  else:
469
- prompt_to_process = chat_request.prompt
470
-
471
- response = await llm_manager.generate(prompt_to_process)
472
- if chat_request.tgt_lang != "eng_Latn":
473
- translate_mgr = model_manager.get_model("eng_Latn", chat_request.tgt_lang)
474
- if translate_mgr.model:
475
- translated_response = await perform_internal_translation(
476
- [response], "eng_Latn", chat_request.tgt_lang
477
- )
478
- final_response = translated_response[0]
479
- else:
480
- final_response = response
481
- else:
482
- final_response = response
483
- logger.info(f"Processed text: {final_response[:50]}...")
484
-
485
- audio_buffer = await synthesize_speech(tts_manager, final_response, "KAN_F (Happy)", EXAMPLES[0]["ref_text"])
486
- logger.info("Speech-to-speech processing completed")
487
- return StreamingResponse(
488
- audio_buffer,
489
- media_type="audio/wav",
490
- headers={"Content-Disposition": "attachment; filename=speech_to_speech_output.wav"}
491
- )
492
- except Exception as e:
493
- logger.error(f"Error in speech-to-speech pipeline: {str(e)}")
494
- raise HTTPException(status_code=500, detail=f"Speech-to-speech failed: {str(e)}")
495
 
496
- @app.post("/v1/chat", response_model=ChatResponse)
497
- @limiter.limit(settings.chat_rate_limit)
498
- async def chat(request: Request, chat_request: ChatRequest):
499
- async with request_queue:
500
- logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
501
- try:
502
- if chat_request.src_lang != "eng_Latn":
503
- translate_mgr = model_manager.get_model(chat_request.src_lang, "eng_Latn")
504
- if translate_mgr.model:
505
- translated_prompt = await perform_internal_translation(
506
- [chat_request.prompt], chat_request.src_lang, "eng_Latn"
507
- )
508
- prompt_to_process = translated_prompt[0]
509
- logger.info(f"Translated prompt to English: {prompt_to_process}")
510
- else:
511
- prompt_to_process = chat_request.prompt
512
  else:
513
- prompt_to_process = chat_request.prompt
514
-
515
- response = await llm_manager.generate(prompt_to_process)
516
- logger.info(f"Generated English response: {response}")
517
-
518
- if chat_request.tgt_lang != "eng_Latn":
519
- translate_mgr = model_manager.get_model("eng_Latn", chat_request.tgt_lang)
520
- if translate_mgr.model:
521
- translated_response = await perform_internal_translation(
522
- [response], "eng_Latn", chat_request.tgt_lang
523
- )
524
- final_response = translated_response[0]
525
- logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
526
- else:
527
- final_response = response
528
  else:
529
- final_response = response
530
- return ChatResponse(response=final_response)
531
- except Exception as e:
532
- logger.error(f"Error in chat: {str(e)}")
533
- raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")
534
-
535
- async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
536
- translate_mgr = model_manager.get_model(src_lang, tgt_lang)
537
- if not translate_mgr.model:
538
- raise HTTPException(status_code=503, detail="Translation model unavailable")
539
- batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=tgt_lang)
540
- inputs = translate_mgr.tokenizer(batch, truncation=True, padding="longest", return_tensors="pt").to(device)
541
- with torch.no_grad(), autocast():
542
- tokens = translate_mgr.model.generate(**inputs, max_length=256, num_beams=5)
543
- translations = translate_mgr.tokenizer.batch_decode(tokens, skip_special_tokens=True)
544
- return ip.postprocess_batch(translations, lang=tgt_lang)
545
 
546
- @app.get("/v1/health")
547
- async def health_check():
548
- memory_usage = torch.cuda.memory_allocated() / (24 * 1024**3) if cuda_available else 0 # 24GB VRAM
549
- if memory_usage > 0.9:
550
- logger.warning("GPU memory usage exceeds 90%; consider unloading models")
551
- status = {
552
- "status": "healthy",
553
- "llm_loaded": llm_manager.is_loaded,
554
- "tts_loaded": bool(tts_manager.model),
555
- "asr_loaded": bool(asr_manager.model),
556
- "translation_models": list(model_manager.models.keys()),
557
- "gpu_memory_usage": f"{memory_usage:.2%}"
558
- }
559
- return status
560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  if __name__ == "__main__":
562
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
563
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
564
  parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
 
565
  args = parser.parse_args()
566
 
567
- # Uvicorn tuning: 2 workers for 8 vCPUs and 24GB VRAM
568
- uvicorn.run(app, host=args.host, port=args.port, workers=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import io
3
  import os
 
4
  from time import time
5
  from typing import List
6
+ import tempfile
7
  import uvicorn
8
+ from fastapi import Depends, FastAPI, File, HTTPException, Query, Request, UploadFile, Body, Form
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
11
  from PIL import Image
 
14
  from slowapi import Limiter
15
  from slowapi.util import get_remote_address
16
  import torch
17
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig, AutoModel, Gemma3ForConditionalGeneration
18
  from IndicTransToolkit import IndicProcessor
19
  import json
20
  import asyncio
 
22
  import soundfile as sf
23
  import numpy as np
24
  import requests
 
25
  from starlette.responses import StreamingResponse
26
+ from logging_config import logger
27
+ from tts_config import SPEED, ResponseFormat, config as tts_config
28
  import torchaudio
 
 
29
 
30
  # Device setup
31
+ if torch.cuda.is_available():
32
+ device = "cuda:0"
33
+ logger.info("GPU will be used for inference")
34
+ else:
35
+ device = "cpu"
36
+ logger.info("CPU will be used for inference")
37
+ torch_dtype = torch.bfloat16 if device != "cpu" else torch.float32
38
 
39
  # Check CUDA availability and version
40
  cuda_available = torch.cuda.is_available()
41
  cuda_version = torch.version.cuda if cuda_available else None
42
+
43
+ if torch.cuda.is_available():
44
  device_idx = torch.cuda.current_device()
45
  capability = torch.cuda.get_device_capability(device_idx)
46
+ compute_capability_float = float(f"{capability[0]}.{capability[1]}")
47
+ print(f"CUDA version: {cuda_version}")
48
+ print(f"CUDA Compute Capability: {compute_capability_float}")
49
  else:
50
+ print("CUDA is not available on this system.")
51
 
52
  # Settings
53
  class Settings(BaseSettings):
 
69
 
70
  settings = Settings()
71
 
72
+ # Quantization config for LLM
73
+ quantization_config = BitsAndBytesConfig(
74
+ load_in_4bit=True,
75
+ bnb_4bit_quant_type="nf4",
76
+ bnb_4bit_use_double_quant=True,
77
+ bnb_4bit_compute_dtype=torch.bfloat16
78
+ )
79
 
80
+ # LLM Manager
81
  class LLMManager:
82
+ def __init__(self, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
83
  self.model_name = model_name
84
  self.device = torch.device(device)
85
+ self.torch_dtype = torch.bfloat16 if self.device.type != "cpu" else torch.float32
86
  self.model = None
87
  self.processor = None
88
  self.is_loaded = False
 
 
89
  logger.info(f"LLMManager initialized with model {model_name} on {self.device}")
90
 
91
  def load(self):
92
  if not self.is_loaded:
93
  try:
 
 
 
 
94
  self.model = Gemma3ForConditionalGeneration.from_pretrained(
95
  self.model_name,
96
  device_map="auto",
97
+ quantization_config=quantization_config,
98
+ torch_dtype=self.torch_dtype
99
+ )
100
+ self.model.eval()
101
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
 
 
 
 
102
  self.is_loaded = True
103
+ logger.info(f"LLM {self.model_name} loaded on {self.device}")
104
  except Exception as e:
105
  logger.error(f"Failed to load LLM: {str(e)}")
106
+ raise
107
 
108
  def unload(self):
109
  if self.is_loaded:
 
111
  del self.processor
112
  if self.device.type == "cuda":
113
  torch.cuda.empty_cache()
114
+ logger.info(f"GPU memory allocated after unload: {torch.cuda.memory_allocated()}")
115
  self.is_loaded = False
116
+ logger.info(f"LLM {self.model_name} unloaded from {self.device}")
 
117
 
118
+ async def generate(self, prompt: str, max_tokens: int = 512, temperature: float = 0.7) -> str:
119
  if not self.is_loaded:
 
120
  self.load()
 
 
121
 
122
+ messages_vlm = [
123
+ {
124
+ "role": "system",
125
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state. Provide a concise response in one sentence maximum."}]
126
+ },
127
+ {
128
+ "role": "user",
129
+ "content": [{"type": "text", "text": prompt}]
130
+ }
131
+ ]
132
+
133
+ try:
134
+ inputs_vlm = self.processor.apply_chat_template(
135
+ messages_vlm,
136
+ add_generation_prompt=True,
137
+ tokenize=True,
138
+ return_dict=True,
139
+ return_tensors="pt"
140
+ ).to(self.device, dtype=torch.bfloat16)
141
+ except Exception as e:
142
+ logger.error(f"Error in tokenization: {str(e)}")
143
+ raise HTTPException(status_code=500, detail=f"Tokenization failed: {str(e)}")
144
+
145
+ input_len = inputs_vlm["input_ids"].shape[-1]
146
+
147
+ with torch.inference_mode():
148
+ generation = self.model.generate(
149
+ **inputs_vlm,
150
+ max_new_tokens=max_tokens,
151
+ do_sample=True,
152
+ temperature=temperature
153
+ )
154
+ generation = generation[0][input_len:]
155
+
156
+ response = self.processor.decode(generation, skip_special_tokens=True)
157
+ logger.info(f"Generated response: {response}")
158
+ return response
159
+
160
+ async def vision_query(self, image: Image.Image, query: str) -> str:
161
+ if not self.is_loaded:
162
+ self.load()
163
 
164
  messages_vlm = [
165
+ {
166
+ "role": "system",
167
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Summarize your answer in maximum 1 sentence."}]
168
+ },
169
+ {
170
+ "role": "user",
171
+ "content": []
172
+ }
173
  ]
174
 
175
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
176
+ if image and image.size[0] > 0 and image.size[1] > 0:
177
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
178
+ logger.info(f"Received valid image for processing")
179
+ else:
180
+ logger.info("No valid image provided, processing text only")
181
+
182
  try:
183
  inputs_vlm = self.processor.apply_chat_template(
184
  messages_vlm,
 
186
  tokenize=True,
187
  return_dict=True,
188
  return_tensors="pt"
189
+ ).to(self.device, dtype=torch.bfloat16)
190
+ except Exception as e:
191
+ logger.error(f"Error in apply_chat_template: {str(e)}")
192
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
193
+
194
+ input_len = inputs_vlm["input_ids"].shape[-1]
195
+
196
+ with torch.inference_mode():
197
+ generation = self.model.generate(
198
+ **inputs_vlm,
199
+ max_new_tokens=512,
200
+ do_sample=True,
201
+ temperature=0.7
202
+ )
203
+ generation = generation[0][input_len:]
204
+
205
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
206
+ logger.info(f"Vision query response: {decoded}")
207
+ return decoded
208
+
209
+ async def chat_v2(self, image: Image.Image, query: str) -> str:
210
+ if not self.is_loaded:
211
+ self.load()
212
+
213
+ messages_vlm = [
214
+ {
215
+ "role": "system",
216
+ "content": [{"type": "text", "text": "You are Dhwani, a helpful assistant. Answer questions considering India as base country and Karnataka as base state."}]
217
+ },
218
+ {
219
+ "role": "user",
220
+ "content": []
221
+ }
222
+ ]
223
+
224
+ messages_vlm[1]["content"].append({"type": "text", "text": query})
225
+ if image and image.size[0] > 0 and image.size[1] > 0:
226
+ messages_vlm[1]["content"].insert(0, {"type": "image", "image": image})
227
+ logger.info(f"Received valid image for processing")
228
+ else:
229
+ logger.info("No valid image provided, processing text only")
230
 
231
+ try:
232
+ inputs_vlm = self.processor.apply_chat_template(
233
+ messages_vlm,
234
+ add_generation_prompt=True,
235
+ tokenize=True,
236
+ return_dict=True,
237
+ return_tensors="pt"
238
+ ).to(self.device, dtype=torch.bfloat16)
239
  except Exception as e:
240
+ logger.error(f"Error in apply_chat_template: {str(e)}")
241
+ raise HTTPException(status_code=500, detail=f"Failed to process input: {str(e)}")
242
+
243
+ input_len = inputs_vlm["input_ids"].shape[-1]
244
+
245
+ with torch.inference_mode():
246
+ generation = self.model.generate(
247
+ **inputs_vlm,
248
+ max_new_tokens=512,
249
+ do_sample=True,
250
+ temperature=0.7
251
+ )
252
+ generation = generation[0][input_len:]
253
+
254
+ decoded = self.processor.decode(generation, skip_special_tokens=True)
255
+ logger.info(f"Chat_v2 response: {decoded}")
256
+ return decoded
257
 
258
+ # TTS Manager
259
  class TTSManager:
260
  def __init__(self, device_type=device):
261
+ self.device_type = device_type
262
  self.model = None
263
  self.repo_id = "ai4bharat/IndicF5"
 
264
 
265
  def load(self):
266
  if not self.model:
267
+ logger.info("Loading TTS model IndicF5...")
268
+ self.model = AutoModel.from_pretrained(
269
+ self.repo_id,
270
+ trust_remote_code=True
271
+ )
272
+ self.model = self.model.to(self.device_type)
273
+ logger.info("TTS model IndicF5 loaded")
 
 
 
 
 
274
 
275
  def synthesize(self, text, ref_audio_path, ref_text):
276
  if not self.model:
277
  raise ValueError("TTS model not loaded")
278
+ return self.model(text, ref_audio_path=ref_audio_path, ref_text=ref_text)
279
+
280
+ # TTS Constants
281
+ EXAMPLES = [
282
+ {
283
+ "audio_name": "KAN_F (Happy)",
284
+ "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
285
+ "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್��ಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
286
+ "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
287
+ },
288
+ ]
289
+
290
+ # Pydantic models for TTS
291
+ class SynthesizeRequest(BaseModel):
292
+ text: str
293
+ ref_audio_name: str
294
+ ref_text: str = None
295
 
296
+ class KannadaSynthesizeRequest(BaseModel):
297
+ text: str
298
+
299
+ # TTS Functions
300
+ def load_audio_from_url(url: str):
301
+ response = requests.get(url)
302
+ if response.status_code == 200:
303
+ audio_data, sample_rate = sf.read(io.BytesIO(response.content))
304
+ return sample_rate, audio_data
305
+ raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
306
+
307
+ def synthesize_speech(tts_manager: TTSManager, text: str, ref_audio_name: str, ref_text: str):
308
+ ref_audio_url = None
309
+ for example in EXAMPLES:
310
+ if example["audio_name"] == ref_audio_name:
311
+ ref_audio_url = example["audio_url"]
312
+ if not ref_text:
313
+ ref_text = example["ref_text"]
314
+ break
315
+
316
+ if not ref_audio_url:
317
+ raise HTTPException(status_code=400, detail="Invalid reference audio name.")
318
+ if not text.strip():
319
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
320
+ if not ref_text or not ref_text.strip():
321
+ raise HTTPException(status_code=400, detail="Reference text cannot be empty.")
322
+
323
+ sample_rate, audio_data = load_audio_from_url(ref_audio_url)
324
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
325
+ sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
326
+ temp_audio.flush()
327
+ audio = tts_manager.synthesize(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
328
+
329
+ if audio.dtype == np.int16:
330
+ audio = audio.astype(np.float32) / 32768.0
331
+ buffer = io.BytesIO()
332
+ sf.write(buffer, audio, 24000, format='WAV')
333
+ buffer.seek(0)
334
+ return buffer
335
+
336
+ # Supported languages
337
+ SUPPORTED_LANGUAGES = {
338
+ "asm_Beng", "kas_Arab", "pan_Guru", "ben_Beng", "kas_Deva", "san_Deva",
339
+ "brx_Deva", "mai_Deva", "sat_Olck", "doi_Deva", "mal_Mlym", "snd_Arab",
340
+ "eng_Latn", "mar_Deva", "snd_Deva", "gom_Deva", "mni_Beng", "tam_Taml",
341
+ "guj_Gujr", "mni_Mtei", "tel_Telu", "hin_Deva", "npi_Deva", "urd_Arab",
342
+ "kan_Knda", "ory_Orya",
343
+ "deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn",
344
+ "por_Latn", "rus_Cyrl", "pol_Latn"
345
+ }
346
+
347
+ # Translation Manager
348
  class TranslateManager:
349
  def __init__(self, src_lang, tgt_lang, device_type=device, use_distilled=True):
350
+ self.device_type = device_type
351
+ self.tokenizer = None
352
+ self.model = None
353
+ self.src_lang = src_lang
354
+ self.tgt_lang = tgt_lang
355
+ self.use_distilled = use_distilled
356
 
357
+ def load(self):
358
+ if not self.tokenizer or not self.model:
359
+ if self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
360
+ model_name = "ai4bharat/indictrans2-en-indic-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-en-indic-1B"
361
+ elif not self.src_lang.startswith("eng") and self.tgt_lang.startswith("eng"):
362
+ model_name = "ai4bharat/indictrans2-indic-en-dist-200M" if self.use_distilled else "ai4bharat/indictrans2-indic-en-1B"
363
+ elif not self.src_lang.startswith("eng") and not self.tgt_lang.startswith("eng"):
364
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" if self.use_distilled else "ai4bharat/indictrans2-indic-indic-1B"
365
  else:
366
  raise ValueError("Invalid language combination")
367
 
368
+ self.tokenizer = AutoTokenizer.from_pretrained(
369
+ model_name,
370
+ trust_remote_code=True
371
+ )
372
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
373
  model_name,
374
  trust_remote_code=True,
375
  torch_dtype=torch.float16,
376
  attn_implementation="flash_attention_2"
377
+ )
378
+ self.model = self.model.to(self.device_type)
379
+ self.model = torch.compile(self.model, mode="reduce-overhead")
380
+ logger.info(f"Translation model {model_name} loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
 
382
  class ModelManager:
383
+ def __init__(self, device_type=device, use_distilled=True, is_lazy_loading=False):
384
  self.models = {}
385
  self.device_type = device_type
386
  self.use_distilled = use_distilled
387
+ self.is_lazy_loading = is_lazy_loading
388
 
389
+ def load_model(self, src_lang, tgt_lang, key):
390
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}")
391
+ translate_manager = TranslateManager(src_lang, tgt_lang, self.device_type, self.use_distilled)
392
+ translate_manager.load()
393
+ self.models[key] = translate_manager
394
+ logger.info(f"Loaded translation model for {key}")
 
 
 
395
 
396
  def get_model(self, src_lang, tgt_lang):
397
+ key = self._get_model_key(src_lang, tgt_lang)
398
+ if key not in self.models:
399
+ if self.is_lazy_loading:
400
+ self.load_model(src_lang, tgt_lang, key)
401
+ else:
402
+ raise ValueError(f"Model for {key} is not preloaded and lazy loading is disabled.")
403
+ return self.models.get(key)
404
+
405
+ def _get_model_key(self, src_lang, tgt_lang):
406
  if src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
407
+ return 'eng_indic'
408
  elif not src_lang.startswith("eng") and tgt_lang.startswith("eng"):
409
+ return 'indic_eng'
410
  elif not src_lang.startswith("eng") and not tgt_lang.startswith("eng"):
411
+ return 'indic_indic'
412
+ raise ValueError("Invalid language combination")
 
 
 
 
413
 
414
+ # ASR Manager
415
  class ASRModelManager:
416
+ def __init__(self, device_type="cuda"):
417
+ self.device_type = device_type
418
  self.model = None
419
  self.model_language = {"kannada": "kn"}
 
420
 
421
  def load(self):
422
  if not self.model:
423
+ logger.info("Loading ASR model...")
424
  self.model = AutoModel.from_pretrained(
425
  "ai4bharat/indic-conformer-600m-multilingual",
426
  trust_remote_code=True
427
+ )
428
+ self.model = self.model.to(self.device_type)
429
  logger.info("ASR model loaded")
430
 
 
 
 
 
 
 
 
 
 
431
  # Global Managers
432
  llm_manager = LLMManager(settings.llm_model_name)
433
  model_manager = ModelManager()
 
435
  tts_manager = TTSManager()
436
  ip = IndicProcessor(inference=True)
437
 
 
 
 
 
 
 
 
 
 
438
  # Pydantic Models
439
  class ChatRequest(BaseModel):
440
  prompt: str
 
447
  raise ValueError("Prompt cannot exceed 1000 characters")
448
  return v.strip()
449
 
450
+ @field_validator("src_lang", "tgt_lang")
451
+ def validate_language(cls, v):
452
+ if v not in SUPPORTED_LANGUAGES:
453
+ raise ValueError(f"Unsupported language code: {v}. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
454
+ return v
455
+
456
+
457
  class ChatResponse(BaseModel):
458
  response: str
459
 
460
+ class TranslationRequest(BaseModel):
461
+ sentences: List[str]
462
+ src_lang: str
463
+ tgt_lang: str
 
 
 
 
464
 
465
  class TranscriptionResponse(BaseModel):
466
  text: str
467
 
468
+ class TranslationResponse(BaseModel):
469
+ translations: List[str]
470
+
471
+ # Dependency
472
+ def get_translate_manager(src_lang: str, tgt_lang: str) -> TranslateManager:
473
+ return model_manager.get_model(src_lang, tgt_lang)
474
+
475
+ # Lifespan Event Handler
476
+ translation_configs = []
477
+
478
+ @asynccontextmanager
479
+ async def lifespan(app: FastAPI):
480
+ def load_all_models():
481
+ try:
482
+ # Load LLM model
483
+ logger.info("Loading LLM model...")
484
+ llm_manager.load()
485
+ logger.info("LLM model loaded successfully")
486
+
487
+ # Load TTS model
488
+ logger.info("Loading TTS model...")
489
+ tts_manager.load()
490
+ logger.info("TTS model loaded successfully")
491
+
492
+ # Load ASR model
493
+ logger.info("Loading ASR model...")
494
+ asr_manager.load()
495
+ logger.info("ASR model loaded successfully")
496
+
497
+ # Load translation models
498
+ translation_tasks = [
499
+ ('eng_Latn', 'kan_Knda', 'eng_indic'),
500
+ ('kan_Knda', 'eng_Latn', 'indic_eng'),
501
+ ('kan_Knda', 'hin_Deva', 'indic_indic'),
502
+ ]
503
+
504
+ for config in translation_configs:
505
+ src_lang = config["src_lang"]
506
+ tgt_lang = config["tgt_lang"]
507
+ key = model_manager._get_model_key(src_lang, tgt_lang)
508
+ translation_tasks.append((src_lang, tgt_lang, key))
509
+
510
+ for src_lang, tgt_lang, key in translation_tasks:
511
+ logger.info(f"Loading translation model for {src_lang} -> {tgt_lang}...")
512
+ model_manager.load_model(src_lang, tgt_lang, key)
513
+ logger.info(f"Translation model for {key} loaded successfully")
514
+
515
+ logger.info("All models loaded successfully")
516
+ except Exception as e:
517
+ logger.error(f"Error loading models: {str(e)}")
518
+ raise
519
+
520
+ logger.info("Starting sequential model loading...")
521
+ load_all_models()
522
+ yield
523
+ llm_manager.unload()
524
+ logger.info("Server shutdown complete")
525
 
526
  # FastAPI App
527
  app = FastAPI(
528
+ title="Dhwani API",
529
+ description="AI Chat API supporting Indian languages",
530
  version="1.0.0",
531
+ redirect_slashes=False,
532
  lifespan=lifespan
533
  )
534
 
535
+ # Add CORS Middleware
536
  app.add_middleware(
537
  CORSMiddleware,
538
  allow_origins=["*"],
 
541
  allow_headers=["*"],
542
  )
543
 
544
+ # Add Timing Middleware
545
  @app.middleware("http")
546
  async def add_request_timing(request: Request, call_next):
547
  start_time = time()
 
555
  limiter = Limiter(key_func=get_remote_address)
556
  app.state.limiter = limiter
557
 
558
+ # API Endpoints
559
+ @app.post("/audio/speech", response_class=StreamingResponse)
560
+ async def synthesize_kannada(request: KannadaSynthesizeRequest):
561
+ if not tts_manager.model:
562
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
563
+ kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
564
+ if not request.text.strip():
565
+ raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
566
+
567
+ audio_buffer = synthesize_speech(
568
+ tts_manager,
569
+ text=request.text,
570
+ ref_audio_name="KAN_F (Happy)",
571
+ ref_text=kannada_example["ref_text"]
572
+ )
573
+
574
+ return StreamingResponse(
575
+ audio_buffer,
576
+ media_type="audio/wav",
577
+ headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
578
+ )
579
+
580
+ @app.post("/translate", response_model=TranslationResponse)
581
+ async def translate(request: TranslationRequest, translate_manager: TranslateManager = Depends(get_translate_manager)):
582
+ input_sentences = request.sentences
583
+ src_lang = request.src_lang
584
+ tgt_lang = request.tgt_lang
585
+
586
+ if not input_sentences:
587
+ raise HTTPException(status_code=400, detail="Input sentences are required")
588
+
589
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
590
+ inputs = translate_manager.tokenizer(
591
+ batch,
592
+ truncation=True,
593
+ padding="longest",
594
+ return_tensors="pt",
595
+ return_attention_mask=True,
596
+ ).to(translate_manager.device_type)
597
+
598
+ with torch.no_grad():
599
+ generated_tokens = translate_manager.model.generate(
600
+ **inputs,
601
+ use_cache=True,
602
+ min_length=0,
603
+ max_length=256,
604
+ num_beams=5,
605
+ num_return_sequences=1,
606
+ )
607
+
608
+ with translate_manager.tokenizer.as_target_tokenizer():
609
+ generated_tokens = translate_manager.tokenizer.batch_decode(
610
+ generated_tokens.detach().cpu().tolist(),
611
+ skip_special_tokens=True,
612
+ clean_up_tokenization_spaces=True,
613
+ )
614
+
615
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
616
+ return TranslationResponse(translations=translations)
617
+
618
+ async def perform_internal_translation(sentences: List[str], src_lang: str, tgt_lang: str) -> List[str]:
619
+ try:
620
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
621
+ except ValueError as e:
622
+ logger.info(f"Model not preloaded: {str(e)}, loading now...")
623
+ key = model_manager._get_model_key(src_lang, tgt_lang)
624
+ model_manager.load_model(src_lang, tgt_lang, key)
625
+ translate_manager = model_manager.get_model(src_lang, tgt_lang)
626
+
627
+ if not translate_manager.model:
628
+ translate_manager.load()
629
+
630
+ request = TranslationRequest(sentences=sentences, src_lang=src_lang, tgt_lang=tgt_lang)
631
+ response = await translate(request, translate_manager)
632
+ return response.translations
633
+
634
+ @app.get("/v1/health")
635
+ async def health_check():
636
+ return {"status": "healthy", "model": settings.llm_model_name}
637
+
638
+ @app.get("/")
639
+ async def home():
640
+ return RedirectResponse(url="/docs")
641
+
642
+ @app.post("/v1/unload_all_models")
643
+ async def unload_all_models():
644
+ try:
645
+ logger.info("Starting to unload all models...")
646
+ llm_manager.unload()
647
+ logger.info("All models unloaded successfully")
648
+ return {"status": "success", "message": "All models unloaded"}
649
+ except Exception as e:
650
+ logger.error(f"Error unloading models: {str(e)}")
651
+ raise HTTPException(status_code=500, detail=f"Failed to unload models: {str(e)}")
652
+
653
+ @app.post("/v1/load_all_models")
654
+ async def load_all_models():
655
+ try:
656
+ logger.info("Starting to load all models...")
657
+ llm_manager.load()
658
+ logger.info("All models loaded successfully")
659
+ return {"status": "success", "message": "All models loaded"}
660
+ except Exception as e:
661
+ logger.error(f"Error loading models: {str(e)}")
662
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {str(e)}")
663
+
664
+ @app.post("/v1/translate", response_model=TranslationResponse)
665
+ async def translate_endpoint(request: TranslationRequest):
666
+ logger.info(f"Received translation request: {request.dict()}")
667
+ try:
668
+ translations = await perform_internal_translation(
669
+ sentences=request.sentences,
670
+ src_lang=request.src_lang,
671
+ tgt_lang=request.tgt_lang
672
+ )
673
+ logger.info(f"Translation successful: {translations}")
674
+ return TranslationResponse(translations=translations)
675
+ except Exception as e:
676
+ logger.error(f"Unexpected error during translation: {str(e)}")
677
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
678
+
679
+ @app.post("/v1/chat", response_model=ChatResponse)
680
+ @limiter.limit(settings.chat_rate_limit)
681
+ async def chat(request: Request, chat_request: ChatRequest):
682
+ if not chat_request.prompt:
683
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
684
+ logger.info(f"Received prompt: {chat_request.prompt}, src_lang: {chat_request.src_lang}, tgt_lang: {chat_request.tgt_lang}")
685
+
686
+ EUROPEAN_LANGUAGES = {"deu_Latn", "fra_Latn", "nld_Latn", "spa_Latn", "ita_Latn", "por_Latn", "rus_Cyrl", "pol_Latn"}
687
+
688
+ try:
689
+ if chat_request.src_lang != "eng_Latn" and chat_request.src_lang not in EUROPEAN_LANGUAGES:
690
+ translated_prompt = await perform_internal_translation(
691
+ sentences=[chat_request.prompt],
692
+ src_lang=chat_request.src_lang,
693
+ tgt_lang="eng_Latn"
694
+ )
695
+ prompt_to_process = translated_prompt[0]
696
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
697
+ else:
698
+ prompt_to_process = chat_request.prompt
699
+ logger.info("Prompt in English or European language, no translation needed")
700
+
701
+ response = await llm_manager.generate(prompt_to_process, settings.max_tokens)
702
+ logger.info(f"Generated response: {response}")
703
+
704
+ if chat_request.tgt_lang != "eng_Latn" and chat_request.tgt_lang not in EUROPEAN_LANGUAGES:
705
+ translated_response = await perform_internal_translation(
706
+ sentences=[response],
707
+ src_lang="eng_Latn",
708
+ tgt_lang=chat_request.tgt_lang
709
+ )
710
+ final_response = translated_response[0]
711
+ logger.info(f"Translated response to {chat_request.tgt_lang}: {final_response}")
712
+ else:
713
+ final_response = response
714
+ logger.info(f"Response in {chat_request.tgt_lang}, no translation needed")
715
+
716
+ return ChatResponse(response=final_response)
717
+ except Exception as e:
718
+ logger.error(f"Error processing request: {str(e)}")
719
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
720
+
721
+ @app.post("/v1/visual_query/")
722
+ async def visual_query(
723
  file: UploadFile = File(...),
724
+ query: str = Body(...),
725
+ src_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
726
+ tgt_lang: str = Query("kan_Knda", enum=list(SUPPORTED_LANGUAGES)),
727
  ):
728
+ try:
729
+ image = Image.open(file.file)
730
+ if image.size == (0, 0):
731
+ raise HTTPException(status_code=400, detail="Uploaded image is empty or invalid")
732
+
733
+ if src_lang != "eng_Latn":
734
+ translated_query = await perform_internal_translation(
735
+ sentences=[query],
736
+ src_lang=src_lang,
737
+ tgt_lang="eng_Latn"
738
+ )
739
+ query_to_process = translated_query[0]
740
+ logger.info(f"Translated query to English: {query_to_process}")
741
+ else:
742
+ query_to_process = query
743
+ logger.info("Query already in English, no translation needed")
744
 
745
+ answer = await llm_manager.vision_query(image, query_to_process)
746
+ logger.info(f"Generated English answer: {answer}")
 
 
 
747
 
748
+ if tgt_lang != "eng_Latn":
749
+ translated_answer = await perform_internal_translation(
750
+ sentences=[answer],
751
+ src_lang="eng_Latn",
752
+ tgt_lang=tgt_lang
 
 
 
 
 
 
 
 
 
 
 
 
753
  )
754
+ final_answer = translated_answer[0]
755
+ logger.info(f"Translated answer to {tgt_lang}: {final_answer}")
756
+ else:
757
+ final_answer = answer
758
+ logger.info("Answer kept in English, no translation needed")
759
+
760
+ return {"answer": final_answer}
761
+ except Exception as e:
762
+ logger.error(f"Error processing request: {str(e)}")
763
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
764
+
765
+ @app.post("/v1/chat_v2", response_model=ChatResponse)
766
+ @limiter.limit(settings.chat_rate_limit)
767
+ async def chat_v2(
768
+ request: Request,
769
+ prompt: str = Form(...),
770
+ image: UploadFile = File(default=None),
771
+ src_lang: str = Form("kan_Knda"),
772
+ tgt_lang: str = Form("kan_Knda"),
773
+ ):
774
+ if not prompt:
775
+ raise HTTPException(status_code=400, detail="Prompt cannot be empty")
776
+ if src_lang not in SUPPORTED_LANGUAGES or tgt_lang not in SUPPORTED_LANGUAGES:
777
+ raise HTTPException(status_code=400, detail=f"Unsupported language code. Supported codes: {', '.join(SUPPORTED_LANGUAGES)}")
778
+
779
+ logger.info(f"Received prompt: {prompt}, src_lang: {src_lang}, tgt_lang: {tgt_lang}, Image provided: {image is not None}")
780
+
781
+ try:
782
+ if image:
783
+ image_data = await image.read()
784
+ if not image_data:
785
+ raise HTTPException(status_code=400, detail="Uploaded image is empty")
786
+ img = Image.open(io.BytesIO(image_data))
787
+
788
+ if src_lang != "eng_Latn":
789
  translated_prompt = await perform_internal_translation(
790
+ sentences=[prompt],
791
+ src_lang=src_lang,
792
+ tgt_lang="eng_Latn"
793
  )
794
  prompt_to_process = translated_prompt[0]
795
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
796
  else:
797
+ prompt_to_process = prompt
798
+ logger.info("Prompt already in English, no translation needed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
 
800
+ decoded = await llm_manager.chat_v2(img, prompt_to_process)
801
+ logger.info(f"Generated English response: {decoded}")
802
+
803
+ if tgt_lang != "eng_Latn":
804
+ translated_response = await perform_internal_translation(
805
+ sentences=[decoded],
806
+ src_lang="eng_Latn",
807
+ tgt_lang=tgt_lang
808
+ )
809
+ final_response = translated_response[0]
810
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
 
 
 
 
 
811
  else:
812
+ final_response = decoded
813
+ logger.info("Response kept in English, no translation needed")
814
+ else:
815
+ if src_lang != "eng_Latn":
816
+ translated_prompt = await perform_internal_translation(
817
+ sentences=[prompt],
818
+ src_lang=src_lang,
819
+ tgt_lang="eng_Latn"
820
+ )
821
+ prompt_to_process = translated_prompt[0]
822
+ logger.info(f"Translated prompt to English: {prompt_to_process}")
 
 
 
 
823
  else:
824
+ prompt_to_process = prompt
825
+ logger.info("Prompt already in English, no translation needed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
 
827
+ decoded = await llm_manager.generate(prompt_to_process, settings.max_tokens)
828
+ logger.info(f"Generated English response: {decoded}")
 
 
 
 
 
 
 
 
 
 
 
 
829
 
830
+ if tgt_lang != "eng_Latn":
831
+ translated_response = await perform_internal_translation(
832
+ sentences=[decoded],
833
+ src_lang="eng_Latn",
834
+ tgt_lang=tgt_lang
835
+ )
836
+ final_response = translated_response[0]
837
+ logger.info(f"Translated response to {tgt_lang}: {final_response}")
838
+ else:
839
+ final_response = decoded
840
+ logger.info("Response kept in English, no translation needed")
841
+
842
+ return ChatResponse(response=final_response)
843
+ except Exception as e:
844
+ logger.error(f"Error processing request: {str(e)}")
845
+ raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
846
+
847
+ @app.post("/transcribe/", response_model=TranscriptionResponse)
848
+ async def transcribe_audio(file: UploadFile = File(...), language: str = Query(..., enum=list(asr_manager.model_language.keys()))):
849
+ if not asr_manager.model:
850
+ raise HTTPException(status_code=503, detail="ASR model not loaded")
851
+ try:
852
+ wav, sr = torchaudio.load(file.file)
853
+ wav = torch.mean(wav, dim=0, keepdim=True)
854
+ target_sample_rate = 16000
855
+ if sr != target_sample_rate:
856
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sample_rate)
857
+ wav = resampler(wav)
858
+ transcription_rnnt = asr_manager.model(wav, asr_manager.model_language[language], "rnnt")
859
+ return TranscriptionResponse(text=transcription_rnnt)
860
+ except Exception as e:
861
+ logger.error(f"Error in transcription: {str(e)}")
862
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
863
+
864
+ @app.post("/v1/speech_to_speech")
865
+ async def speech_to_speech(
866
+ request: Request,
867
+ file: UploadFile = File(...),
868
+ language: str = Query(..., enum=list(asr_manager.model_language.keys())),
869
+ ) -> StreamingResponse:
870
+ if not tts_manager.model:
871
+ raise HTTPException(status_code=503, detail="TTS model not loaded")
872
+ transcription = await transcribe_audio(file, language)
873
+ logger.info(f"Transcribed text: {transcription.text}")
874
+
875
+ chat_request = ChatRequest(
876
+ prompt=transcription.text,
877
+ src_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda"),
878
+ tgt_lang=LANGUAGE_TO_SCRIPT.get(language, "kan_Knda")
879
+ )
880
+ processed_text = await chat(request, chat_request)
881
+ logger.info(f"Processed text: {processed_text.response}")
882
+
883
+ voice_request = KannadaSynthesizeRequest(text=processed_text.response)
884
+ audio_response = await synthesize_kannada(voice_request)
885
+ return audio_response
886
+
887
+ LANGUAGE_TO_SCRIPT = {
888
+ "kannada": "kan_Knda"
889
+ }
890
+
891
+ # Main Execution
892
  if __name__ == "__main__":
893
  parser = argparse.ArgumentParser(description="Run the FastAPI server.")
894
  parser.add_argument("--port", type=int, default=settings.port, help="Port to run the server on.")
895
  parser.add_argument("--host", type=str, default=settings.host, help="Host to run the server on.")
896
+ parser.add_argument("--config", type=str, default="config_one", help="Configuration to use")
897
  args = parser.parse_args()
898
 
899
+ def load_config(config_path="dhwani_config.json"):
900
+ with open(config_path, "r") as f:
901
+ return json.load(f)
902
+
903
+ config_data = load_config()
904
+ if args.config not in config_data["configs"]:
905
+ raise ValueError(f"Invalid config: {args.config}. Available: {list(config_data['configs'].keys())}")
906
+
907
+ selected_config = config_data["configs"][args.config]
908
+ global_settings = config_data["global_settings"]
909
+
910
+ settings.llm_model_name = selected_config["components"]["LLM"]["model"]
911
+ settings.max_tokens = selected_config["components"]["LLM"]["max_tokens"]
912
+ settings.host = global_settings["host"]
913
+ settings.port = global_settings["port"]
914
+ settings.chat_rate_limit = global_settings["chat_rate_limit"]
915
+ settings.speech_rate_limit = global_settings["speech_rate_limit"]
916
+
917
+ llm_manager = LLMManager(settings.llm_model_name)
918
+
919
+ if selected_config["components"]["ASR"]:
920
+ asr_model_name = selected_config["components"]["ASR"]["model"]
921
+ asr_manager.model_language[selected_config["language"]] = selected_config["components"]["ASR"]["language_code"]
922
+
923
+ if selected_config["components"]["Translation"]:
924
+ translation_configs.extend(selected_config["components"]["Translation"])
925
+
926
+ host = args.host if args.host != settings.host else settings.host
927
+ port = args.port if args.port != settings.port else settings.port
928
+
929
+ uvicorn.run(app, host=host, port=port)