robert
Refactoring the model to be a function not a generator
c88df11
raw
history blame
No virus
5.25 kB
import json
import os
import random
from threading import Thread
import gradio as gr
import spaces
import torch
from langchain.schema import AIMessage, HumanMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, SecretStr
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
)
tokenizer = AutoTokenizer.from_pretrained("ContextualAI/archangel_sft-kto_llama13b")
model = AutoModelForCausalLM.from_pretrained(
"ContextualAI/archangel_sft-kto_llama13b", device_map="auto", load_in_4bit=True
)
class OAAPIKey(BaseModel):
openai_api_key: SecretStr
def set_openai_api_key(api_key: SecretStr):
os.environ["OPENAI_API_KEY"] = api_key.get_secret_value()
llm = ChatOpenAI(temperature=1.0, model="gpt-3.5-turbo-0125")
return llm
class StopOnSequence(StoppingCriteria):
def __init__(self, sequence, tokenizer):
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_len = len(self.sequence_ids)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
for i in range(input_ids.shape[0]):
if input_ids[i, -self.sequence_len:].tolist() == self.sequence_ids:
return True
return False
@spaces.GPU(duration=54)
def spaces_model_predict(message: str, history: list[tuple[str, str]]):
history_transformer_format = history + [[message, ""]]
stop = StopOnSequence("<|user|>", tokenizer)
messages = "".join(
[
f"<|user|>\n{item[0]}\n<|assistant|>\n{item[1]}"
for item in history_transformer_format
]
)
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop]),
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
generated_text = ""
for new_token in streamer:
generated_text += new_token
if "<|user|>" in generated_text:
generated_text = generated_text.split("<|user|>")[0].strip()
break
return generated_text
def predict(
message: str,
chat_history_openai: list[tuple[str, str]],
chat_history_spaces: list[tuple[str, str]],
openai_api_key: SecretStr,
):
openai_key_model = OAAPIKey(openai_api_key=openai_api_key)
openai_llm = set_openai_api_key(api_key=openai_key_model.openai_api_key)
# OpenAI
history_langchain_format_openai = []
for human, ai in chat_history_openai:
history_langchain_format_openai.append(HumanMessage(content=human))
history_langchain_format_openai.append(AIMessage(content=ai))
history_langchain_format_openai.append(HumanMessage(content=message))
openai_response = openai_llm.invoke(input=history_langchain_format_openai)
# Spaces Model
spaces_model_response = spaces_model_predict(message, chat_history_spaces)
chat_history_openai.append((message, openai_response.content))
chat_history_spaces.append((message, spaces_model_response))
return "", chat_history_openai, chat_history_spaces
with open("askbakingtop.json", "r") as file:
ask_baking_msgs = json.load(file)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
openai_api_key = gr.Textbox(
label="Please enter your OpenAI API key",
type="password",
elem_id="lets-chat-openai-api-key",
)
with gr.Row():
options = [ask["history"] for ask in random.sample(ask_baking_msgs, k=3)]
msg = gr.Dropdown(
options,
label="Please enter your message",
interactive=True,
multiselect=False,
allow_custom_value=True,
)
with gr.Row():
with gr.Column(scale=1):
chatbot_openai = gr.Chatbot(label="OpenAI Chatbot 🏒")
with gr.Column(scale=1):
chatbot_spaces = gr.Chatbot(
label="Your own fine-tuned preference optimized Chatbot πŸ’ͺ"
)
with gr.Row():
submit_button = gr.Button("Submit")
with gr.Row():
clear = gr.ClearButton([msg])
def respond(
message: str,
chat_history_openai: list[tuple[str, str]],
chat_history_spaces: list[tuple[str, str]],
openai_api_key: SecretStr,
):
return predict(
message=message,
chat_history_openai=chat_history_openai,
chat_history_spaces=chat_history_spaces,
openai_api_key=openai_api_key,
)
submit_button.click(
fn=respond,
inputs=[
msg,
chatbot_openai,
chatbot_spaces,
openai_api_key,
],
outputs=[msg, chatbot_openai, chatbot_spaces],
)
demo.launch()