import os os.environ["HF_HOME"] = "/tmp/hf" os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers" from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from fastapi.responses import StreamingResponse import threading app = FastAPI() model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") class ChatRequest(BaseModel): message: str @app.post("/chat/stream") async def chat_stream(request: ChatRequest): # Usar plantilla de chat, instrucción clara en español messages = [ { "role": "system", "content": "Eres un asistente IA amigable y responde siempre en español, de forma breve y clara.", }, {"role": "user", "content": request.message}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=48, temperature=0.7, top_p=0.9, do_sample=True, streamer=streamer, pad_token_id=tokenizer.eos_token_id, ) thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) thread.start() async def event_generator(): for new_text in streamer: yield new_text return StreamingResponse(event_generator(), media_type="text/plain")