Qwen3 / utils /model_utils.py
Semnykcz's picture
Upload 8 files
ac5ebc8 verified
raw
history blame
3.58 kB
"""
Model utilities for working with Qwen/Qwen3-Coder-30B-A3B-Instruct model
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import logging
from typing import Generator, Optional
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Model configuration
MODEL_NAME = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
DEFAULT_MAX_TOKENS = 1024
DEFAULT_TEMPERATURE = 0.7
class ModelManager:
"""Manage Qwen model loading and inference"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
def load_model(self) -> None:
"""Load the Qwen model"""
try:
logger.info(f"Loading model {MODEL_NAME} on {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
low_cpu_mem_usage=True,
device_map="auto"
)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def generate_response(self, prompt: str, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE) -> str:
"""Generate response from the model"""
try:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Generate without streaming for simple response
generated = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(generated[0], skip_special_tokens=True)
# Remove the prompt from the response
response = response[len(prompt):].strip()
return response
except Exception as e:
logger.error(f"Error generating response: {e}")
raise
def generate_streaming_response(self, prompt: str, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE) -> Generator[str, None, None]:
"""Generate streaming response from the model"""
try:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Create streamer for streaming response
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
# Start generation in a separate thread
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens as they are generated
for new_text in streamer:
yield new_text
except Exception as e:
logger.error(f"Error generating streaming response: {e}")
yield f"Error: {str(e)}"