Nile-Chat-12B / app.py
habdine's picture
Update app.py
5baef4a verified
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = """\
# ๐Ÿž๏ธ๐Ÿž๏ธ JAIS Initiative: Nile-Chat-12B ๐Ÿž๏ธ๐Ÿž๏ธ
Disclaimer: This research demonstration of Nile-Chat-12B is not intended for end-user applications. The model may generate biased, offensive, or inaccurate content as it is trained on diverse internet data. The developers do not endorse any views expressed by the model and assume no responsibility for the consequences of its use. Users should critically evaluate the generated responses and use the tool at their own risk.
Note: The model is expected to take input and generate output in Egyptian with both Arabic and Latin scripts.
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "MBZUAI-Paris/Nile-Chat-12B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
do_sample: bool = False,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.1,
) -> Iterator[str]:
conversation = chat_history.copy()
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Checkbox(label="Do Sample", value=True),
gr.Slider(
label="Temperature",
minimum=0.0,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.1,
),
],
stop_btn=None,
examples=[
["ู…ูŠู† ุงู„ู„ูŠ ุนู…ู„ูƒุŸ"],
["ุงุณู…ูƒ ุงูŠู‡ุŸ"],
["Esmak eh?"],
["ุชุฑุฌู… ู„ู„ู…ุตุฑูŠุฉ:\nWith a total length of about 6,650 km between the region of Lake Victoria and the Mediterranean Sea, the Nile is among the longest rivers on Earth."],
],
cache_examples=False,
type="messages",
)
with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()