|
|
|
""" |
|
AI Chat Application - Pure FastAPI Backend |
|
Serves custom frontend with OpenAI compatible API |
|
""" |
|
|
|
import os |
|
import sys |
|
import json |
|
import logging |
|
import time |
|
from typing import Optional, Dict, Any, Generator, List |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from fastapi import FastAPI, HTTPException, Response |
|
from fastapi.responses import StreamingResponse, FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import asyncio |
|
import threading |
|
from threading import Thread |
|
from pydantic import BaseModel |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
role: str |
|
content: str |
|
|
|
class ChatRequest(BaseModel): |
|
messages: List[ChatMessage] |
|
model: Optional[str] = "qwen-coder-3-30b" |
|
temperature: Optional[float] = 0.7 |
|
max_tokens: Optional[int] = 2048 |
|
stream: Optional[bool] = False |
|
|
|
class ChatResponse(BaseModel): |
|
id: str |
|
object: str = "chat.completion" |
|
created: int |
|
model: str |
|
choices: List[Dict[str, Any]] |
|
|
|
|
|
tokenizer = None |
|
model = None |
|
current_model_name = None |
|
available_models = { |
|
"qwen-coder-3-30b": "Qwen/Qwen3-Coder-30B-A3B-Instruct", |
|
"qwen-4b-thinking": "Qwen/Qwen3-4B-Thinking-2507" |
|
} |
|
|
|
|
|
def load_model(model_id: str = "qwen-coder-3-30b"): |
|
"""Load the specified Qwen model and tokenizer""" |
|
global tokenizer, model, current_model_name |
|
|
|
try: |
|
if model_id not in available_models: |
|
raise ValueError(f"Unknown model ID: {model_id}") |
|
|
|
model_name = available_models[model_id] |
|
|
|
|
|
if current_model_name == model_name: |
|
logger.info(f"Model {model_name} is already loaded") |
|
return |
|
|
|
|
|
if model is not None: |
|
del model |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
logger.info(f"Loading model: {model_name}") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
|
|
if model_id == "qwen-4b-thinking": |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True |
|
) |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
current_model_name = model_name |
|
logger.info(f"Model {model_name} loaded successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model {model_id}: {e}") |
|
|
|
logger.warning("Using fallback model response") |
|
def generate_response(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048, model_id: str = "qwen-coder-3-30b"): |
|
"""Generate response from the model""" |
|
try: |
|
|
|
if model is None or current_model_name != available_models.get(model_id): |
|
load_model(model_id) |
|
|
|
if model is None or tokenizer is None: |
|
|
|
return f"I'm a Qwen AI assistant ({model_id}). The model is currently loading, please try again in a moment." |
|
|
|
|
|
formatted_messages = [] |
|
for msg in messages: |
|
formatted_messages.append({"role": msg.role, "content": msg.content}) |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
formatted_messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
if model_id == "qwen-4b-thinking": |
|
|
|
max_tokens = min(max_tokens, 1024) |
|
temperature = min(temperature, 0.8) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
return response.strip() |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating response: {e}") |
|
return f"I apologize, but I encountered an error while processing your request: {str(e)}" |
|
|
|
def generate_streaming_response(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048, model_id: str = "qwen-coder-3-30b"): |
|
"""Generate streaming response from the model""" |
|
try: |
|
|
|
if model is None or current_model_name != available_models.get(model_id): |
|
load_model(model_id) |
|
|
|
if model is None or tokenizer is None: |
|
|
|
response = f"I'm a Qwen AI assistant ({model_id}). The model is currently loading, please try again in a moment." |
|
for char in response: |
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': char}}]})}\n\n" |
|
time.sleep(0.05) |
|
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
|
yield "data: [DONE]\n\n" |
|
return |
|
|
|
|
|
formatted_messages = [] |
|
for msg in messages: |
|
formatted_messages.append({"role": msg.role, "content": msg.content}) |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
formatted_messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
if model_id == "qwen-4b-thinking": |
|
max_tokens = min(max_tokens, 1024) |
|
temperature = min(temperature, 0.8) |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generation_kwargs = { |
|
**inputs, |
|
"max_new_tokens": max_tokens, |
|
"temperature": temperature, |
|
"do_sample": True, |
|
"pad_token_id": tokenizer.eos_token_id, |
|
"streamer": streamer |
|
} |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
|
|
for new_text in streamer: |
|
if new_text: |
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': new_text}}]})}\n\n" |
|
|
|
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
|
yield "data: [DONE]\n\n" |
|
|
|
except Exception as e: |
|
logger.error(f"Error in streaming generation: {e}") |
|
error_msg = f"Error: {str(e)}" |
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': error_msg}}]})}\n\n" |
|
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
|
yield "data: [DONE]\n\n" |
|
|
|
def generate_plain_text_stream(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048, model_id: str = "qwen-coder-3-30b"): |
|
"""Plain text streaming generator used by /chat compatibility endpoint (no SSE).""" |
|
try: |
|
|
|
if model is None or current_model_name != available_models.get(model_id): |
|
load_model(model_id) |
|
|
|
if model is None or tokenizer is None: |
|
|
|
response = f"I'm a Qwen AI assistant ({model_id}). The model is currently loading, please try again in a moment." |
|
for ch in response: |
|
yield ch |
|
time.sleep(0.02) |
|
return |
|
|
|
|
|
formatted_messages = [{"role": m.role, "content": m.content} for m in messages] |
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
formatted_messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
if model_id == "qwen-4b-thinking": |
|
max_tokens = min(max_tokens, 1024) |
|
temperature = min(temperature, 0.8) |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
generation_kwargs = { |
|
**inputs, |
|
"max_new_tokens": max_tokens, |
|
"temperature": temperature, |
|
"do_sample": True, |
|
"pad_token_id": tokenizer.eos_token_id, |
|
"streamer": streamer |
|
} |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
for new_text in streamer: |
|
if new_text: |
|
yield new_text |
|
except Exception as e: |
|
logger.error(f"Error in plain streaming generation: {e}") |
|
yield f"[error] {str(e)}" |
|
|
|
|
|
app = FastAPI(title="AI Chat API", description="OpenAI compatible interface for Qwen model") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
@app.get("/") |
|
async def serve_index(): |
|
"""Serve the main HTML file""" |
|
return FileResponse("public/index.html") |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
return {"status": "healthy", "model_loaded": model is not None} |
|
|
|
@app.get("/ping") |
|
async def ping(): |
|
"""Simple ping endpoint""" |
|
return {"status": "pong"} |
|
|
|
@app.head("/ping") |
|
async def ping_head(): |
|
"""HEAD ping for health checks""" |
|
return Response(status_code=200) |
|
|
|
@app.get("/api/models") |
|
async def list_models(): |
|
"""List available models""" |
|
return { |
|
"data": [ |
|
{ |
|
"id": "qwen-coder-3-30b", |
|
"object": "model", |
|
"created": int(time.time()), |
|
"owned_by": "qwen", |
|
"name": "Qwen 3 Coder 30B", |
|
"description": "Výkonný model pro programování" |
|
}, |
|
{ |
|
"id": "qwen-4b-thinking", |
|
"object": "model", |
|
"created": int(time.time()), |
|
"owned_by": "qwen", |
|
"name": "Qwen 4B Thinking", |
|
"description": "Rychlejší odlehčený model" |
|
} |
|
] |
|
} |
|
|
|
@app.post("/api/chat") |
|
async def chat_completion(request: ChatRequest): |
|
"""OpenAI compatible chat completion endpoint""" |
|
try: |
|
model_id = request.model or "qwen-coder-3-30b" |
|
|
|
|
|
if model_id not in available_models: |
|
raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") |
|
|
|
if request.stream: |
|
return StreamingResponse( |
|
generate_streaming_response( |
|
request.messages, |
|
request.temperature or 0.7, |
|
request.max_tokens or 2048, |
|
model_id |
|
), |
|
media_type="text/plain" |
|
) |
|
else: |
|
response_content = generate_response( |
|
request.messages, |
|
request.temperature or 0.7, |
|
request.max_tokens or 2048, |
|
model_id |
|
) |
|
|
|
return ChatResponse( |
|
id=f"chatcmpl-{int(time.time())}", |
|
created=int(time.time()), |
|
model=model_id, |
|
choices=[{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": response_content |
|
}, |
|
"finish_reason": "stop" |
|
}] |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in chat completion: {e}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
async def openai_chat_completion(request: ChatRequest): |
|
"""OpenAI API compatible endpoint""" |
|
return await chat_completion(request) |
|
|
|
@app.post("/chat") |
|
async def chat_stream_compat(payload: Dict[str, Any]): |
|
"""Compatibility endpoint for frontend streaming /chat (plain text stream).""" |
|
try: |
|
message = str(payload.get("message", "") or "").strip() |
|
history_raw = payload.get("history", []) or [] |
|
model_id = payload.get("model", "qwen-coder-3-30b") |
|
|
|
|
|
if model_id not in available_models: |
|
model_id = "qwen-coder-3-30b" |
|
|
|
history_msgs: List[ChatMessage] = [] |
|
for item in history_raw: |
|
role = item.get("role") |
|
content = item.get("content") |
|
if role and content is not None: |
|
history_msgs.append(ChatMessage(role=role, content=str(content))) |
|
|
|
if message: |
|
history_msgs.append(ChatMessage(role="user", content=message)) |
|
|
|
return StreamingResponse( |
|
generate_plain_text_stream( |
|
history_msgs, |
|
temperature=0.7, |
|
max_tokens=2048, |
|
model_id=model_id |
|
), |
|
media_type="text/plain; charset=utf-8" |
|
) |
|
except Exception as e: |
|
logger.error(f"/chat compatibility error: {e}") |
|
raise HTTPException(status_code=400, detail="Invalid request body") |
|
|
|
|
|
app.mount("/", StaticFiles(directory="public", html=True), name="static") |
|
|
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Initialize the default model on startup""" |
|
|
|
thread = Thread(target=load_model, args=("qwen-coder-3-30b",)) |
|
thread.daemon = True |
|
thread.start() |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
|
uvicorn.run( |
|
app, |
|
host="0.0.0.0", |
|
port=port, |
|
access_log=True |
|
) |