"""
This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface.
"""
from __future__ import annotations
import inspect
from typing import AsyncGenerator, Callable, Literal, Union, cast
import anyio
from gradio_client.documentation import document
from gradio.blocks import Blocks
from gradio.components import (
    Button,
    Chatbot,
    Component,
    Markdown,
    MultimodalTextbox,
    State,
    Textbox,
    get_component_instance,
    Dataset,
)
from gradio.events import Dependency, on
from gradio.helpers import special_args
from gradio.layouts import Accordion, Group, Row
from gradio.routes import Request
from gradio.themes import ThemeClass as Theme
from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
@document()
class ChatInterface(Blocks):
    """
    ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create
    a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which
    takes a function that governs the response of the chatbot based on the user input and chat history. Additional
    parameters can be used to control the appearance and behavior of the demo.
    Example:
        import gradio as gr
        def echo(message, history):
            return message
        demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot")
        demo.launch()
    Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo
    Guides: creating-a-chatbot-fast, sharing-your-app
    """
    def __init__(
        self,
        fn: Callable,
        post_fn: Callable,
        pre_fn: Callable,
        chatbot: Chatbot,
        *,
        show_stop_button=True,
        post_fn_kwargs: dict = None,
        pre_fn_kwargs: dict = None,
        multimodal: bool = False,
        textbox: Textbox | MultimodalTextbox | None = None,
        additional_inputs: str | Component | list[str | Component] | None = None,
        additional_inputs_accordion_name: str | None = None,
        additional_inputs_accordion: str | Accordion | None = None,
        examples: Dataset = None,
        title: str | None = None,
        description: str | None = None,
        theme: Theme | str | None = None,
        css: str | None = None,
        js: str | None = None,
        head: str | None = None,
        analytics_enabled: bool | None = None,
        submit_btn: str | None | Button = "Submit",
        stop_btn: str | None | Button = "Stop",
        retry_btn: str | None | Button = "đ  Retry",
        undo_btn: str | None | Button = "âŠī¸ Undo",
        clear_btn: str | None | Button = "đī¸  Clear",
        autofocus: bool = True,
        concurrency_limit: int | None | Literal["default"] = "default",
        fill_height: bool = True,
        delete_cache: tuple[int, int] | None = None,
    ):
        super().__init__(
            analytics_enabled=analytics_enabled,
            mode="chat_interface",
            css=css,
            title=title or "Gradio",
            theme=theme,
            js=js,
            head=head,
            fill_height=fill_height,
            delete_cache=delete_cache,
        )
        if post_fn_kwargs is None:
            post_fn_kwargs = []
        self.post_fn = post_fn
        self.post_fn_kwargs = post_fn_kwargs
        self.pre_fn = pre_fn
        self.pre_fn_kwargs = pre_fn_kwargs
        self.show_stop_button = show_stop_button
        self.interrupter = State(None)
        self.multimodal = multimodal
        self.concurrency_limit = concurrency_limit
        self.fn = fn
        self.is_async = inspect.iscoroutinefunction(
            self.fn
        ) or inspect.isasyncgenfunction(self.fn)
        self.is_generator = inspect.isgeneratorfunction(
            self.fn
        ) or inspect.isasyncgenfunction(self.fn)
        if additional_inputs:
            if not isinstance(additional_inputs, list):
                additional_inputs = [additional_inputs]
            self.additional_inputs = [
                get_component_instance(i)
                for i in additional_inputs  # type: ignore
            ]
        else:
            self.additional_inputs = []
        if additional_inputs_accordion_name is not None:
            print(
                "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
            )
            self.additional_inputs_accordion_params = {
                "label": additional_inputs_accordion_name
            }
        if additional_inputs_accordion is None:
            self.additional_inputs_accordion_params = {
                "label": "Additional Inputs",
                "open": False,
            }
        elif isinstance(additional_inputs_accordion, str):
            self.additional_inputs_accordion_params = {
                "label": additional_inputs_accordion
            }
        elif isinstance(additional_inputs_accordion, Accordion):
            self.additional_inputs_accordion_params = (
                additional_inputs_accordion.recover_kwargs(
                    additional_inputs_accordion.get_config()
                )
            )
        else:
            raise ValueError(
                f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}"
            )
        with self:
            if title:
                Markdown(
                    f"
{self.title}
"
                )
            if description:
                Markdown(description)
            self.chatbot = chatbot.render()
            self.buttons = [retry_btn, undo_btn, clear_btn]
            with Group():
                with Row():
                    if textbox:
                        if self.multimodal:
                            submit_btn = None
                        else:
                            textbox.container = False
                        textbox.show_label = False
                        textbox_ = textbox.render()
                        if not isinstance(textbox_, (Textbox, MultimodalTextbox)):
                            raise TypeError(
                                f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}"
                            )
                        self.textbox = textbox_
                    elif self.multimodal:
                        submit_btn = None
                        self.textbox = MultimodalTextbox(
                            show_label=False,
                            label="Message",
                            placeholder="Type a message...",
                            scale=7,
                            autofocus=autofocus,
                        )
                    else:
                        self.textbox = Textbox(
                            container=False,
                            show_label=False,
                            label="Message",
                            placeholder="Type a message...",
                            scale=7,
                            autofocus=autofocus,
                        )
                    if submit_btn is not None and not multimodal:
                        if isinstance(submit_btn, Button):
                            submit_btn.render()
                        elif isinstance(submit_btn, str):
                            submit_btn = Button(
                                submit_btn,
                                variant="primary",
                                scale=1,
                                min_width=150,
                            )
                        else:
                            raise ValueError(
                                f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}"
                            )
                    if stop_btn is not None:
                        if isinstance(stop_btn, Button):
                            stop_btn.visible = False
                            stop_btn.render()
                        elif isinstance(stop_btn, str):
                            stop_btn = Button(
                                stop_btn,
                                variant="stop",
                                visible=False,
                                scale=1,
                                min_width=150,
                            )
                        else:
                            raise ValueError(
                                f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}"
                            )
                    self.buttons.extend([submit_btn, stop_btn])  # type: ignore
                self.fake_api_btn = Button("Fake API", visible=False)
                self.fake_response_textbox = Textbox(label="Response", visible=False)
                (
                    self.retry_btn,
                    self.undo_btn,
                    self.clear_btn,
                    self.submit_btn,
                    self.stop_btn,
                ) = self.buttons
            any_unrendered_inputs = any(
                not inp.is_rendered for inp in self.additional_inputs
            )
            if self.additional_inputs and any_unrendered_inputs:
                with Accordion(**self.additional_inputs_accordion_params):  # type: ignore
                    for input_component in self.additional_inputs:
                        if not input_component.is_rendered:
                            input_component.render()
            self.saved_input = State()
            self.chatbot_state = (
                State(self.chatbot.value) if self.chatbot.value else State([])
            )
            self._setup_events()
            self._setup_api()
        if examples:
            examples.click(lambda x: x[0], inputs=[examples], outputs=self.textbox, show_progress=False, queue=False)
    def _setup_events(self) -> None:
        submit_fn = self._stream_fn if self.is_generator else self._submit_fn
        submit_triggers = (
            [self.textbox.submit, self.submit_btn.click]
            if self.submit_btn
            else [self.textbox.submit]
        )
        submit_event = (
            on(
                submit_triggers,
                self._clear_and_save_textbox,
                [self.textbox],
                [self.textbox, self.saved_input],
                show_api=False,
                queue=False,
            )
            .then(
                self.pre_fn,
                **self.pre_fn_kwargs,
                show_api=False,
                queue=False,
            )
            .then(
                self._display_input,
                [self.saved_input, self.chatbot_state],
                [self.chatbot, self.chatbot_state],
                show_api=False,
                queue=False,
            )
            .then(
                submit_fn,
                [self.saved_input, self.chatbot_state] + self.additional_inputs,
                [self.chatbot, self.chatbot_state, self.interrupter],
                show_api=False,
                concurrency_limit=cast(
                    Union[int, Literal["default"], None], self.concurrency_limit
                ),
            ).then(
                self.post_fn,
                **self.post_fn_kwargs,
                show_api=False,
                concurrency_limit=cast(
                    Union[int, Literal["default"], None], self.concurrency_limit
                ),
            )
        )
        self._setup_stop_events(submit_triggers, submit_event)
        if self.retry_btn:
            retry_event = (
                self.retry_btn.click(
                    self._delete_prev_fn,
                    [self.saved_input, self.chatbot_state],
                    [self.chatbot, self.saved_input, self.chatbot_state],
                    show_api=False,
                    queue=False,
                )
                .then(
                    self.pre_fn,
                    **self.pre_fn_kwargs,
                    show_api=False,
                    queue=False,
                )
                .then(
                    self._display_input,
                    [self.saved_input, self.chatbot_state],
                    [self.chatbot, self.chatbot_state],
                    show_api=False,
                    queue=False,
                )
                .then(
                    submit_fn,
                    [self.saved_input, self.chatbot_state] + self.additional_inputs,
                    [self.chatbot, self.chatbot_state],
                    show_api=False,
                    concurrency_limit=cast(
                        Union[int, Literal["default"], None], self.concurrency_limit
                    ),
                ).then(
                self.post_fn,
                **self.post_fn_kwargs,
                show_api=False,
                concurrency_limit=cast(
                    Union[int, Literal["default"], None], self.concurrency_limit
                ),
            )
            )
            self._setup_stop_events([self.retry_btn.click], retry_event)
        if self.undo_btn:
            self.undo_btn.click(
                self._delete_prev_fn,
                [self.saved_input, self.chatbot_state],
                [self.chatbot, self.saved_input, self.chatbot_state],
                show_api=False,
                queue=False,
            ).then(
                self.pre_fn,
                **self.pre_fn_kwargs,
                show_api=False,
                queue=False,
            ).then(
                async_lambda(lambda x: x),
                [self.saved_input],
                [self.textbox],
                show_api=False,
                queue=False,
            ).then(
                self.post_fn,
                **self.post_fn_kwargs,
                show_api=False,
                concurrency_limit=cast(
                    Union[int, Literal["default"], None], self.concurrency_limit
                ),
            )
        if self.clear_btn:
            self.clear_btn.click(
                async_lambda(lambda: ([], [], None)),
                None,
                [self.chatbot, self.chatbot_state, self.saved_input],
                queue=False,
                show_api=False,
            ).then(
                self.pre_fn,
                **self.pre_fn_kwargs,
                show_api=False,
                queue=False,
            ).then(
                self.post_fn,
                **self.post_fn_kwargs,
                show_api=False,
                concurrency_limit=cast(
                    Union[int, Literal["default"], None], self.concurrency_limit
                ),
            )
    def _setup_stop_events(
        self, event_triggers: list[Callable], event_to_cancel: Dependency
    ) -> None:
        def perform_interrupt(ipc):
            if ipc is not None:
                ipc()
            return
        if self.stop_btn and self.is_generator:
            if self.submit_btn:
                for event_trigger in event_triggers:
                    event_trigger(
                        async_lambda(
                            lambda: (
                                Button(visible=False),
                                Button(visible=self.show_stop_button),
                            )
                        ),
                        None,
                        [self.submit_btn, self.stop_btn],
                        show_api=False,
                        queue=False,
                    )
                event_to_cancel.then(
                    async_lambda(lambda: (Button(visible=True), Button(visible=False))),
                    None,
                    [self.submit_btn, self.stop_btn],
                    show_api=False,
                    queue=False,
                )
            else:
                for event_trigger in event_triggers:
                    event_trigger(
                        async_lambda(lambda: Button(visible=self.show_stop_button)),
                        None,
                        [self.stop_btn],
                        show_api=False,
                        queue=False,
                    )
                event_to_cancel.then(
                    async_lambda(lambda: Button(visible=False)),
                    None,
                    [self.stop_btn],
                    show_api=False,
                    queue=False,
                )
            self.stop_btn.click(
                fn=perform_interrupt,
                inputs=[self.interrupter],
                cancels=event_to_cancel,
                show_api=False,
            )
    def _setup_api(self) -> None:
        api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
        self.fake_api_btn.click(
            api_fn,
            [self.textbox, self.chatbot_state] + self.additional_inputs,
            [self.textbox, self.chatbot_state],
            api_name="chat",
            concurrency_limit=cast(
                Union[int, Literal["default"], None], self.concurrency_limit
            ),
        )
    def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]:
        if self.multimodal:
            return {"text": "", "files": []}, message
        else:
            return "", message
    def _append_multimodal_history(
        self,
        message: dict[str, list],
        response: str | None,
        history: list[list[str | tuple | None]],
    ):
        for x in message["files"]:
            history.append([(x,), None])
        if message["text"] is None or not isinstance(message["text"], str):
            return
        elif message["text"] == "" and message["files"] != []:
            history.append([None, response])
        else:
            history.append([message["text"], response])
    async def _display_input(
        self, message: str | dict[str, list], history: list[list[str | tuple | None]]
    ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
        if self.multimodal and isinstance(message, dict):
            self._append_multimodal_history(message, None, history)
        elif isinstance(message, str):
            history.append([message, None])
        return history, history
    async def _submit_fn(
        self,
        message: str | dict[str, list],
        history_with_input: list[list[str | tuple | None]],
        request: Request,
        *args,
    ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
        if self.multimodal and isinstance(message, dict):
            remove_input = (
                len(message["files"]) + 1
                if message["text"] is not None
                else len(message["files"])
            )
            history = history_with_input[:-remove_input]
        else:
            history = history_with_input[:-1]
        inputs, _, _ = special_args(
            self.fn, inputs=[message, history, *args], request=request
        )
        if self.is_async:
            response = await self.fn(*inputs)
        else:
            response = await anyio.to_thread.run_sync(
                self.fn, *inputs, limiter=self.limiter
            )
        if self.multimodal and isinstance(message, dict):
            self._append_multimodal_history(message, response, history)
        elif isinstance(message, str):
            history.append([message, response])
        return history, history
    async def _stream_fn(
        self,
        message: str | dict[str, list],
        history_with_input: list[list[str | tuple | None]],
        request: Request,
        *args,
    ) -> AsyncGenerator:
        if self.multimodal and isinstance(message, dict):
            remove_input = (
                len(message["files"]) + 1
                if message["text"] is not None
                else len(message["files"])
            )
            history = history_with_input[:-remove_input]
        else:
            history = history_with_input[:-1]
        inputs, _, _ = special_args(
            self.fn, inputs=[message, history, *args], request=request
        )
        if self.is_async:
            generator = self.fn(*inputs)
        else:
            generator = await anyio.to_thread.run_sync(
                self.fn, *inputs, limiter=self.limiter
            )
            generator = SyncToAsyncIterator(generator, self.limiter)
        try:
            first_response, first_interrupter = await async_iteration(generator)
            if self.multimodal and isinstance(message, dict):
                for x in message["files"]:
                    history.append([(x,), None])
                update = history + [[message["text"], first_response]]
                yield update, update
            else:
                update = history + [[message, first_response]]
                yield update, update, first_interrupter
        except StopIteration:
            if self.multimodal and isinstance(message, dict):
                self._append_multimodal_history(message, None, history)
                yield history, history
            else:
                update = history + [[message, None]]
                yield update, update, first_interrupter
        async for response, interrupter in generator:
            if self.multimodal and isinstance(message, dict):
                update = history + [[message["text"], response]]
                yield update, update
            else:
                update = history + [[message, response]]
                yield update, update, interrupter
    async def _api_submit_fn(
        self, message: str, history: list[list[str | None]], request: Request, *args
    ) -> tuple[str, list[list[str | None]]]:
        inputs, _, _ = special_args(
            self.fn, inputs=[message, history, *args], request=request
        )
        if self.is_async:
            response = await self.fn(*inputs)
        else:
            response = await anyio.to_thread.run_sync(
                self.fn, *inputs, limiter=self.limiter
            )
        history.append([message, response])
        return response, history
    async def _api_stream_fn(
        self, message: str, history: list[list[str | None]], request: Request, *args
    ) -> AsyncGenerator:
        inputs, _, _ = special_args(
            self.fn, inputs=[message, history, *args], request=request
        )
        if self.is_async:
            generator = self.fn(*inputs)
        else:
            generator = await anyio.to_thread.run_sync(
                self.fn, *inputs, limiter=self.limiter
            )
            generator = SyncToAsyncIterator(generator, self.limiter)
        try:
            first_response = await async_iteration(generator)
            yield first_response, history + [[message, first_response]]
        except StopIteration:
            yield None, history + [[message, None]]
        async for response in generator:
            yield response, history + [[message, response]]
    async def _delete_prev_fn(
        self,
        message: str | dict[str, list],
        history: list[list[str | tuple | None]],
    ) -> tuple[
        list[list[str | tuple | None]],
        str | dict[str, list],
        list[list[str | tuple | None]],
    ]:
        if self.multimodal and isinstance(message, dict):
            remove_input = (
                len(message["files"]) + 1
                if message["text"] is not None
                else len(message["files"])
            )
            history = history[:-remove_input]
        else:
            while history:
                deleted_a, deleted_b = history[-1]
                history = history[:-1]
                if isinstance(deleted_a, str) and isinstance(deleted_b, str):
                    break
        return history, message or "", history