|
|
|
|
|
""" |
|
|
AI Chat Application - Pure FastAPI Backend |
|
|
Serves custom frontend with OpenAI compatible API |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import logging |
|
|
import time |
|
|
from typing import Optional, Dict, Any, Generator, List |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
from fastapi import FastAPI, HTTPException, Response |
|
|
from fastapi.responses import StreamingResponse, FileResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import asyncio |
|
|
import threading |
|
|
from threading import Thread |
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[ChatMessage] |
|
|
model: Optional[str] = "qwen-coder-3-30b" |
|
|
temperature: Optional[float] = 0.7 |
|
|
max_tokens: Optional[int] = 2048 |
|
|
stream: Optional[bool] = False |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
id: str |
|
|
object: str = "chat.completion" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[Dict[str, Any]] |
|
|
|
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
|
|
|
def load_model(): |
|
|
"""Load the Qwen model and tokenizer""" |
|
|
global tokenizer, model |
|
|
|
|
|
try: |
|
|
model_name = "Qwen/Qwen3-Coder-30B-A3B-Instruct" |
|
|
|
|
|
logger.info(f"Loading model: {model_name}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
logger.info("Model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {e}") |
|
|
|
|
|
logger.warning("Using fallback model response") |
|
|
|
|
|
def generate_response(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048): |
|
|
"""Generate response from the model""" |
|
|
try: |
|
|
if model is None or tokenizer is None: |
|
|
|
|
|
return "I'm a Qwen AI assistant. The model is currently loading, please try again in a moment." |
|
|
|
|
|
|
|
|
formatted_messages = [] |
|
|
for msg in messages: |
|
|
formatted_messages.append({"role": msg.role, "content": msg.content}) |
|
|
|
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
|
formatted_messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
return response.strip() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating response: {e}") |
|
|
return f"I apologize, but I encountered an error while processing your request: {str(e)}" |
|
|
|
|
|
def generate_streaming_response(messages: List[ChatMessage], temperature: float = 0.7, max_tokens: int = 2048): |
|
|
"""Generate streaming response from the model""" |
|
|
try: |
|
|
if model is None or tokenizer is None: |
|
|
|
|
|
response = "I'm a Qwen AI assistant. The model is currently loading, please try again in a moment." |
|
|
for char in response: |
|
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': char}}]})}\n\n" |
|
|
time.sleep(0.05) |
|
|
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
return |
|
|
|
|
|
|
|
|
formatted_messages = [] |
|
|
for msg in messages: |
|
|
formatted_messages.append({"role": msg.role, "content": msg.content}) |
|
|
|
|
|
|
|
|
text = tokenizer.apply_chat_template( |
|
|
formatted_messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = { |
|
|
**inputs, |
|
|
"max_new_tokens": max_tokens, |
|
|
"temperature": temperature, |
|
|
"do_sample": True, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
"streamer": streamer |
|
|
} |
|
|
|
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
|
|
|
for new_text in streamer: |
|
|
if new_text: |
|
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': new_text}}]})}\n\n" |
|
|
|
|
|
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in streaming generation: {e}") |
|
|
error_msg = f"Error: {str(e)}" |
|
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': error_msg}}]})}\n\n" |
|
|
yield f"data: {json.dumps({'choices': [{'finish_reason': 'stop'}]})}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
|
|
|
app = FastAPI(title="AI Chat API", description="OpenAI compatible interface for Qwen model") |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def serve_index(): |
|
|
"""Serve the main HTML file""" |
|
|
return FileResponse("public/index.html") |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return {"status": "healthy", "model_loaded": model is not None} |
|
|
|
|
|
@app.get("/ping") |
|
|
async def ping(): |
|
|
"""Simple ping endpoint""" |
|
|
return {"status": "pong"} |
|
|
|
|
|
@app.get("/api/models") |
|
|
async def list_models(): |
|
|
"""List available models""" |
|
|
return { |
|
|
"data": [ |
|
|
{ |
|
|
"id": "qwen-coder-3-30b", |
|
|
"object": "model", |
|
|
"created": int(time.time()), |
|
|
"owned_by": "qwen" |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
@app.post("/api/chat") |
|
|
async def chat_completion(request: ChatRequest): |
|
|
"""OpenAI compatible chat completion endpoint""" |
|
|
try: |
|
|
if request.stream: |
|
|
return StreamingResponse( |
|
|
generate_streaming_response( |
|
|
request.messages, |
|
|
request.temperature or 0.7, |
|
|
request.max_tokens or 2048 |
|
|
), |
|
|
media_type="text/plain" |
|
|
) |
|
|
else: |
|
|
response_content = generate_response( |
|
|
request.messages, |
|
|
request.temperature or 0.7, |
|
|
request.max_tokens or 2048 |
|
|
) |
|
|
|
|
|
return ChatResponse( |
|
|
id=f"chatcmpl-{int(time.time())}", |
|
|
created=int(time.time()), |
|
|
model=request.model or "qwen-coder-3-30b", |
|
|
choices=[{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": response_content |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
}] |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in chat completion: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def openai_chat_completion(request: ChatRequest): |
|
|
"""OpenAI API compatible endpoint""" |
|
|
return await chat_completion(request) |
|
|
|
|
|
|
|
|
app.mount("/", StaticFiles(directory="public", html=True), name="static") |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Initialize the model on startup""" |
|
|
|
|
|
thread = Thread(target=load_model) |
|
|
thread.daemon = True |
|
|
thread.start() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
|
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=port, |
|
|
access_log=True |
|
|
) |