Spaces:
Running
Running
| import argparse | |
| import uvicorn | |
| import sys | |
| import os | |
| import io | |
| from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| import time | |
| import json | |
| from typing import List | |
| import torch | |
| import logging | |
| import string | |
| import random | |
| import base64 | |
| import re | |
| import requests | |
| from utils.enver import enver | |
| import shutil | |
| import tempfile | |
| from fastapi import FastAPI, Response, File, UploadFile, Form | |
| from fastapi.encoders import jsonable_encoder | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from sse_starlette.sse import EventSourceResponse | |
| from utils.logger import logger | |
| from networks.message_streamer import MessageStreamer | |
| from messagers.message_composer import MessageComposer | |
| from googletrans import Translator | |
| from io import BytesIO | |
| from gtts import gTTS | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pathlib import Path | |
| from tempfile import NamedTemporaryFile | |
| class ChatAPIApp: | |
| def __init__(self): | |
| self.app = FastAPI( | |
| docs_url="/", | |
| title="HuggingFace LLM API", | |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
| version="1.0", | |
| ) | |
| self.setup_routes() | |
| def get_available_langs(self): | |
| f = open('apis/lang_name.json', "r") | |
| self.available_models = json.loads(f.read()) | |
| return self.available_models | |
| class TranslateCompletionsPostItem(BaseModel): | |
| from_language: str = Field( | |
| default="en", | |
| description="(str) `Detect`", | |
| ) | |
| to_language: str = Field( | |
| default="fa", | |
| description="(str) `en`", | |
| ) | |
| input_text: str = Field( | |
| default="Hello", | |
| description="(str) `Text for translate`", | |
| ) | |
| def translate_completions(self, item: TranslateCompletionsPostItem): | |
| translator = Translator() | |
| f = open('apis/lang_name.json', "r") | |
| available_langs = json.loads(f.read()) | |
| from_lang = 'en' | |
| to_lang = 'en' | |
| for lang_item in available_langs: | |
| if item.to_language == lang_item['code']: | |
| to_lang = item.to_language | |
| break | |
| translated = translator.translate(item.input_text, dest=to_lang) | |
| item_response = { | |
| "from_language": translated.src, | |
| "to_language": translated.dest, | |
| "text": item.input_text, | |
| "translate": translated.text | |
| } | |
| json_compatible_item_data = jsonable_encoder(item_response) | |
| return JSONResponse(content=json_compatible_item_data) | |
| def translate_ai_completions(self, item: TranslateCompletionsPostItem): | |
| translator = Translator() | |
| #print(os.getcwd()) | |
| f = open('apis/lang_name.json', "r") | |
| available_langs = json.loads(f.read()) | |
| from_lang = 'en' | |
| to_lang = 'en' | |
| for lang_item in available_langs: | |
| if item.to_language == lang_item['code']: | |
| to_lang = item.to_language | |
| if item.from_language == lang_item['code']: | |
| from_lang = item.from_language | |
| if to_lang == 'auto': | |
| to_lang = 'en' | |
| if from_lang == 'auto': | |
| from_lang = translator.detect(item.input_text).lang | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda:0") | |
| else: | |
| device = torch.device("cpu") | |
| logging.warning("GPU not found, using CPU, translation will be very slow.") | |
| time_start = time.time() | |
| #TRANSFORMERS_CACHE | |
| pretrained_model = "facebook/m2m100_1.2B" | |
| cache_dir = "models/" | |
| tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) | |
| model = M2M100ForConditionalGeneration.from_pretrained( | |
| pretrained_model, cache_dir=cache_dir | |
| ).to(device) | |
| model.eval() | |
| tokenizer.src_lang = from_lang | |
| with torch.no_grad(): | |
| encoded_input = tokenizer(item.input_text, return_tensors="pt").to(device) | |
| generated_tokens = model.generate( | |
| **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(to_lang) | |
| ) | |
| translated_text = tokenizer.batch_decode( | |
| generated_tokens, skip_special_tokens=True | |
| )[0] | |
| time_end = time.time() | |
| translated = translated_text | |
| item_response = { | |
| "from_language": from_lang, | |
| "to_language": to_lang, | |
| "text": item.input_text, | |
| "translate": translated, | |
| "start": str(time_start), | |
| "end": str(time_end) | |
| } | |
| json_compatible_item_data = jsonable_encoder(item_response) | |
| return JSONResponse(content=json_compatible_item_data) | |
| class TranslateAiPostItem(BaseModel): | |
| model: str = Field( | |
| default="t5-base", | |
| description="(str) `Model Name`", | |
| ) | |
| from_language: str = Field( | |
| default="en", | |
| description="(str) `translate from`", | |
| ) | |
| to_language: str = Field( | |
| default="fa", | |
| description="(str) `translate to`", | |
| ) | |
| input_text: str = Field( | |
| default="Hello", | |
| description="(str) `Text for translate`", | |
| ) | |
| def ai_translate(self, item:TranslateAiPostItem): | |
| MODEL_MAP = { | |
| "t5-base": "t5-base", | |
| "t5-small": "t5-small", | |
| "t5-large": "t5-large", | |
| "t5-3b": "t5-3b", | |
| "mbart-large-50-many-to-many-mmt": "facebook/mbart-large-50-many-to-many-mmt", | |
| "nllb-200-distilled-600M": "facebook/nllb-200-distilled-600M", | |
| "madlad400-3b-mt": "jbochi/madlad400-3b-mt", | |
| "default": "t5-base", | |
| } | |
| if item.model in MODEL_MAP.keys(): | |
| target_model = item.model | |
| else: | |
| target_model = "default" | |
| real_name = MODEL_MAP[target_model] | |
| read_model = AutoModelForSeq2SeqLM.from_pretrained(real_name) | |
| tokenizer = AutoTokenizer.from_pretrained(real_name) | |
| #translator = pipeline("translation", model=read_model, tokenizer=tokenizer, src_lang=item.from_language, tgt_lang=item.to_language) | |
| translate_query = ( | |
| f"translation_{item.from_language}_to_{item.to_language}" | |
| ) | |
| translator = pipeline(translate_query) | |
| result = translator(item.input_text) | |
| item_response = { | |
| "statue": 200, | |
| "result": result, | |
| } | |
| json_compatible_item_data = jsonable_encoder(item_response) | |
| return JSONResponse(content=json_compatible_item_data) | |
| class DetectLanguagePostItem(BaseModel): | |
| input_text: str = Field( | |
| default="Hello, how are you?", | |
| description="(str) `Text for detection`", | |
| ) | |
| def detect_language(self, item: DetectLanguagePostItem): | |
| translator = Translator() | |
| detected = translator.detect(item.input_text) | |
| item_response = { | |
| "lang": detected.lang, | |
| "confidence": detected.confidence, | |
| } | |
| json_compatible_item_data = jsonable_encoder(item_response) | |
| return JSONResponse(content=json_compatible_item_data) | |
| class TTSPostItem(BaseModel): | |
| input_text: str = Field( | |
| default="Hello", | |
| description="(str) `Text for TTS`", | |
| ) | |
| from_language: str = Field( | |
| default="en", | |
| description="(str) `TTS language`", | |
| ) | |
| def text_to_speech(self, item: TTSPostItem): | |
| try: | |
| audioobj = gTTS(text = item.input_text, lang = item.from_language, slow = False) | |
| fileName = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(10)); | |
| fileName = fileName + ".mp3"; | |
| mp3_fp = BytesIO() | |
| #audioobj.save(fileName) | |
| #audioobj.write_to_fp(mp3_fp) | |
| #buffer = bytearray(mp3_fp.read()) | |
| #base64EncodedStr = base64.encodebytes(buffer) | |
| #mp3_fp.read() | |
| #return Response(content=mp3_fp.tell(), media_type="audio/mpeg") | |
| return StreamingResponse(audioobj.stream()) | |
| except: | |
| item_response = { | |
| "status": 400 | |
| } | |
| json_compatible_item_data = jsonable_encoder(item_response) | |
| return JSONResponse(content=json_compatible_item_data) | |
| def setup_routes(self): | |
| for prefix in ["", "/v1"]: | |
| self.app.get( | |
| prefix + "/langs", | |
| summary="Get available languages", | |
| )(self.get_available_langs) | |
| self.app.post( | |
| prefix + "/translate", | |
| summary="translate text", | |
| )(self.translate_completions) | |
| self.app.post( | |
| prefix + "/translate/ai", | |
| summary="translate text with ai", | |
| )(self.translate_ai_completions) | |
| self.app.post( | |
| prefix + "/detect", | |
| summary="detect language", | |
| )(self.detect_language) | |
| self.app.post( | |
| prefix + "/tts", | |
| summary="text to speech", | |
| )(self.text_to_speech) | |
| class ArgParser(argparse.ArgumentParser): | |
| def __init__(self, *args, **kwargs): | |
| super(ArgParser, self).__init__(*args, **kwargs) | |
| self.add_argument( | |
| "-s", | |
| "--server", | |
| type=str, | |
| default="0.0.0.0", | |
| help="Server IP for HF LLM Chat API", | |
| ) | |
| self.add_argument( | |
| "-p", | |
| "--port", | |
| type=int, | |
| default=23333, | |
| help="Server Port for HF LLM Chat API", | |
| ) | |
| self.add_argument( | |
| "-d", | |
| "--dev", | |
| default=False, | |
| action="store_true", | |
| help="Run in dev mode", | |
| ) | |
| self.args = self.parse_args(sys.argv[1:]) | |
| app = ChatAPIApp().app | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def whisper_transcribe( | |
| audio_file: UploadFile = File(description="Audio file for transcribe"), | |
| language: str = Form(), | |
| model: str = Form(), | |
| ): | |
| MODEL_MAP = { | |
| "whisper-small": "openai/whisper-small", | |
| "whisper-medium": "openai/whisper-medium", | |
| "whisper-large": "openai/whisper-large", | |
| "default": "openai/whisper-small", | |
| } | |
| AUDIO_MAP = { | |
| "audio/wav": "audio/wav", | |
| "audio/mpeg": "audio/mpeg", | |
| "audio/x-flac": "audio/x-flac", | |
| } | |
| item_response = { | |
| "statue": 200, | |
| "result": "", | |
| "start": 0, | |
| "end": 0 | |
| } | |
| if audio_file.content_type in AUDIO_MAP.keys(): | |
| if model in MODEL_MAP.keys(): | |
| target_model = model | |
| else: | |
| target_model = "default" | |
| real_name = MODEL_MAP[target_model] | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| pipe = pipeline( | |
| task="automatic-speech-recognition", | |
| model=real_name, | |
| chunk_length_s=30, | |
| device=device, | |
| ) | |
| time_start = time.time() | |
| pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe") | |
| file_data = await audio_file.read() | |
| # rv = data.encode('utf-8') | |
| rv = base64.b64encode(file_data) | |
| print(rv, "rvrvrvrvr") | |
| text = pipe(rv)["text"] | |
| time_end = time.time() | |
| item_response["status"] = 200 | |
| item_response["result"] = text | |
| item_response["start"] = time_start | |
| item_response["end"] = time_end | |
| else: | |
| item_response["status"] = 400 | |
| item_response["result"] = 'Acceptable files: audio/wav,audio/mpeg,audio/x-flac' | |
| return item_response | |
| if __name__ == "__main__": | |
| args = ArgParser().args | |
| if args.dev: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True) | |
| else: | |
| uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False) | |
| # python -m apis.chat_api # [Docker] on product mode | |
| # python -m apis.chat_api -d # [Dev] on develop mode | |