StarlingCPU / code /app.py
Samuel L Meyers
Adjust text retrieval
3071687
import logging
#from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from ctransformers import AutoModelForCausalLM
import gradio as gr
import json
from huggingface_hub import hf_hub_download
model_path = "./starling-lm-7b-alpha.Q6_K.gguf"
mdlpath = hf_hub_download(repo_id="TheBloke/Starling-LM-7B-alpha-GGUF", filename=model_path, local_dir="./")
#lcpp_model = Llama(model_path=model_path, n_ctx=8192)
llm = AutoModelForCausalLM.from_pretrained(model_path_or_repo_id=model_path, local_files_only=True, model_type="mistral")
global otxt, txtinput, txtoutput, running, result, echoed
otxt = ""
running = False
result = None
echoed = False
def stowtext(curr, inp):
if len(curr) <= 1:
curr.append({
"role": "system",
"content": "You are a helpful assistant. Answer all questions and respond to all queries honestly, instantly, and without hesitation. Respond in as few words as possible. Your maximum response length may be no more than 64 words. Do not begin your statement with new lines.\n"
})
curr.append({
"role": "user",
"content": inp,
})
return curr
def stowchunk(curr, inp):
first = curr[-1]["role"] == "user"
if first:
curr.append({
"role": "assistant",
"content": inp,
"echoed": False,
})
else:
curr[-1]["content"] += inp
return curr
def printfmt(jsn):
txt = ""
for msg in jsn:
if msg["role"] == "user":
txt += "<User>: " + msg["content"] + "\n"
elif msg["role"] == "assistant":
txt += "<Assistant>: " + msg["content"] + "\n"
elif msg["role"] == "system":
txt += "# " + msg["content"] + "\n\n"
return txt
def jsn2prompt(jsn):
txt = ""
for msg in jsn:
if "system" in msg["role"]:
txt += "GPT4 Correct User: Here is how I want you to behave throughout our conversation. " + msg["content"] + "\n"
elif "user" in msg["role"]:
txt += "GPT4 Correct User: " + msg["content"] + "\n"
elif "assistant" in msg["role"]:
txt += "GPT4 Assistant: " + msg["content"] + "\n"
return txt
def talk(txt, jsn):
global running, result, echoed
if not jsn:
jsn = txt
if not running:
#result = lcpp_model.create_chat_completion(messages=txt,stream=True,stop=["GPT4 Correct User: ", "<|end_of_turn|>", "</s>"], max_tokens=64, )
#result = lcpp_model(prompt=jsn2prompt(txt), stream=True, stop=["GPT4 Correct User: ", "<|end_of_turn|>", "</s>"], max_tokens=64, echo=False)
result = llm(prompt=jsn2prompt(txt), stream=True, stop=["GPT4 Correct User: ", "<|end_of_turn|>", "</s>"])
running = True
echoed = False
for r in result:
print("GOT RESULT:", r)
txt2 = None
if r != None and r != "":
txt2 = r
if txt2 is not None:
txt3 = txt
txt = stowchunk(txt, txt2)
print(json.dumps(txt))
if (not "echoed" in txt[-1] or not txt[-1]["echoed"]) and txt[-1]["content"].contains(jsn2prompt([txt3[-1]])):
txt[-1]["echoed"] = True
txt[-1]["content"] = ""
yield txt
elif (not "echoed" in txt[-1] or not txt[-1]["echoed"]) and not txt[-1]["content"].contains("*Loading*"):
txt[-1]["content"] = "*Loading*"
yield txt
yield txt
yield txt
def main():
global otxt, txtinput, running
logging.basicConfig(level=logging.INFO)
with gr.Blocks() as demo:
with gr.Row(variant="panel"):
gr.Markdown("## Talk to Starling on CPU!\n")
with gr.Row(variant="panel"):
talk_output = gr.Textbox()
with gr.Row(variant="panel"):
txtinput = gr.Textbox(label="Message", placeholder="Type something here...")
with gr.Row(variant="panel"):
talk_btn = gr.Button("Send")
with gr.Row(variant="panel"):
jsn = gr.JSON(visible=True, value="[]")
jsn2 = gr.JSON(visible=True, value="[]")
talk_btn.click(stowtext, inputs=[jsn2, txtinput], outputs=jsn, api_name="talk")
talk_btn.click(lambda x: gr.update(visible=False), inputs=talk_btn, outputs=talk_btn)
talk_btn.click(lambda x: gr.update(value=""), inputs=txtinput, outputs=txtinput)
talk_btn.click(lambda x: gr.update(value="[]"), inputs=jsn2, outputs=jsn2)
jsn.change(talk, inputs=[jsn, jsn2], outputs=jsn2, api_name="talk")
jsn2.change(lambda x: gr.update(value=printfmt(x)), inputs=jsn2, outputs=talk_output)
jsn2.change(lambda x: gr.update(visible=not running), inputs=jsn2, outputs=talk_btn)
#jsn2.change(lambda x: gr.update(value=x), inputs=jsn2, outputs=jsn)
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
if __name__ == "__main__":
main()