Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Commit 
							
							·
						
						da09cca
	
1
								Parent(s):
							
							05fd483
								
feat: granite 3.1 with model selection
Browse filesSigned-off-by: Graham White <[email protected]>
- pyproject.toml +7 -3
 - src/app.css +14 -0
 - src/app.py +65 -26
 - src/app_head.html +4 -0
 
    	
        pyproject.toml
    CHANGED
    
    | 
         @@ -1,8 +1,12 @@ 
     | 
|
| 1 | 
         
             
            [tool.poetry]
         
     | 
| 2 | 
         
            -
            name = " 
     | 
| 3 | 
         
             
            version = "0.1.0"
         
     | 
| 4 | 
         
            -
            description = "A  
     | 
| 5 | 
         
            -
            authors = [ 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 6 | 
         
             
            license = "Apache-2.0"
         
     | 
| 7 | 
         
             
            readme = "README.md"
         
     | 
| 8 | 
         
             
            package-mode = false
         
     | 
| 
         | 
|
| 1 | 
         
             
            [tool.poetry]
         
     | 
| 2 | 
         
            +
            name = "granite-3.1-8b-instruct"
         
     | 
| 3 | 
         
             
            version = "0.1.0"
         
     | 
| 4 | 
         
            +
            description = "A demo of the IBM Granite 3.1 8b instruct model"
         
     | 
| 5 | 
         
            +
            authors = [
         
     | 
| 6 | 
         
            +
                "James Sutton <[email protected]>",
         
     | 
| 7 | 
         
            +
                "Graham White <[email protected]>",
         
     | 
| 8 | 
         
            +
                "Michael Desmond <[email protected]>",
         
     | 
| 9 | 
         
            +
            ]
         
     | 
| 10 | 
         
             
            license = "Apache-2.0"
         
     | 
| 11 | 
         
             
            readme = "README.md"
         
     | 
| 12 | 
         
             
            package-mode = false
         
     | 
    	
        src/app.css
    CHANGED
    
    | 
         @@ -1,3 +1,17 @@ 
     | 
|
| 1 | 
         
             
            footer {
         
     | 
| 2 | 
         
             
                display: none !important;
         
     | 
| 3 | 
         
             
            }
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            footer {
         
     | 
| 2 | 
         
             
                display: none !important;
         
     | 
| 3 | 
         
             
            }
         
     | 
| 4 | 
         
            +
            .gr_docs_link {
         
     | 
| 5 | 
         
            +
                float: right;
         
     | 
| 6 | 
         
            +
                font-size: var(--text-xs);
         
     | 
| 7 | 
         
            +
                margin-top: -8px;
         
     | 
| 8 | 
         
            +
            }
         
     | 
| 9 | 
         
            +
            .gr_title {
         
     | 
| 10 | 
         
            +
                display: flex;
         
     | 
| 11 | 
         
            +
                align-items: center;
         
     | 
| 12 | 
         
            +
            }
         
     | 
| 13 | 
         
            +
            .gr_title img {
         
     | 
| 14 | 
         
            +
                max-height: 40px;
         
     | 
| 15 | 
         
            +
                margin-right: 1rem;
         
     | 
| 16 | 
         
            +
                margin-bottom: -10px;
         
     | 
| 17 | 
         
            +
            }
         
     | 
    	
        src/app.py
    CHANGED
    
    | 
         @@ -14,25 +14,28 @@ from themes.carbon import carbon_theme 
     | 
|
| 14 | 
         | 
| 15 | 
         
             
            today_date = datetime.today().strftime("%B %-d, %Y")  # noqa: DTZ002
         
     | 
| 16 | 
         | 
| 17 | 
         
            -
            MODEL_ID = "ibm-granite/granite-3.1-8b-instruct"
         
     | 
| 18 | 
         
             
            SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
         
     | 
| 19 | 
         
             
            Today's Date: {today_date}.
         
     | 
| 20 | 
         
             
            You are Granite, developed by IBM. You are a helpful AI assistant"""
         
     | 
| 21 | 
         
             
            TITLE = "IBM Granite 3.1 8b Instruct"
         
     | 
| 22 | 
         
             
            DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, just like developers, \
         
     | 
| 23 | 
         
             
                           AI models can make mistakes."
         
     | 
| 24 | 
         
            -
            MAX_INPUT_TOKEN_LENGTH =  
     | 
| 25 | 
         
             
            MAX_NEW_TOKENS = 1024
         
     | 
| 26 | 
         
             
            TEMPERATURE = 0.7
         
     | 
| 27 | 
         
             
            TOP_P = 0.85
         
     | 
| 28 | 
         
             
            TOP_K = 50
         
     | 
| 29 | 
         
             
            REPETITION_PENALTY = 1.05
         
     | 
| 30 | 
         | 
| 
         | 
|
| 
         | 
|
| 31 | 
         
             
            if not torch.cuda.is_available():
         
     | 
| 32 | 
         
             
                DESCRIPTION += "\nThis demo does not work on CPU."
         
     | 
| 33 | 
         | 
| 34 | 
         
            -
            model = AutoModelForCausalLM.from_pretrained( 
     | 
| 35 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 36 | 
         
             
            tokenizer.use_default_system_prompt = False
         
     | 
| 37 | 
         | 
| 38 | 
         | 
| 
         @@ -46,11 +49,13 @@ def generate(message: str, chat_history: list[dict]) -> Iterator[str]: 
     | 
|
| 46 | 
         
             
                conversation.append({"role": "user", "content": message})
         
     | 
| 47 | 
         | 
| 48 | 
         
             
                # Convert messages to prompt format
         
     | 
| 49 | 
         
            -
                input_ids = tokenizer.apply_chat_template( 
     | 
| 50 | 
         
            -
             
     | 
| 51 | 
         
            -
             
     | 
| 52 | 
         
            -
                     
     | 
| 53 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 54 | 
         | 
| 55 | 
         
             
                input_ids = input_ids.to(model.device)
         
     | 
| 56 | 
         
             
                streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
         
     | 
| 
         @@ -75,28 +80,62 @@ def generate(message: str, chat_history: list[dict]) -> Iterator[str]: 
     | 
|
| 75 | 
         
             
                    yield "".join(outputs)
         
     | 
| 76 | 
         | 
| 77 | 
         | 
| 78 | 
         
            -
            chat_interface = gr.ChatInterface(
         
     | 
| 79 | 
         
            -
                fn=generate,
         
     | 
| 80 | 
         
            -
                stop_btn=None,
         
     | 
| 81 | 
         
            -
                examples=[
         
     | 
| 82 | 
         
            -
                    ["Explain quantum computing"],
         
     | 
| 83 | 
         
            -
                    ["What is OpenShift?"],
         
     | 
| 84 | 
         
            -
                    ["Importance of low latency inference"],
         
     | 
| 85 | 
         
            -
                    ["Boosting productivity habits"],
         
     | 
| 86 | 
         
            -
                ],
         
     | 
| 87 | 
         
            -
                cache_examples=False,
         
     | 
| 88 | 
         
            -
                type="messages",
         
     | 
| 89 | 
         
            -
            )
         
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
             
            css_file_path = Path(Path(__file__).parent / "app.css")
         
     | 
| 92 | 
         
             
            head_file_path = Path(Path(__file__).parent / "app_head.html")
         
     | 
| 93 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 94 | 
         
             
            with gr.Blocks(
         
     | 
| 95 | 
         
             
                fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE
         
     | 
| 96 | 
         
             
            ) as demo:
         
     | 
| 97 | 
         
            -
                gr. 
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 100 | 
         | 
| 101 | 
         
             
            if __name__ == "__main__":
         
     | 
| 102 | 
         
            -
                demo.queue( 
     | 
| 
         | 
|
| 14 | 
         | 
| 15 | 
         
             
            today_date = datetime.today().strftime("%B %-d, %Y")  # noqa: DTZ002
         
     | 
| 16 | 
         | 
| 
         | 
|
| 17 | 
         
             
            SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
         
     | 
| 18 | 
         
             
            Today's Date: {today_date}.
         
     | 
| 19 | 
         
             
            You are Granite, developed by IBM. You are a helpful AI assistant"""
         
     | 
| 20 | 
         
             
            TITLE = "IBM Granite 3.1 8b Instruct"
         
     | 
| 21 | 
         
             
            DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, just like developers, \
         
     | 
| 22 | 
         
             
                           AI models can make mistakes."
         
     | 
| 23 | 
         
            +
            MAX_INPUT_TOKEN_LENGTH = 128_000
         
     | 
| 24 | 
         
             
            MAX_NEW_TOKENS = 1024
         
     | 
| 25 | 
         
             
            TEMPERATURE = 0.7
         
     | 
| 26 | 
         
             
            TOP_P = 0.85
         
     | 
| 27 | 
         
             
            TOP_K = 50
         
     | 
| 28 | 
         
             
            REPETITION_PENALTY = 1.05
         
     | 
| 29 | 
         | 
| 30 | 
         
            +
            model_list = ["granite-3.1-8b-instruct", "granite-3.1-2b-instruct"]
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
             
            if not torch.cuda.is_available():
         
     | 
| 33 | 
         
             
                DESCRIPTION += "\nThis demo does not work on CPU."
         
     | 
| 34 | 
         | 
| 35 | 
         
            +
            model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 36 | 
         
            +
                "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
         
     | 
| 37 | 
         
            +
            )
         
     | 
| 38 | 
         
            +
            tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
         
     | 
| 39 | 
         
             
            tokenizer.use_default_system_prompt = False
         
     | 
| 40 | 
         | 
| 41 | 
         | 
| 
         | 
|
| 49 | 
         
             
                conversation.append({"role": "user", "content": message})
         
     | 
| 50 | 
         | 
| 51 | 
         
             
                # Convert messages to prompt format
         
     | 
| 52 | 
         
            +
                input_ids = tokenizer.apply_chat_template(
         
     | 
| 53 | 
         
            +
                    conversation,
         
     | 
| 54 | 
         
            +
                    return_tensors="pt",
         
     | 
| 55 | 
         
            +
                    add_generation_prompt=True,
         
     | 
| 56 | 
         
            +
                    truncation=True,
         
     | 
| 57 | 
         
            +
                    max_length=MAX_INPUT_TOKEN_LENGTH,
         
     | 
| 58 | 
         
            +
                )
         
     | 
| 59 | 
         | 
| 60 | 
         
             
                input_ids = input_ids.to(model.device)
         
     | 
| 61 | 
         
             
                streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
         
     | 
| 
         | 
|
| 80 | 
         
             
                    yield "".join(outputs)
         
     | 
| 81 | 
         | 
| 82 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 83 | 
         
             
            css_file_path = Path(Path(__file__).parent / "app.css")
         
     | 
| 84 | 
         
             
            head_file_path = Path(Path(__file__).parent / "app_head.html")
         
     | 
| 85 | 
         | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
            def on_model_dropdown_change(model_name: str) -> list:
         
     | 
| 88 | 
         
            +
                """Event handler for dropdown."""
         
     | 
| 89 | 
         
            +
                global model
         
     | 
| 90 | 
         
            +
                global tokenizer
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 93 | 
         
            +
                    f"ibm-granite/{model_name}", torch_dtype=torch.float16, device_map="auto"
         
     | 
| 94 | 
         
            +
                )
         
     | 
| 95 | 
         
            +
                tokenizer = AutoTokenizer.from_pretrained(f"ibm-granite/{model_name}")
         
     | 
| 96 | 
         
            +
                tokenizer.use_default_system_prompt = False
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                # clear the chat interface when the model dropdown is changed
         
     | 
| 99 | 
         
            +
                # works around https://github.com/gradio-app/gradio/issues/10343
         
     | 
| 100 | 
         
            +
                return [None, []]
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
             
            with gr.Blocks(
         
     | 
| 104 | 
         
             
                fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=carbon_theme, title=TITLE
         
     | 
| 105 | 
         
             
            ) as demo:
         
     | 
| 106 | 
         
            +
                gr.HTML(
         
     | 
| 107 | 
         
            +
                    f"<img src='https://www.ibm.com/granite/docs/images/granite-cubes-352x368.webp'/><h1>{TITLE}</h1>",
         
     | 
| 108 | 
         
            +
                    elem_classes=["gr_title"],
         
     | 
| 109 | 
         
            +
                )
         
     | 
| 110 | 
         
            +
                gr.HTML(DESCRIPTION)
         
     | 
| 111 | 
         
            +
                model_dropdown = gr.Dropdown(
         
     | 
| 112 | 
         
            +
                    choices=model_list,
         
     | 
| 113 | 
         
            +
                    value="granite-3.1-8b-instruct",
         
     | 
| 114 | 
         
            +
                    interactive=True,
         
     | 
| 115 | 
         
            +
                    label="Model",
         
     | 
| 116 | 
         
            +
                    filterable=False,
         
     | 
| 117 | 
         
            +
                )
         
     | 
| 118 | 
         
            +
                gr.HTML(
         
     | 
| 119 | 
         
            +
                    value='<a href="https://www.ibm.com/granite/docs/">View Documentation</a> <i class="fa fa-external-link"></i>',
         
     | 
| 120 | 
         
            +
                    elem_classes=["gr_docs_link"],
         
     | 
| 121 | 
         
            +
                )
         
     | 
| 122 | 
         
            +
                chat_interface = gr.ChatInterface(
         
     | 
| 123 | 
         
            +
                    fn=generate,
         
     | 
| 124 | 
         
            +
                    examples=[
         
     | 
| 125 | 
         
            +
                        ["Explain quantum computing"],
         
     | 
| 126 | 
         
            +
                        ["What is OpenShift?"],
         
     | 
| 127 | 
         
            +
                        ["Importance of low latency inference"],
         
     | 
| 128 | 
         
            +
                        ["Boosting productivity habits"],
         
     | 
| 129 | 
         
            +
                    ],
         
     | 
| 130 | 
         
            +
                    cache_examples=False,
         
     | 
| 131 | 
         
            +
                    type="messages",
         
     | 
| 132 | 
         
            +
                )
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                model_dropdown.change(
         
     | 
| 135 | 
         
            +
                    fn=on_model_dropdown_change,
         
     | 
| 136 | 
         
            +
                    inputs=model_dropdown,
         
     | 
| 137 | 
         
            +
                    outputs=[chat_interface.chatbot, chat_interface.chatbot_state],
         
     | 
| 138 | 
         
            +
                )
         
     | 
| 139 | 
         | 
| 140 | 
         
             
            if __name__ == "__main__":
         
     | 
| 141 | 
         
            +
                demo.queue().launch()
         
     | 
    	
        src/app_head.html
    CHANGED
    
    | 
         @@ -1,3 +1,7 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            <script
         
     | 
| 2 | 
         
             
              async
         
     | 
| 3 | 
         
             
              src="https://www.googletagmanager.com/gtag/js?id=G-C6LFT227RC"
         
     | 
| 
         | 
|
| 1 | 
         
            +
            <link
         
     | 
| 2 | 
         
            +
              rel="stylesheet"
         
     | 
| 3 | 
         
            +
              href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css"
         
     | 
| 4 | 
         
            +
            />
         
     | 
| 5 | 
         
             
            <script
         
     | 
| 6 | 
         
             
              async
         
     | 
| 7 | 
         
             
              src="https://www.googletagmanager.com/gtag/js?id=G-C6LFT227RC"
         
     |