import json import os import random from threading import Thread import gradio as gr import spaces import torch from langchain.schema import AIMessage, HumanMessage from langchain_openai import ChatOpenAI from pydantic import BaseModel, SecretStr from transformers import ( AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) tokenizer = AutoTokenizer.from_pretrained("ContextualAI/archangel_sft-kto_llama13b") model = AutoModelForCausalLM.from_pretrained( "ContextualAI/archangel_sft-kto_llama13b", device_map="auto", load_in_4bit=True ) class OAAPIKey(BaseModel): openai_api_key: SecretStr def set_openai_api_key(api_key: SecretStr): os.environ["OPENAI_API_KEY"] = api_key.get_secret_value() llm = ChatOpenAI(temperature=1.0, model="gpt-3.5-turbo-0125") return llm class StopOnSequence(StoppingCriteria): def __init__(self, sequence, tokenizer): self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) self.sequence_len = len(self.sequence_ids) def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs ) -> bool: for i in range(input_ids.shape[0]): if input_ids[i, -self.sequence_len:].tolist() == self.sequence_ids: return True return False @spaces.GPU(duration=54) def spaces_model_predict(message: str, history: list[tuple[str, str]]): history_transformer_format = history + [[message, ""]] stop = StopOnSequence("<|user|>", tokenizer) messages = "".join( [ f"<|user|>\n{item[0]}\n<|assistant|>\n{item[1]}" for item in history_transformer_format ] ) model_inputs = tokenizer([messages], return_tensors="pt").to("cuda") streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.95, top_k=1000, temperature=1.0, num_beams=1, stopping_criteria=StoppingCriteriaList([stop]), ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() generated_text = "" for new_token in streamer: generated_text += new_token if "<|user|>" in generated_text: generated_text = generated_text.split("<|user|>")[0].strip() break return generated_text def predict( message: str, chat_history_openai: list[tuple[str, str]], chat_history_spaces: list[tuple[str, str]], openai_api_key: SecretStr, ): openai_key_model = OAAPIKey(openai_api_key=openai_api_key) openai_llm = set_openai_api_key(api_key=openai_key_model.openai_api_key) # OpenAI history_langchain_format_openai = [] for human, ai in chat_history_openai: history_langchain_format_openai.append(HumanMessage(content=human)) history_langchain_format_openai.append(AIMessage(content=ai)) history_langchain_format_openai.append(HumanMessage(content=message)) openai_response = openai_llm.invoke(input=history_langchain_format_openai) # Spaces Model spaces_model_response = spaces_model_predict(message, chat_history_spaces) chat_history_openai.append((message, openai_response.content)) chat_history_spaces.append((message, spaces_model_response)) return "", chat_history_openai, chat_history_spaces with open("askbakingtop.json", "r") as file: ask_baking_msgs = json.load(file) with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): openai_api_key = gr.Textbox( label="Please enter your OpenAI API key", type="password", elem_id="lets-chat-openai-api-key", ) with gr.Row(): options = [ask["history"] for ask in random.sample(ask_baking_msgs, k=3)] msg = gr.Dropdown( options, label="Please enter your message", interactive=True, multiselect=False, allow_custom_value=True, ) with gr.Row(): with gr.Column(scale=1): chatbot_openai = gr.Chatbot(label="OpenAI Chatbot 🏢") with gr.Column(scale=1): chatbot_spaces = gr.Chatbot( label="Your own fine-tuned preference optimized Chatbot 💪" ) with gr.Row(): submit_button = gr.Button("Submit") with gr.Row(): clear = gr.ClearButton([msg]) def respond( message: str, chat_history_openai: list[tuple[str, str]], chat_history_spaces: list[tuple[str, str]], openai_api_key: SecretStr, ): return predict( message=message, chat_history_openai=chat_history_openai, chat_history_spaces=chat_history_spaces, openai_api_key=openai_api_key, ) submit_button.click( fn=respond, inputs=[ msg, chatbot_openai, chatbot_spaces, openai_api_key, ], outputs=[msg, chatbot_openai, chatbot_spaces], ) demo.launch()