File size: 2,334 Bytes
65be8e9
f0f57c7
 
65be8e9
f0f57c7
 
 
65be8e9
 
f0f57c7
 
 
 
65be8e9
f0f57c7
65be8e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0f57c7
 
65be8e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# hf_model_adapter.py
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import traceback

class HFLocalModelAdapter:
    """
    Minimal Hugging Face model adapter for text generation with simple
    error handling and device selection.
    """

    def __init__(self, model_name="stabilityai/stablelm-3b-4e1t", device=None):
        self.model_name = model_name
        # choose device
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Loading model {model_name} on device: {self.device}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=(torch.float16 if "cuda" in self.device else torch.float32),
                low_cpu_mem_usage=True,
                device_map="auto" if "cuda" in self.device else None,
            )
            # move model to device if device_map not used
            try:
                self.model.to(self.device)
            except Exception:
                # device_map="auto" may already have placed model
                pass
            print("Model loaded successfully.")
        except Exception as e:
            print("Error loading model:", e)
            traceback.print_exc()
            raise

    def generate(self, prompt, max_new_tokens=250, temperature=0.7, top_p=0.95):
        """
        Returns generated text (string). Tries to trim the prompt from the output.
        """
        try:
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            out = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            decoded = self.tokenizer.decode(out[0], skip_special_tokens=True)
            # If model echoes the prompt, trim it
            if decoded.startswith(prompt):
                return decoded[len(prompt):].strip()
            return decoded.strip()
        except Exception as e:
            # return error message so UI doesn't crash
            return f"[Generation error] {e}"