#load package from fastapi import FastAPI from pydantic import BaseModel import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer ) from typing import List, Tuple from threading import Thread import os from pydantic import BaseModel import logging import uvicorn # Configurer les répertoires de cache os.environ['TRANSFORMERS_CACHE'] = '/app/.cache' os.environ['HF_HOME'] = '/app/.cache' # Charger le modèle et le tokenizer model = AutoModelForCausalLM.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True, device_map='auto') tokenizer = AutoTokenizer.from_pretrained("THUDM/longwriter-glm4-9b", trust_remote_code=True) #Additional information Informations = """ -text : Texte à resumé output: - Text summary : texte resumé """ app =FastAPI( title='Text Summary', description =Informations ) #class to define the input text logging.basicConfig(level=logging.INFO) logger =logging.getLogger(__name__) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = model.config.eos_token_id for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False default_prompt = """Bonjour, En tant qu’expert en gestion des plaintes réseaux, rédige un descriptif clair de la plainte ci-dessous. Résume la situation en 4 ou 5 phrases concises, en mettant l'accent sur : 1. **Informations Client** : Indique des détails pertinents sur le client. 2. **Dates et Délais** : Mentionne les dates clés et les délais (prise en charge, résolution, etc.). 3. **Contexte et Détails** : Inclut les éléments essentiels de la plainte (titre, détails, états d’avancement, qualification, fichiers joints). Ajoute une recommandation importante pour éviter le mécontentement du client, par exemple, en cas de service non fourni malgré le paiement. Adapte le ton pour qu'il soit humain et engageant. Merci ! """ class PredictionRequest(BaseModel): history: List[Tuple[str, str]] = [] prompt: str = default_prompt max_length: int = 10240 top_p: float = 0.8 temperature: float = 0.6 @app.post("/predict/") async def predict(request: PredictionRequest): history = request.history prompt = request.prompt max_length = request.max_length top_p = request.top_p temperature = request.temperature stop = StopOnTokens() messages = [] if prompt: messages.append({"role": "system", "content": prompt}) for idx, (user_msg, model_msg) in enumerate(history): if prompt and idx == 0: continue if idx == len(history) - 1 and not model_msg: query = user_msg break if user_msg: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to( next(model.parameters()).device) streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True) eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|observation|>")] generate_kwargs = { "input_ids": model_inputs, "streamer": streamer, "max_new_tokens": max_length, "do_sample": True, "top_p": top_p, "temperature": temperature, "stopping_criteria": StoppingCriteriaList([stop]), "repetition_penalty": 1, "eos_token_id": eos_token_id, } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() generated_text = "" for new_token in streamer: if new_token and '<|user|>' in new_token: new_token = new_token.split('<|user|>')[0] if new_token: generated_text += new_token history[-1][1] = generated_text return {"history": history} if __name__ == "__main__": uvicorn.run("app:app",reload=True)