|
""" |
|
Model utilities for working with Qwen/Qwen3-Coder-30B-A3B-Instruct model |
|
""" |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
from threading import Thread |
|
import logging |
|
from typing import Generator, Optional |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_NAME = "Qwen/Qwen3-Coder-30B-A3B-Instruct" |
|
DEFAULT_MAX_TOKENS = 1024 |
|
DEFAULT_TEMPERATURE = 0.7 |
|
|
|
class ModelManager: |
|
"""Manage Qwen model loading and inference""" |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.tokenizer = None |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.load_model() |
|
|
|
def load_model(self) -> None: |
|
"""Load the Qwen model""" |
|
try: |
|
logger.info(f"Loading model {MODEL_NAME} on {self.device}") |
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map="auto" |
|
) |
|
logger.info("Model loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error loading model: {e}") |
|
raise |
|
|
|
def generate_response(self, prompt: str, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE) -> str: |
|
"""Generate response from the model""" |
|
try: |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
generated = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
response = self.tokenizer.decode(generated[0], skip_special_tokens=True) |
|
|
|
response = response[len(prompt):].strip() |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error generating response: {e}") |
|
raise |
|
|
|
def generate_streaming_response(self, prompt: str, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE) -> Generator[str, None, None]: |
|
"""Generate streaming response from the model""" |
|
try: |
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = dict( |
|
inputs, |
|
streamer=streamer, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id |
|
) |
|
|
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
|
|
for new_text in streamer: |
|
yield new_text |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating streaming response: {e}") |
|
yield f"Error: {str(e)}" |