nisten commited on
Commit
5598c41
1 Parent(s): 247d769

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -3,11 +3,9 @@ import spaces
3
  import torch
4
  import subprocess
5
  import sys
6
- from threading import Thread
7
- from transformers import TextIteratorStreamer
8
 
9
  # Install required packages
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "transformers", "sentencepiece", "torch"])
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
  from transformers import OlmoeForCausalLM, AutoTokenizer
@@ -20,7 +18,7 @@ try:
20
  model = OlmoeForCausalLM.from_pretrained(
21
  model_name,
22
  trust_remote_code=True,
23
- torch_dtype=torch.bfloat16, # Using float16 for lower precision
24
  low_cpu_mem_usage=True,
25
  device_map="auto",
26
  _attn_implementation="flash_attention_2" # Enable Flash Attention 2
@@ -53,26 +51,32 @@ def generate_response(message, history, temperature, max_new_tokens):
53
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
54
 
55
  try:
56
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
57
  generation_kwargs = dict(
58
  inputs=inputs,
59
- streamer=streamer,
60
  max_new_tokens=max_new_tokens,
61
  do_sample=True,
62
  temperature=temperature,
63
  eos_token_id=tokenizer.eos_token_id,
 
64
  )
65
 
66
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
67
  thread.start()
68
 
69
- partial_message = ""
70
  for new_text in streamer:
71
- partial_message += new_text
72
- yield partial_message.strip()
73
-
 
 
 
 
 
 
74
  except Exception as e:
75
- yield f"An error occurred: {str(e)}"
76
 
77
  css = """
78
  #output {
@@ -83,12 +87,12 @@ css = """
83
  """
84
 
85
  with gr.Blocks(css=css) as demo:
86
- gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Flash Attention 2 and Streaming!)")
87
  chatbot = gr.Chatbot(elem_id="output")
88
  msg = gr.Textbox(label="Meow")
89
  with gr.Row():
90
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
91
- max_new_tokens = gr.Slider(minimum=50, maximum=8000, value=2000, step=50, label="Max New Tokens")
92
  clear = gr.Button("Clear")
93
 
94
  def user(user_message, history):
 
3
  import torch
4
  import subprocess
5
  import sys
 
 
6
 
7
  # Install required packages
8
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "einops", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
  from transformers import OlmoeForCausalLM, AutoTokenizer
 
18
  model = OlmoeForCausalLM.from_pretrained(
19
  model_name,
20
  trust_remote_code=True,
21
+ torch_dtype=torch.float16, # Using float16 for lower precision
22
  low_cpu_mem_usage=True,
23
  device_map="auto",
24
  _attn_implementation="flash_attention_2" # Enable Flash Attention 2
 
51
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
52
 
53
  try:
54
+ streamer = gr.TextIteratorStreamer(tokenizer, skip_special_tokens=True)
55
  generation_kwargs = dict(
56
  inputs=inputs,
 
57
  max_new_tokens=max_new_tokens,
58
  do_sample=True,
59
  temperature=temperature,
60
  eos_token_id=tokenizer.eos_token_id,
61
+ streamer=streamer
62
  )
63
 
64
+ thread = torch.multiprocessing.Process(target=model.generate, kwargs=generation_kwargs)
65
  thread.start()
66
 
67
+ generated_text = ""
68
  for new_text in streamer:
69
+ generated_text += new_text
70
+ yield generated_text.strip()
71
+
72
+ thread.join()
73
+ except RuntimeError as e:
74
+ if "CUDA out of memory" in str(e):
75
+ yield "GPU memory exceeded. Try reducing the max tokens or using a smaller model."
76
+ else:
77
+ yield f"An error occurred: {str(e)}"
78
  except Exception as e:
79
+ yield f"An unexpected error occurred: {str(e)}"
80
 
81
  css = """
82
  #output {
 
87
  """
88
 
89
  with gr.Blocks(css=css) as demo:
90
+ gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Flash Attention 2!)")
91
  chatbot = gr.Chatbot(elem_id="output")
92
  msg = gr.Textbox(label="Meow")
93
  with gr.Row():
94
  temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
95
+ max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=2000, step=50, label="Max New Tokens")
96
  clear = gr.Button("Clear")
97
 
98
  def user(user_message, history):