Spaces:
Paused
Paused
| """ | |
| This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface. | |
| """ | |
| from __future__ import annotations | |
| import inspect | |
| from typing import AsyncGenerator, Callable, Literal, Union, cast | |
| import anyio | |
| from gradio_client.documentation import document | |
| from gradio.blocks import Blocks | |
| from gradio.components import ( | |
| Button, | |
| Chatbot, | |
| Component, | |
| Markdown, | |
| MultimodalTextbox, | |
| State, | |
| Textbox, | |
| get_component_instance, | |
| Dataset | |
| ) | |
| from gradio.events import Dependency, on | |
| from gradio.helpers import special_args | |
| from gradio.layouts import Accordion, Group, Row | |
| from gradio.routes import Request | |
| from gradio.themes import ThemeClass as Theme | |
| from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda | |
| class ChatInterface(Blocks): | |
| """ | |
| ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create | |
| a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which | |
| takes a function that governs the response of the chatbot based on the user input and chat history. Additional | |
| parameters can be used to control the appearance and behavior of the demo. | |
| Example: | |
| import gradio as gr | |
| def echo(message, history): | |
| return message | |
| demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot") | |
| demo.launch() | |
| Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo | |
| Guides: creating-a-chatbot-fast, sharing-your-app | |
| """ | |
| def __init__( | |
| self, | |
| fn: Callable, | |
| post_fn: Callable, | |
| pre_fn: Callable, | |
| chatbot: Chatbot, | |
| *, | |
| post_fn_kwargs: dict = None, | |
| pre_fn_kwargs: dict = None, | |
| multimodal: bool = False, | |
| textbox: Textbox | MultimodalTextbox | None = None, | |
| additional_inputs: str | Component | list[str | Component] | None = None, | |
| additional_inputs_accordion_name: str | None = None, | |
| additional_inputs_accordion: str | Accordion | None = None, | |
| examples: Dataset = None, | |
| title: str | None = None, | |
| description: str | None = None, | |
| theme: Theme | str | None = None, | |
| css: str | None = None, | |
| js: str | None = None, | |
| head: str | None = None, | |
| analytics_enabled: bool | None = None, | |
| submit_btn: str | None | Button = "Submit", | |
| stop_btn: str | None | Button = "Stop", | |
| retry_btn: str | None | Button = "🔄 Retry", | |
| undo_btn: str | None | Button = "↩️ Undo", | |
| clear_btn: str | None | Button = "🗑️ Clear", | |
| autofocus: bool = True, | |
| concurrency_limit: int | None | Literal["default"] = "default", | |
| fill_height: bool = True, | |
| delete_cache: tuple[int, int] | None = None, | |
| ): | |
| super().__init__( | |
| analytics_enabled=analytics_enabled, | |
| mode="chat_interface", | |
| css=css, | |
| title=title or "Gradio", | |
| theme=theme, | |
| js=js, | |
| head=head, | |
| fill_height=fill_height, | |
| delete_cache=delete_cache, | |
| ) | |
| if post_fn_kwargs is None: | |
| post_fn_kwargs = [] | |
| self.post_fn = post_fn | |
| self.post_fn_kwargs = post_fn_kwargs | |
| self.pre_fn = pre_fn | |
| self.pre_fn_kwargs = pre_fn_kwargs | |
| self.multimodal = multimodal | |
| self.concurrency_limit = concurrency_limit | |
| self.fn = fn | |
| self.is_async = inspect.iscoroutinefunction( | |
| self.fn | |
| ) or inspect.isasyncgenfunction(self.fn) | |
| self.is_generator = inspect.isgeneratorfunction( | |
| self.fn | |
| ) or inspect.isasyncgenfunction(self.fn) | |
| if additional_inputs: | |
| if not isinstance(additional_inputs, list): | |
| additional_inputs = [additional_inputs] | |
| self.additional_inputs = [ | |
| get_component_instance(i) | |
| for i in additional_inputs # type: ignore | |
| ] | |
| else: | |
| self.additional_inputs = [] | |
| if additional_inputs_accordion_name is not None: | |
| print( | |
| "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead." | |
| ) | |
| self.additional_inputs_accordion_params = { | |
| "label": additional_inputs_accordion_name | |
| } | |
| if additional_inputs_accordion is None: | |
| self.additional_inputs_accordion_params = { | |
| "label": "Additional Inputs", | |
| "open": False, | |
| } | |
| elif isinstance(additional_inputs_accordion, str): | |
| self.additional_inputs_accordion_params = { | |
| "label": additional_inputs_accordion | |
| } | |
| elif isinstance(additional_inputs_accordion, Accordion): | |
| self.additional_inputs_accordion_params = ( | |
| additional_inputs_accordion.recover_kwargs( | |
| additional_inputs_accordion.get_config() | |
| ) | |
| ) | |
| else: | |
| raise ValueError( | |
| f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}" | |
| ) | |
| with self: | |
| if title: | |
| Markdown( | |
| f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>" | |
| ) | |
| if description: | |
| Markdown(description) | |
| self.chatbot = chatbot.render() | |
| self.buttons = [retry_btn, undo_btn, clear_btn] | |
| with Group(): | |
| with Row(): | |
| if textbox: | |
| if self.multimodal: | |
| submit_btn = None | |
| else: | |
| textbox.container = False | |
| textbox.show_label = False | |
| textbox_ = textbox.render() | |
| if not isinstance(textbox_, (Textbox, MultimodalTextbox)): | |
| raise TypeError( | |
| f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}" | |
| ) | |
| self.textbox = textbox_ | |
| elif self.multimodal: | |
| submit_btn = None | |
| self.textbox = MultimodalTextbox( | |
| show_label=False, | |
| label="Message", | |
| placeholder="Type a message...", | |
| scale=7, | |
| autofocus=autofocus, | |
| ) | |
| else: | |
| self.textbox = Textbox( | |
| container=False, | |
| show_label=False, | |
| label="Message", | |
| placeholder="Type a message...", | |
| scale=7, | |
| autofocus=autofocus, | |
| ) | |
| if submit_btn is not None and not multimodal: | |
| if isinstance(submit_btn, Button): | |
| submit_btn.render() | |
| elif isinstance(submit_btn, str): | |
| submit_btn = Button( | |
| submit_btn, | |
| variant="primary", | |
| scale=1, | |
| min_width=150, | |
| ) | |
| else: | |
| raise ValueError( | |
| f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}" | |
| ) | |
| if stop_btn is not None: | |
| if isinstance(stop_btn, Button): | |
| stop_btn.visible = False | |
| stop_btn.render() | |
| elif isinstance(stop_btn, str): | |
| stop_btn = Button( | |
| stop_btn, | |
| variant="stop", | |
| visible=False, | |
| scale=1, | |
| min_width=150, | |
| ) | |
| else: | |
| raise ValueError( | |
| f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}" | |
| ) | |
| self.buttons.extend([submit_btn, stop_btn]) # type: ignore | |
| self.fake_api_btn = Button("Fake API", visible=False) | |
| self.fake_response_textbox = Textbox(label="Response", visible=False) | |
| ( | |
| self.retry_btn, | |
| self.undo_btn, | |
| self.clear_btn, | |
| self.submit_btn, | |
| self.stop_btn, | |
| ) = self.buttons | |
| any_unrendered_inputs = any( | |
| not inp.is_rendered for inp in self.additional_inputs | |
| ) | |
| if self.additional_inputs and any_unrendered_inputs: | |
| with Accordion(**self.additional_inputs_accordion_params): # type: ignore | |
| for input_component in self.additional_inputs: | |
| if not input_component.is_rendered: | |
| input_component.render() | |
| self.saved_input = State() | |
| self.chatbot_state = ( | |
| State(self.chatbot.value) if self.chatbot.value else State([]) | |
| ) | |
| self._setup_events() | |
| self._setup_api() | |
| if examples: | |
| examples.click(lambda x: x[0], inputs=[examples], outputs=self.textbox, show_progress=False, queue=False) | |
| def _setup_events(self) -> None: | |
| submit_fn = self._stream_fn if self.is_generator else self._submit_fn | |
| submit_triggers = ( | |
| [self.textbox.submit, self.submit_btn.click] | |
| if self.submit_btn | |
| else [self.textbox.submit] | |
| ) | |
| submit_event = ( | |
| on( | |
| submit_triggers, | |
| self._clear_and_save_textbox, | |
| [self.textbox], | |
| [self.textbox, self.saved_input], | |
| show_api=False, | |
| queue=False, | |
| ) | |
| .then( | |
| self.pre_fn, | |
| **self.pre_fn_kwargs, | |
| show_api=False, | |
| queue=False, | |
| ) | |
| .then( | |
| self._display_input, | |
| [self.saved_input, self.chatbot_state], | |
| [self.chatbot, self.chatbot_state], | |
| show_api=False, | |
| queue=False, | |
| ) | |
| .then( | |
| submit_fn, | |
| [self.saved_input, self.chatbot_state] + self.additional_inputs, | |
| [self.chatbot, self.chatbot_state], | |
| show_api=False, | |
| concurrency_limit=cast( | |
| Union[int, Literal["default"], None], self.concurrency_limit | |
| ), | |
| ).then( | |
| self.post_fn, | |
| **self.post_fn_kwargs, | |
| show_api=False, | |
| concurrency_limit=cast( | |
| Union[int, Literal["default"], None], self.concurrency_limit | |
| ), | |
| ) | |
| ) | |
| self._setup_stop_events(submit_triggers, submit_event) | |
| if self.retry_btn: | |
| retry_event = ( | |
| self.retry_btn.click( | |
| self._delete_prev_fn, | |
| [self.saved_input, self.chatbot_state], | |
| [self.chatbot, self.saved_input, self.chatbot_state], | |
| show_api=False, | |
| queue=False, | |
| ) | |
| .then( | |
| self.pre_fn, | |
| **self.pre_fn_kwargs, | |
| show_api=False, | |
| queue=False, | |
| ) | |
| .then( | |
| self._display_input, | |
| [self.saved_input, self.chatbot_state], | |
| [self.chatbot, self.chatbot_state], | |
| show_api=False, | |
| queue=False, | |
| ) | |
| .then( | |
| submit_fn, | |
| [self.saved_input, self.chatbot_state] + self.additional_inputs, | |
| [self.chatbot, self.chatbot_state], | |
| show_api=False, | |
| concurrency_limit=cast( | |
| Union[int, Literal["default"], None], self.concurrency_limit | |
| ), | |
| ).then( | |
| self.post_fn, | |
| **self.post_fn_kwargs, | |
| show_api=False, | |
| concurrency_limit=cast( | |
| Union[int, Literal["default"], None], self.concurrency_limit | |
| ), | |
| ) | |
| ) | |
| self._setup_stop_events([self.retry_btn.click], retry_event) | |
| if self.undo_btn: | |
| self.undo_btn.click( | |
| self._delete_prev_fn, | |
| [self.saved_input, self.chatbot_state], | |
| [self.chatbot, self.saved_input, self.chatbot_state], | |
| show_api=False, | |
| queue=False, | |
| ).then( | |
| self.pre_fn, | |
| **self.pre_fn_kwargs, | |
| show_api=False, | |
| queue=False, | |
| ).then( | |
| async_lambda(lambda x: x), | |
| [self.saved_input], | |
| [self.textbox], | |
| show_api=False, | |
| queue=False, | |
| ).then( | |
| self.post_fn, | |
| **self.post_fn_kwargs, | |
| show_api=False, | |
| concurrency_limit=cast( | |
| Union[int, Literal["default"], None], self.concurrency_limit | |
| ), | |
| ) | |
| if self.clear_btn: | |
| self.clear_btn.click( | |
| async_lambda(lambda: ([], [], None)), | |
| None, | |
| [self.chatbot, self.chatbot_state, self.saved_input], | |
| queue=False, | |
| show_api=False, | |
| ).then( | |
| self.pre_fn, | |
| **self.pre_fn_kwargs, | |
| show_api=False, | |
| queue=False, | |
| ).then( | |
| self.post_fn, | |
| **self.post_fn_kwargs, | |
| show_api=False, | |
| concurrency_limit=cast( | |
| Union[int, Literal["default"], None], self.concurrency_limit | |
| ), | |
| ) | |
| def _setup_stop_events( | |
| self, event_triggers: list[Callable], event_to_cancel: Dependency | |
| ) -> None: | |
| if self.stop_btn and self.is_generator: | |
| if self.submit_btn: | |
| for event_trigger in event_triggers: | |
| event_trigger( | |
| async_lambda( | |
| lambda: ( | |
| Button(visible=False), | |
| Button(visible=True), | |
| ) | |
| ), | |
| None, | |
| [self.submit_btn, self.stop_btn], | |
| show_api=False, | |
| queue=False, | |
| ) | |
| event_to_cancel.then( | |
| async_lambda(lambda: (Button(visible=True), Button(visible=False))), | |
| None, | |
| [self.submit_btn, self.stop_btn], | |
| show_api=False, | |
| queue=False, | |
| ) | |
| else: | |
| for event_trigger in event_triggers: | |
| event_trigger( | |
| async_lambda(lambda: Button(visible=True)), | |
| None, | |
| [self.stop_btn], | |
| show_api=False, | |
| queue=False, | |
| ) | |
| event_to_cancel.then( | |
| async_lambda(lambda: Button(visible=False)), | |
| None, | |
| [self.stop_btn], | |
| show_api=False, | |
| queue=False, | |
| ) | |
| self.stop_btn.click( | |
| None, | |
| None, | |
| None, | |
| cancels=event_to_cancel, | |
| show_api=False, | |
| ) | |
| def _setup_api(self) -> None: | |
| api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn | |
| self.fake_api_btn.click( | |
| api_fn, | |
| [self.textbox, self.chatbot_state] + self.additional_inputs, | |
| [self.textbox, self.chatbot_state], | |
| api_name="chat", | |
| concurrency_limit=cast( | |
| Union[int, Literal["default"], None], self.concurrency_limit | |
| ), | |
| ) | |
| def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]: | |
| if self.multimodal: | |
| return {"text": "", "files": []}, message | |
| else: | |
| return "", message | |
| def _append_multimodal_history( | |
| self, | |
| message: dict[str, list], | |
| response: str | None, | |
| history: list[list[str | tuple | None]], | |
| ): | |
| for x in message["files"]: | |
| history.append([(x,), None]) | |
| if message["text"] is None or not isinstance(message["text"], str): | |
| return | |
| elif message["text"] == "" and message["files"] != []: | |
| history.append([None, response]) | |
| else: | |
| history.append([message["text"], response]) | |
| async def _display_input( | |
| self, message: str | dict[str, list], history: list[list[str | tuple | None]] | |
| ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]: | |
| if self.multimodal and isinstance(message, dict): | |
| self._append_multimodal_history(message, None, history) | |
| elif isinstance(message, str): | |
| history.append([message, None]) | |
| return history, history | |
| async def _submit_fn( | |
| self, | |
| message: str | dict[str, list], | |
| history_with_input: list[list[str | tuple | None]], | |
| request: Request, | |
| *args, | |
| ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]: | |
| if self.multimodal and isinstance(message, dict): | |
| remove_input = ( | |
| len(message["files"]) + 1 | |
| if message["text"] is not None | |
| else len(message["files"]) | |
| ) | |
| history = history_with_input[:-remove_input] | |
| else: | |
| history = history_with_input[:-1] | |
| inputs, _, _ = special_args( | |
| self.fn, inputs=[message, history, *args], request=request | |
| ) | |
| if self.is_async: | |
| response = await self.fn(*inputs) | |
| else: | |
| response = await anyio.to_thread.run_sync( | |
| self.fn, *inputs, limiter=self.limiter | |
| ) | |
| if self.multimodal and isinstance(message, dict): | |
| self._append_multimodal_history(message, response, history) | |
| elif isinstance(message, str): | |
| history.append([message, response]) | |
| return history, history | |
| async def _stream_fn( | |
| self, | |
| message: str | dict[str, list], | |
| history_with_input: list[list[str | tuple | None]], | |
| request: Request, | |
| *args, | |
| ) -> AsyncGenerator: | |
| if self.multimodal and isinstance(message, dict): | |
| remove_input = ( | |
| len(message["files"]) + 1 | |
| if message["text"] is not None | |
| else len(message["files"]) | |
| ) | |
| history = history_with_input[:-remove_input] | |
| else: | |
| history = history_with_input[:-1] | |
| inputs, _, _ = special_args( | |
| self.fn, inputs=[message, history, *args], request=request | |
| ) | |
| if self.is_async: | |
| generator = self.fn(*inputs) | |
| else: | |
| generator = await anyio.to_thread.run_sync( | |
| self.fn, *inputs, limiter=self.limiter | |
| ) | |
| generator = SyncToAsyncIterator(generator, self.limiter) | |
| try: | |
| first_response = await async_iteration(generator) | |
| if self.multimodal and isinstance(message, dict): | |
| for x in message["files"]: | |
| history.append([(x,), None]) | |
| update = history + [[message["text"], first_response]] | |
| yield update, update | |
| else: | |
| update = history + [[message, first_response]] | |
| yield update, update | |
| except StopIteration: | |
| if self.multimodal and isinstance(message, dict): | |
| self._append_multimodal_history(message, None, history) | |
| yield history, history | |
| else: | |
| update = history + [[message, None]] | |
| yield update, update | |
| async for response in generator: | |
| if self.multimodal and isinstance(message, dict): | |
| update = history + [[message["text"], response]] | |
| yield update, update | |
| else: | |
| update = history + [[message, response]] | |
| yield update, update | |
| async def _api_submit_fn( | |
| self, message: str, history: list[list[str | None]], request: Request, *args | |
| ) -> tuple[str, list[list[str | None]]]: | |
| inputs, _, _ = special_args( | |
| self.fn, inputs=[message, history, *args], request=request | |
| ) | |
| if self.is_async: | |
| response = await self.fn(*inputs) | |
| else: | |
| response = await anyio.to_thread.run_sync( | |
| self.fn, *inputs, limiter=self.limiter | |
| ) | |
| history.append([message, response]) | |
| return response, history | |
| async def _api_stream_fn( | |
| self, message: str, history: list[list[str | None]], request: Request, *args | |
| ) -> AsyncGenerator: | |
| inputs, _, _ = special_args( | |
| self.fn, inputs=[message, history, *args], request=request | |
| ) | |
| if self.is_async: | |
| generator = self.fn(*inputs) | |
| else: | |
| generator = await anyio.to_thread.run_sync( | |
| self.fn, *inputs, limiter=self.limiter | |
| ) | |
| generator = SyncToAsyncIterator(generator, self.limiter) | |
| try: | |
| first_response = await async_iteration(generator) | |
| yield first_response, history + [[message, first_response]] | |
| except StopIteration: | |
| yield None, history + [[message, None]] | |
| async for response in generator: | |
| yield response, history + [[message, response]] | |
| async def _delete_prev_fn( | |
| self, | |
| message: str | dict[str, list], | |
| history: list[list[str | tuple | None]], | |
| ) -> tuple[ | |
| list[list[str | tuple | None]], | |
| str | dict[str, list], | |
| list[list[str | tuple | None]], | |
| ]: | |
| if self.multimodal and isinstance(message, dict): | |
| remove_input = ( | |
| len(message["files"]) + 1 | |
| if message["text"] is not None | |
| else len(message["files"]) | |
| ) | |
| history = history[:-remove_input] | |
| else: | |
| history = history[:-1] | |
| return history, message or "", history | |