hf-space-autogen-radio / hf_model_adapter.py
swapniild1601's picture
Update hf_model_adapter.py
65be8e9 verified
# hf_model_adapter.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import traceback
class HFLocalModelAdapter:
"""
Minimal Hugging Face model adapter for text generation with simple
error handling and device selection.
"""
def __init__(self, model_name="stabilityai/stablelm-3b-4e1t", device=None):
self.model_name = model_name
# choose device
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Loading model {model_name} on device: {self.device}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=(torch.float16 if "cuda" in self.device else torch.float32),
low_cpu_mem_usage=True,
device_map="auto" if "cuda" in self.device else None,
)
# move model to device if device_map not used
try:
self.model.to(self.device)
except Exception:
# device_map="auto" may already have placed model
pass
print("Model loaded successfully.")
except Exception as e:
print("Error loading model:", e)
traceback.print_exc()
raise
def generate(self, prompt, max_new_tokens=250, temperature=0.7, top_p=0.95):
"""
Returns generated text (string). Tries to trim the prompt from the output.
"""
try:
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
out = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=self.tokenizer.eos_token_id,
)
decoded = self.tokenizer.decode(out[0], skip_special_tokens=True)
# If model echoes the prompt, trim it
if decoded.startswith(prompt):
return decoded[len(prompt):].strip()
return decoded.strip()
except Exception as e:
# return error message so UI doesn't crash
return f"[Generation error] {e}"