File size: 6,755 Bytes
b376f12
 
 
2e11c33
c3fb36e
 
3c2d17d
b376f12
c3fb36e
 
 
 
b376f12
 
c3fb36e
 
 
b376f12
 
 
 
 
 
c3fb36e
b376f12
c3fb36e
1ae9a3e
c3fb36e
815fe38
c3fb36e
 
 
 
 
1ae9a3e
c3fb36e
fed41be
c3fb36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fed41be
c3fb36e
53cb438
7fff69d
b376f12
3c2d17d
03c2ae6
 
2e11c33
03c2ae6
db24877
03c2ae6
 
 
 
c3fb36e
2e11c33
c3fb36e
 
 
2e11c33
c3fb36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e11c33
c3fb36e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e11c33
c3fb36e
 
b376f12
 
c3fb36e
 
b376f12
 
03c2ae6
edd0bac
c3fb36e
edd0bac
c3fb36e
edd0bac
c3fb36e
03c2ae6
edd0bac
 
 
db24877
edd0bac
 
 
 
 
 
 
 
 
 
03c2ae6
 
da09cca
ead1131
cbef7a0
da09cca
 
 
 
edb6fb6
3ba38dc
da09cca
19342c6
 
3ba38dc
 
58cc8dc
3ba38dc
 
 
 
da09cca
 
 
03c2ae6
 
d7174fa
03c2ae6
 
 
 
 
da09cca
 
b376f12
c3fb36e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from typing import Iterator, List, Dict
from huggingface_hub import hf_hub_download
from themes.research_monochrome import ResearchMonochrome
import spaces
import gradio as gr
from llama_cpp import Llama # <-- Neu: Llama-Klasse importieren
import os

# --- Konfiguration ---

today_date = datetime.today().strftime("%B %-d, %Y")  # noqa: DTZ002
SYS_PROMPT = f"""Today's Date: {today_date}.You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 4 Tiny Preview served via llama-cpp-python"
DESCRIPTION = """<p>Granite 4 Tiny is an open-source LLM supporting a 128k context window. This demo uses only 2K context.<span class="gr_docs_link"><a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a></span></p>"""

MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05
CONTEXT_WINDOW = 2048 # Kontextfenstergröße setzen

# --- Modell-Setup ---

# Modell herunterladen
gguf_name = "granite-4.0-tiny-preview-Q4_K_M.gguf"
# Der Pfad, in dem das Modell gespeichert wird
model_path = hf_hub_download(
    repo_id="ibm-granite/granite-4.0-tiny-preview-GGUF",
    filename=gguf_name,
    local_dir="."
)
print(f"Model downloaded to: {model_path}")

# Llama-Modell laden
# Hinweis: Die Anzahl der Schichten, die auf die GPU entladen werden (n_gpu_layers),
# sollte auf einen hohen Wert wie 999 gesetzt werden, um die gesamte GPU-Auslagerung zu erzwingen.
# 'n_ctx' setzt die Kontextgröße.
# 'chat_format' wird für die korrekte Formatierung der Konversation benötigt.
try:
    llama_model = Llama(
        model_path=model_path,
        n_ctx=CONTEXT_WINDOW,
        n_gpu_layers=999, # Entlädt alle Schichten auf die GPU
        chat_format="chatml", # Granite 4 Tiny verwendet ein Format, das dem ChatML-Standard ähnelt
        verbose=False
    )
    print("Llama model initialized successfully.")
except Exception as e:
    print(f"Error initializing Llama model: {e}")
    llama_model = None # Setze auf None, falls ein Fehler auftritt

# --- Gradio-Funktionen ---

custom_theme = ResearchMonochrome()

@spaces.GPU(duration=30)
def generate(
    message: str,
    chat_history: List[Dict],
    temperature: float = TEMPERATURE,
    repetition_penalty: float = REPETITION_PENALTY,
    top_p: float = TOP_P,
    top_k: float = TOP_K,
    max_new_tokens: int = MAX_NEW_TOKENS,
) -> Iterator[str]:
    """Generierungsfunktion für Chat-Demo unter Verwendung von llama-cpp-python."""

    if llama_model is None:
        yield "Error: The model failed to initialize."
        return

    # 1. Nachrichten für llama-cpp-python aufbereiten
    # llama-cpp-python erwartet ein OpenAI-Chat-Format
    messages = []
    messages.append({"role": "system", "content": SYS_PROMPT})
    
    # Füge den Chatverlauf hinzu
    for item in chat_history:
        # Gradio speichert als Liste von Listen: [["user_msg", "assistant_msg"], ...]
        # Die Struktur von `chat_history` ist jedoch als Liste von Dictionaries [..., {"role": "user", "content": "..."}] 
        # aus der Gradio ChatInterface-Dokumentation (typischerweise)
        if item["role"] == "user":
            messages.append({"role": "user", "content": item["content"]})
        elif item["role"] == "assistant":
            messages.append({"role": "assistant", "content": item["content"]})
    
    # Füge die aktuelle Benutzernachricht hinzu
    messages.append({"role": "user", "content": message})
    
    # 2. Generierung starten
    full_response = ""
    try:
        # Verwende die OpenAI-kompatible Streaming-API von llama-cpp-python
        stream = llama_model.create_chat_completion_openai_v1(
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=max_new_tokens,
            repeat_penalty=repetition_penalty,
            stop=["<|file_separator|>"], # Stopp-Token wie im Original-Code
            stream=True
        )

        # 3. Streamen der Antwort
        for chunk in stream:
            if chunk and "choices" in chunk and len(chunk["choices"]) > 0:
                delta = chunk["choices"][0]["delta"]
                if "content" in delta:
                    text = delta["content"]
                    full_response += text
                    yield full_response
    
    except Exception as e:
        print(f"An error occurred during generation: {e}")
        yield f"Error: {e}"


# --- Gradio UI-Setup (Unverändert) ---

css_file_path = Path(Path(__file__).parent / "app.css")

# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(
    minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"])
top_p_slider = gr.Slider(
    minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"])
top_k_slider = gr.Slider(
    minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"])
repetition_penalty_slider = gr.Slider(
    minimum=0,
    maximum=2.0,
    value=REPETITION_PENALTY,
    step=0.05,
    label="Repetition Penalty",
    elem_classes=["gr_accordion_element"],
)
max_new_tokens_slider = gr.Slider(
    minimum=1,
    maximum=2000,
    value=MAX_NEW_TOKENS,
    step=1,
    label="Max New Tokens",
    elem_classes=["gr_accordion_element"],
)
chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)

with gr.Blocks(fill_height=True, css_paths=css_file_path, theme=custom_theme, title=TITLE) as demo:
    gr.HTML(f"<h2>{TITLE}</h2>", elem_classes=["gr_title"])
    gr.HTML(DESCRIPTION)
    chat_interface = gr.ChatInterface(
        fn=generate,
        examples=[
            ["What is 1+1?"],
            ["Explain the concept of quantum computing to someone with no background in physics or computer science."],
            ["What is OpenShift?"],
            ["What's the importance of low latency inference?"],
            ["Help me boost productivity habits."],
        ],
        example_labels=[
            "What is 1+1?",
            "Explain quantum computing",
            "What is OpenShift?",
            "Importance of low latency inference",
            "Boosting productivity habits",
        ],
        cache_examples=False,
        type="messages",
        additional_inputs=[
            temperature_slider,
            repetition_penalty_slider,
            top_p_slider,
            top_k_slider,
            max_new_tokens_slider,
        ],
        additional_inputs_accordion=chat_interface_accordion,
    )

if __name__ == "__main__":
    demo.queue().launch()