grahamwhiteuk commited on
Commit
da09cca
·
1 Parent(s): 05fd483

feat: granite 3.1 with model selection

Browse files

Signed-off-by: Graham White <[email protected]>

Files changed (4) hide show
  1. pyproject.toml +7 -3
  2. src/app.css +14 -0
  3. src/app.py +65 -26
  4. src/app_head.html +4 -0
pyproject.toml CHANGED
@@ -1,8 +1,12 @@
1
  [tool.poetry]
2
- name = "huggingface-gradio-template"
3
  version = "0.1.0"
4
- description = "A boilerplate template for an IBM Granite Huggingface Spaces Gradio Demo"
5
- authors = ["James Sutton <[email protected]>"]
 
 
 
 
6
  license = "Apache-2.0"
7
  readme = "README.md"
8
  package-mode = false
 
1
  [tool.poetry]
2
+ name = "granite-3.1-8b-instruct"
3
  version = "0.1.0"
4
+ description = "A demo of the IBM Granite 3.1 8b instruct model"
5
+ authors = [
6
+ "James Sutton <[email protected]>",
7
+ "Graham White <[email protected]>",
8
+ "Michael Desmond <[email protected]>",
9
+ ]
10
  license = "Apache-2.0"
11
  readme = "README.md"
12
  package-mode = false
src/app.css CHANGED
@@ -1,3 +1,17 @@
1
  footer {
2
  display: none !important;
3
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  footer {
2
  display: none !important;
3
  }
4
+ .gr_docs_link {
5
+ float: right;
6
+ font-size: var(--text-xs);
7
+ margin-top: -8px;
8
+ }
9
+ .gr_title {
10
+ display: flex;
11
+ align-items: center;
12
+ }
13
+ .gr_title img {
14
+ max-height: 40px;
15
+ margin-right: 1rem;
16
+ margin-bottom: -10px;
17
+ }
src/app.py CHANGED
@@ -14,25 +14,28 @@ from themes.carbon import carbon_theme
14
 
15
  today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
16
 
17
- MODEL_ID = "ibm-granite/granite-3.1-8b-instruct"
18
  SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
19
  Today's Date: {today_date}.
20
  You are Granite, developed by IBM. You are a helpful AI assistant"""
21
  TITLE = "IBM Granite 3.1 8b Instruct"
22
  DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, just like developers, \
23
  AI models can make mistakes."
24
- MAX_INPUT_TOKEN_LENGTH = 4096
25
  MAX_NEW_TOKENS = 1024
26
  TEMPERATURE = 0.7
27
  TOP_P = 0.85
28
  TOP_K = 50
29
  REPETITION_PENALTY = 1.05
30
 
 
 
31
  if not torch.cuda.is_available():
32
  DESCRIPTION += "\nThis demo does not work on CPU."
33
 
34
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")
35
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
36
  tokenizer.use_default_system_prompt = False
37
 
38
 
@@ -46,11 +49,13 @@ def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
46
  conversation.append({"role": "user", "content": message})
47
 
48
  # Convert messages to prompt format
49
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
50
-
51
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
52
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
54
 
55
  input_ids = input_ids.to(model.device)
56
  streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
@@ -75,28 +80,62 @@ def generate(message: str, chat_history: list[dict]) -> Iterator[str]:
75
  yield "".join(outputs)
76
 
77
 
78
- chat_interface = gr.ChatInterface(
79
- fn=generate,
80
- stop_btn=None,
81
- examples=[
82
- ["Explain quantum computing"],
83
- ["What is OpenShift?"],
84
- ["Importance of low latency inference"],
85
- ["Boosting productivity habits"],
86
- ],
87
- cache_examples=False,
88
- type="messages",
89
- )
90
-
91
  css_file_path = Path(Path(__file__).parent / "app.css")
92
  head_file_path = Path(Path(__file__).parent / "app_head.html")
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  with gr.Blocks(
95
  fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE
96
  ) as demo:
97
- gr.Markdown(f"# {TITLE}")
98
- gr.Markdown(DESCRIPTION)
99
- chat_interface.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  if __name__ == "__main__":
102
- demo.queue(max_size=20).launch()
 
14
 
15
  today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
16
 
 
17
  SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
18
  Today's Date: {today_date}.
19
  You are Granite, developed by IBM. You are a helpful AI assistant"""
20
  TITLE = "IBM Granite 3.1 8b Instruct"
21
  DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, just like developers, \
22
  AI models can make mistakes."
23
+ MAX_INPUT_TOKEN_LENGTH = 128_000
24
  MAX_NEW_TOKENS = 1024
25
  TEMPERATURE = 0.7
26
  TOP_P = 0.85
27
  TOP_K = 50
28
  REPETITION_PENALTY = 1.05
29
 
30
+ model_list = ["granite-3.1-8b-instruct", "granite-3.1-2b-instruct"]
31
+
32
  if not torch.cuda.is_available():
33
  DESCRIPTION += "\nThis demo does not work on CPU."
34
 
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
37
+ )
38
+ tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
39
  tokenizer.use_default_system_prompt = False
40
 
41
 
 
49
  conversation.append({"role": "user", "content": message})
50
 
51
  # Convert messages to prompt format
52
+ input_ids = tokenizer.apply_chat_template(
53
+ conversation,
54
+ return_tensors="pt",
55
+ add_generation_prompt=True,
56
+ truncation=True,
57
+ max_length=MAX_INPUT_TOKEN_LENGTH,
58
+ )
59
 
60
  input_ids = input_ids.to(model.device)
61
  streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
 
80
  yield "".join(outputs)
81
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  css_file_path = Path(Path(__file__).parent / "app.css")
84
  head_file_path = Path(Path(__file__).parent / "app_head.html")
85
 
86
+
87
+ def on_model_dropdown_change(model_name: str) -> list:
88
+ """Event handler for dropdown."""
89
+ global model
90
+ global tokenizer
91
+
92
+ model = AutoModelForCausalLM.from_pretrained(
93
+ f"ibm-granite/{model_name}", torch_dtype=torch.float16, device_map="auto"
94
+ )
95
+ tokenizer = AutoTokenizer.from_pretrained(f"ibm-granite/{model_name}")
96
+ tokenizer.use_default_system_prompt = False
97
+
98
+ # clear the chat interface when the model dropdown is changed
99
+ # works around https://github.com/gradio-app/gradio/issues/10343
100
+ return [None, []]
101
+
102
+
103
  with gr.Blocks(
104
  fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE
105
  ) as demo:
106
+ gr.HTML(
107
+ f"<img src='https://www.ibm.com/granite/docs/images/granite-cubes-352x368.webp'/><h1>{TITLE}</h1>",
108
+ elem_classes=["gr_title"],
109
+ )
110
+ gr.HTML(DESCRIPTION)
111
+ model_dropdown = gr.Dropdown(
112
+ choices=model_list,
113
+ value="granite-3.1-8b-instruct",
114
+ interactive=True,
115
+ label="Model",
116
+ filterable=False,
117
+ )
118
+ gr.HTML(
119
+ value='<a href="https://www.ibm.com/granite/docs/">View Documentation</a> <i class="fa fa-external-link"></i>',
120
+ elem_classes=["gr_docs_link"],
121
+ )
122
+ chat_interface = gr.ChatInterface(
123
+ fn=generate,
124
+ examples=[
125
+ ["Explain quantum computing"],
126
+ ["What is OpenShift?"],
127
+ ["Importance of low latency inference"],
128
+ ["Boosting productivity habits"],
129
+ ],
130
+ cache_examples=False,
131
+ type="messages",
132
+ )
133
+
134
+ model_dropdown.change(
135
+ fn=on_model_dropdown_change,
136
+ inputs=model_dropdown,
137
+ outputs=[chat_interface.chatbot, chat_interface.chatbot_state],
138
+ )
139
 
140
  if __name__ == "__main__":
141
+ demo.queue().launch()
src/app_head.html CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  <script
2
  async
3
  src="https://www.googletagmanager.com/gtag/js?id=G-C6LFT227RC"
 
1
+ <link
2
+ rel="stylesheet"
3
+ href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css"
4
+ />
5
  <script
6
  async
7
  src="https://www.googletagmanager.com/gtag/js?id=G-C6LFT227RC"