Artples commited on
Commit
6f346c7
·
verified ·
1 Parent(s): e05cd4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -23
app.py CHANGED
@@ -2,27 +2,33 @@ import os
2
  import gradio as gr
3
  import spaces
4
  import torch
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
 
7
  MAX_MAX_NEW_TOKENS = 2048
8
  DEFAULT_MAX_NEW_TOKENS = 1024
9
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
10
 
 
 
 
 
 
 
 
11
  DESCRIPTION = """\
12
  # L-MChat
13
  This Space demonstrates [L-MChat](https://huggingface.co/collections/Artples/l-mchat-663265a8351231c428318a8f) by L-AI.
14
  """
15
 
 
16
  if not torch.cuda.is_available():
17
  DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
18
 
19
- model_details = {
20
- "Fast-Model": "Artples/L-MChat-Small",
21
- "Quality-Model": "Artples/L-MChat-7b"
22
- }
23
-
24
- models = {name: AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") for name, model_id in model_details.items()}
25
- tokenizers = {name: AutoTokenizer.from_pretrained(model_id) for name, model_id in model_details.items()}
26
 
27
  @spaces.GPU(enable_queue=True, duration=90)
28
  def generate(
@@ -39,35 +45,52 @@ def generate(
39
  model = models[model_choice]
40
  tokenizer = tokenizers[model_choice]
41
 
42
- conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
43
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}] for user, assistant in chat_history)
 
 
 
44
  conversation.append({"role": "user", "content": message})
45
 
46
  input_ids = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).input_ids
47
  input_ids = input_ids.to(model.device)
48
 
49
- output_ids = model.generate(input_ids, max_length=MAX_INPUT_TOKEN_LENGTH + max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
50
- output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
51
-
52
- return output_text
 
 
 
 
 
53
 
 
54
 
55
- chat_interface = gr.ChatInterface(
56
- theme='ehristoforu/RE_Theme',
57
  fn=generate,
58
- additional_inputs=[gr.Textbox(label="System prompt", lines=6), gr.Dropdown(label="Model Choice", choices=list(model_details.keys()), value="Quality-Model")],
59
- examples=[
60
- ["Hello there! How are you doing?"],
61
- ["Can you explain briefly to me what is the Python programming language?"],
62
- ["Explain the plot of Cinderella in a sentence."],
63
- ["How many hours does it take a man to eat a Helicopter?"],
64
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
 
 
65
  ],
 
 
 
 
 
66
  )
67
 
 
68
  with gr.Blocks(css="style.css") as demo:
69
  gr.Markdown(DESCRIPTION)
70
  chat_interface.render()
71
 
72
  if __name__ == "__main__":
73
- demo.launch()
 
2
  import gradio as gr
3
  import spaces
4
  import torch
5
+ from threading import Thread
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
+ # Constants for model behavior
9
  MAX_MAX_NEW_TOKENS = 2048
10
  DEFAULT_MAX_NEW_TOKENS = 1024
11
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
12
 
13
+ # Models selection
14
+ MODELS = {
15
+ "Fast-Model": "Artples/L-MChat-Small",
16
+ "Quality-Model": "Artples/L-MChat-7b"
17
+ }
18
+
19
+ # Description for the application
20
  DESCRIPTION = """\
21
  # L-MChat
22
  This Space demonstrates [L-MChat](https://huggingface.co/collections/Artples/l-mchat-663265a8351231c428318a8f) by L-AI.
23
  """
24
 
25
+ # Check for GPU availability
26
  if not torch.cuda.is_available():
27
  DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
28
 
29
+ # Load models and tokenizers
30
+ models = {name: AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") for name, model_id in MODELS.items()}
31
+ tokenizers = {name: AutoTokenizer.from_pretrained(model_id) for name, model_id in MODELS.items()}
 
 
 
 
32
 
33
  @spaces.GPU(enable_queue=True, duration=90)
34
  def generate(
 
45
  model = models[model_choice]
46
  tokenizer = tokenizers[model_choice]
47
 
48
+ conversation = []
49
+ if system_prompt:
50
+ conversation.append({"role": "system", "content": system_prompt})
51
+ for user, assistant in chat_history:
52
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
53
  conversation.append({"role": "user", "content": message})
54
 
55
  input_ids = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).input_ids
56
  input_ids = input_ids.to(model.device)
57
 
58
+ output_ids = model.generate(
59
+ input_ids,
60
+ max_length=input_ids.shape[1] + max_new_tokens,
61
+ temperature=temperature,
62
+ top_p=top_p,
63
+ top_k=top_k,
64
+ repetition_penalty=repetition_penalty,
65
+ num_return_sequences=1,
66
+ )
67
 
68
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
69
 
70
+ # Gradio Interface
71
+ chat_interface = gr.Interface(
72
  fn=generate,
73
+ inputs=[
74
+ gr.Dropdown(label="Choose Model", choices=list(MODELS.keys()), default="Quality-Model"),
75
+ gr.ChatBox(),
76
+ gr.Textbox(label="System prompt", lines=6),
77
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
78
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
79
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
80
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
81
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
82
  ],
83
+ outputs="text",
84
+ theme='ehristoforu/RE_Theme',
85
+ examples=[
86
+ ["Quality-Model", "Hello there! How are you doing?", [], "Let's start the conversation.", 1024, 0.6, 0.9, 50, 1.2]
87
+ ]
88
  )
89
 
90
+ # Main execution block
91
  with gr.Blocks(css="style.css") as demo:
92
  gr.Markdown(DESCRIPTION)
93
  chat_interface.render()
94
 
95
  if __name__ == "__main__":
96
+ demo.queue(max_size=20).launch()