DauroCamilo's picture
microsoft/phi-2
b55152c verified
raw
history blame
1.72 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from fastapi.responses import StreamingResponse
import torch
import threading
app = FastAPI()
# Cargar modelo y tokenizer de Phi-2 (usa el modelo de Hugging Face Hub)
model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
# Modelo de entrada
class ChatRequest(BaseModel):
message: str
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest):
prompt = f"""Responde en español de forma clara y breve como un asistente IA.
Usuario: {request.message}
IA:"""
# Tokenizar entrada
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Streamer para obtener tokens generados poco a poco
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Iniciar la generación en un hilo aparte
generation_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=48, # Puedes ajustar este valor para más/menos tokens
temperature=0.7,
top_p=0.9,
do_sample=True,
streamer=streamer,
pad_token_id=tokenizer.eos_token_id
)
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# StreamingResponse espera un generador que devuelva texto
async def event_generator():
for new_text in streamer:
yield new_text
return StreamingResponse(event_generator(), media_type="text/plain")