DauroCamilo's picture
Update main.py
59e2f12 verified
raw
history blame
1.73 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from fastapi.responses import StreamingResponse
import torch
import threading
app = FastAPI()
# Cargar modelo y tokenizer de Phi-2 (usa el modelo de Hugging Face Hub)
model_id = "HuggingFaceTB/SmolLM2-135M"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Modelo de entrada
class ChatRequest(BaseModel):
message: str
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest):
prompt = f"""Responde en español de forma clara y breve como un asistente IA.
Usuario: {request.message}
IA:"""
# Tokenizar entrada
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Streamer para obtener tokens generados poco a poco
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=False)
# Iniciar la generación en un hilo aparte
generation_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=48, # Puedes ajustar este valor para más/menos tokens
temperature=0.7,
top_p=0.9,
do_sample=True,
streamer=streamer,
pad_token_id=tokenizer.eos_token_id
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# StreamingResponse espera un generador que devuelva texto
async def event_generator():
for new_text in streamer:
yield new_text
return StreamingResponse(event_generator(), media_type="text/plain")