Spaces:
Running
Running
Add Joi
Browse files
app.ipynb
CHANGED
|
@@ -31,7 +31,23 @@
|
|
| 31 |
},
|
| 32 |
{
|
| 33 |
"cell_type": "code",
|
| 34 |
-
"execution_count":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
"metadata": {},
|
| 36 |
"outputs": [],
|
| 37 |
"source": [
|
|
@@ -42,8 +58,7 @@
|
|
| 42 |
" temperature,\n",
|
| 43 |
" top_p\n",
|
| 44 |
"):\n",
|
| 45 |
-
"
|
| 46 |
-
" headers = {\"Authorization\": f\"Bearer {HF_TOKEN}\", \"x-wait-for-model\": \"1\"}\n",
|
| 47 |
"\n",
|
| 48 |
" payload = {\n",
|
| 49 |
" \"inputs\": inputs,\n",
|
|
@@ -55,7 +70,7 @@
|
|
| 55 |
" },\n",
|
| 56 |
" }\n",
|
| 57 |
"\n",
|
| 58 |
-
" response = requests.post(
|
| 59 |
"\n",
|
| 60 |
" if response.status_code == 200:\n",
|
| 61 |
" return response.json()\n",
|
|
@@ -65,23 +80,24 @@
|
|
| 65 |
},
|
| 66 |
{
|
| 67 |
"cell_type": "code",
|
| 68 |
-
"execution_count":
|
| 69 |
"metadata": {},
|
| 70 |
"outputs": [
|
| 71 |
{
|
| 72 |
"data": {
|
| 73 |
"text/plain": [
|
| 74 |
-
"
|
| 75 |
]
|
| 76 |
},
|
| 77 |
-
"execution_count":
|
| 78 |
"metadata": {},
|
| 79 |
"output_type": "execute_result"
|
| 80 |
}
|
| 81 |
],
|
| 82 |
"source": [
|
| 83 |
-
"model_id = \"google/flan-t5-xl\"\n",
|
| 84 |
-
"
|
|
|
|
| 85 |
"query_chat_api(model_id, query, 1, 0.95)"
|
| 86 |
]
|
| 87 |
},
|
|
@@ -101,7 +117,7 @@
|
|
| 101 |
},
|
| 102 |
{
|
| 103 |
"cell_type": "code",
|
| 104 |
-
"execution_count":
|
| 105 |
"metadata": {},
|
| 106 |
"outputs": [],
|
| 107 |
"source": [
|
|
@@ -121,7 +137,10 @@
|
|
| 121 |
" inputs = prompt_template[\"prompt\"].format(human_input=text_input)\n",
|
| 122 |
"\n",
|
| 123 |
" output = query_chat_api(model_id, inputs, temperature, top_p)\n",
|
| 124 |
-
"
|
|
|
|
|
|
|
|
|
|
| 125 |
"\n",
|
| 126 |
" chat = [\n",
|
| 127 |
" (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)\n",
|
|
@@ -695,14 +714,14 @@
|
|
| 695 |
},
|
| 696 |
{
|
| 697 |
"cell_type": "code",
|
| 698 |
-
"execution_count":
|
| 699 |
"metadata": {},
|
| 700 |
"outputs": [
|
| 701 |
{
|
| 702 |
"name": "stdout",
|
| 703 |
"output_type": "stream",
|
| 704 |
"text": [
|
| 705 |
-
"Running on local URL: http://127.0.0.1:
|
| 706 |
"\n",
|
| 707 |
"To create a public link, set `share=True` in `launch()`.\n"
|
| 708 |
]
|
|
@@ -710,7 +729,7 @@
|
|
| 710 |
{
|
| 711 |
"data": {
|
| 712 |
"text/html": [
|
| 713 |
-
"<div><iframe src=\"http://127.0.0.1:
|
| 714 |
],
|
| 715 |
"text/plain": [
|
| 716 |
"<IPython.core.display.HTML object>"
|
|
@@ -723,7 +742,7 @@
|
|
| 723 |
"data": {
|
| 724 |
"text/plain": []
|
| 725 |
},
|
| 726 |
-
"execution_count":
|
| 727 |
"metadata": {},
|
| 728 |
"output_type": "execute_result"
|
| 729 |
}
|
|
@@ -744,7 +763,7 @@
|
|
| 744 |
" with gr.Row():\n",
|
| 745 |
" with gr.Column(scale=1):\n",
|
| 746 |
" model_id = gr.Dropdown(\n",
|
| 747 |
-
" choices=[\"google/flan-t5-xl\"],\n",
|
| 748 |
" value=\"google/flan-t5-xl\",\n",
|
| 749 |
" label=\"Model\",\n",
|
| 750 |
" interactive=True,\n",
|
|
@@ -846,7 +865,7 @@
|
|
| 846 |
},
|
| 847 |
{
|
| 848 |
"cell_type": "code",
|
| 849 |
-
"execution_count":
|
| 850 |
"metadata": {},
|
| 851 |
"outputs": [],
|
| 852 |
"source": [
|
|
|
|
| 31 |
},
|
| 32 |
{
|
| 33 |
"cell_type": "code",
|
| 34 |
+
"execution_count": 32,
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [],
|
| 37 |
+
"source": [
|
| 38 |
+
"# |export\n",
|
| 39 |
+
"def get_model_endpoint(model_id):\n",
|
| 40 |
+
" if \"joi\" in model_id:\n",
|
| 41 |
+
" headers = None\n",
|
| 42 |
+
" return \"https://joi-20b.ngrok.io/generate\", headers\n",
|
| 43 |
+
" else:\n",
|
| 44 |
+
" headers = {\"Authorization\": f\"Bearer {HF_TOKEN}\", \"x-wait-for-model\": \"1\"}\n",
|
| 45 |
+
" return f\"https://api-inference.huggingface.co/models/{model_id}\", headers\n"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": 33,
|
| 51 |
"metadata": {},
|
| 52 |
"outputs": [],
|
| 53 |
"source": [
|
|
|
|
| 58 |
" temperature,\n",
|
| 59 |
" top_p\n",
|
| 60 |
"):\n",
|
| 61 |
+
" endpoint, headers = get_model_endpoint(model_id)\n",
|
|
|
|
| 62 |
"\n",
|
| 63 |
" payload = {\n",
|
| 64 |
" \"inputs\": inputs,\n",
|
|
|
|
| 70 |
" },\n",
|
| 71 |
" }\n",
|
| 72 |
"\n",
|
| 73 |
+
" response = requests.post(endpoint, json=payload, headers=headers)\n",
|
| 74 |
"\n",
|
| 75 |
" if response.status_code == 200:\n",
|
| 76 |
" return response.json()\n",
|
|
|
|
| 80 |
},
|
| 81 |
{
|
| 82 |
"cell_type": "code",
|
| 83 |
+
"execution_count": 36,
|
| 84 |
"metadata": {},
|
| 85 |
"outputs": [
|
| 86 |
{
|
| 87 |
"data": {
|
| 88 |
"text/plain": [
|
| 89 |
+
"{'generated_text': '\\n\\nJoi: Black holes are regions of space-time where gravity is so strong that nothing'}"
|
| 90 |
]
|
| 91 |
},
|
| 92 |
+
"execution_count": 36,
|
| 93 |
"metadata": {},
|
| 94 |
"output_type": "execute_result"
|
| 95 |
}
|
| 96 |
],
|
| 97 |
"source": [
|
| 98 |
+
"# model_id = \"google/flan-t5-xl\"\n",
|
| 99 |
+
"model_id = \"Rallio67/joi_20B_instruct_alpha\"\n",
|
| 100 |
+
"query = \"What can you tell me about black holes?\"\n",
|
| 101 |
"query_chat_api(model_id, query, 1, 0.95)"
|
| 102 |
]
|
| 103 |
},
|
|
|
|
| 117 |
},
|
| 118 |
{
|
| 119 |
"cell_type": "code",
|
| 120 |
+
"execution_count": 37,
|
| 121 |
"metadata": {},
|
| 122 |
"outputs": [],
|
| 123 |
"source": [
|
|
|
|
| 137 |
" inputs = prompt_template[\"prompt\"].format(human_input=text_input)\n",
|
| 138 |
"\n",
|
| 139 |
" output = query_chat_api(model_id, inputs, temperature, top_p)\n",
|
| 140 |
+
" # TODO: remove this hack when inference backend schema is updated\n",
|
| 141 |
+
" if isinstance(output, list):\n",
|
| 142 |
+
" output = output[0]\n",
|
| 143 |
+
" history.append(\" \" + output[\"generated_text\"])\n",
|
| 144 |
"\n",
|
| 145 |
" chat = [\n",
|
| 146 |
" (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)\n",
|
|
|
|
| 714 |
},
|
| 715 |
{
|
| 716 |
"cell_type": "code",
|
| 717 |
+
"execution_count": 38,
|
| 718 |
"metadata": {},
|
| 719 |
"outputs": [
|
| 720 |
{
|
| 721 |
"name": "stdout",
|
| 722 |
"output_type": "stream",
|
| 723 |
"text": [
|
| 724 |
+
"Running on local URL: http://127.0.0.1:7861\n",
|
| 725 |
"\n",
|
| 726 |
"To create a public link, set `share=True` in `launch()`.\n"
|
| 727 |
]
|
|
|
|
| 729 |
{
|
| 730 |
"data": {
|
| 731 |
"text/html": [
|
| 732 |
+
"<div><iframe src=\"http://127.0.0.1:7861/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
| 733 |
],
|
| 734 |
"text/plain": [
|
| 735 |
"<IPython.core.display.HTML object>"
|
|
|
|
| 742 |
"data": {
|
| 743 |
"text/plain": []
|
| 744 |
},
|
| 745 |
+
"execution_count": 38,
|
| 746 |
"metadata": {},
|
| 747 |
"output_type": "execute_result"
|
| 748 |
}
|
|
|
|
| 763 |
" with gr.Row():\n",
|
| 764 |
" with gr.Column(scale=1):\n",
|
| 765 |
" model_id = gr.Dropdown(\n",
|
| 766 |
+
" choices=[\"google/flan-t5-xl\" ,\"Rallio67/joi_20B_instruct_alpha\"],\n",
|
| 767 |
" value=\"google/flan-t5-xl\",\n",
|
| 768 |
" label=\"Model\",\n",
|
| 769 |
" interactive=True,\n",
|
|
|
|
| 865 |
},
|
| 866 |
{
|
| 867 |
"cell_type": "code",
|
| 868 |
+
"execution_count": 15,
|
| 869 |
"metadata": {},
|
| 870 |
"outputs": [],
|
| 871 |
"source": [
|
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
|
| 2 |
|
| 3 |
# %% auto 0
|
| 4 |
-
__all__ = ['HF_TOKEN', 'title', 'description', 'query_chat_api', 'inference_chat']
|
| 5 |
|
| 6 |
# %% app.ipynb 0
|
| 7 |
import gradio as gr
|
|
@@ -21,14 +21,23 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 21 |
|
| 22 |
|
| 23 |
# %% app.ipynb 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def query_chat_api(
|
| 25 |
model_id,
|
| 26 |
inputs,
|
| 27 |
temperature,
|
| 28 |
top_p
|
| 29 |
):
|
| 30 |
-
|
| 31 |
-
headers = {"Authorization": f"Bearer {HF_TOKEN}", "x-wait-for-model": "1"}
|
| 32 |
|
| 33 |
payload = {
|
| 34 |
"inputs": inputs,
|
|
@@ -40,7 +49,7 @@ def query_chat_api(
|
|
| 40 |
},
|
| 41 |
}
|
| 42 |
|
| 43 |
-
response = requests.post(
|
| 44 |
|
| 45 |
if response.status_code == 200:
|
| 46 |
return response.json()
|
|
@@ -48,7 +57,7 @@ def query_chat_api(
|
|
| 48 |
return "Error: " + response.text
|
| 49 |
|
| 50 |
|
| 51 |
-
# %% app.ipynb
|
| 52 |
def inference_chat(
|
| 53 |
model_id,
|
| 54 |
prompt_template,
|
|
@@ -64,7 +73,10 @@ def inference_chat(
|
|
| 64 |
inputs = prompt_template["prompt"].format(human_input=text_input)
|
| 65 |
|
| 66 |
output = query_chat_api(model_id, inputs, temperature, top_p)
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
chat = [
|
| 70 |
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
|
|
@@ -73,7 +85,7 @@ def inference_chat(
|
|
| 73 |
return {chatbot: chat, state: history}
|
| 74 |
|
| 75 |
|
| 76 |
-
# %% app.ipynb
|
| 77 |
title = """<h1 align="center">Chatty Language Models</h1>"""
|
| 78 |
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
|
| 79 |
|
|
@@ -98,7 +110,7 @@ So far, the following prompts are available:
|
|
| 98 |
As you can see, most of these prompts exceed the maximum context size of models like Flan-T5, so an error usually means the Inference API has timed out.
|
| 99 |
"""
|
| 100 |
|
| 101 |
-
# %% app.ipynb
|
| 102 |
with gr.Blocks(
|
| 103 |
css="""
|
| 104 |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
|
|
@@ -113,7 +125,7 @@ with gr.Blocks(
|
|
| 113 |
with gr.Row():
|
| 114 |
with gr.Column(scale=1):
|
| 115 |
model_id = gr.Dropdown(
|
| 116 |
-
choices=["google/flan-t5-xl"],
|
| 117 |
value="google/flan-t5-xl",
|
| 118 |
label="Model",
|
| 119 |
interactive=True,
|
|
|
|
| 1 |
# AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
|
| 2 |
|
| 3 |
# %% auto 0
|
| 4 |
+
__all__ = ['HF_TOKEN', 'title', 'description', 'get_model_endpoint', 'query_chat_api', 'inference_chat']
|
| 5 |
|
| 6 |
# %% app.ipynb 0
|
| 7 |
import gradio as gr
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
# %% app.ipynb 2
|
| 24 |
+
def get_model_endpoint(model_id):
|
| 25 |
+
if "joi" in model_id:
|
| 26 |
+
headers = None
|
| 27 |
+
return "https://joi-20b.ngrok.io/generate", headers
|
| 28 |
+
else:
|
| 29 |
+
headers = {"Authorization": f"Bearer {HF_TOKEN}", "x-wait-for-model": "1"}
|
| 30 |
+
return f"https://api-inference.huggingface.co/models/{model_id}", headers
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# %% app.ipynb 3
|
| 34 |
def query_chat_api(
|
| 35 |
model_id,
|
| 36 |
inputs,
|
| 37 |
temperature,
|
| 38 |
top_p
|
| 39 |
):
|
| 40 |
+
endpoint, headers = get_model_endpoint(model_id)
|
|
|
|
| 41 |
|
| 42 |
payload = {
|
| 43 |
"inputs": inputs,
|
|
|
|
| 49 |
},
|
| 50 |
}
|
| 51 |
|
| 52 |
+
response = requests.post(endpoint, json=payload, headers=headers)
|
| 53 |
|
| 54 |
if response.status_code == 200:
|
| 55 |
return response.json()
|
|
|
|
| 57 |
return "Error: " + response.text
|
| 58 |
|
| 59 |
|
| 60 |
+
# %% app.ipynb 6
|
| 61 |
def inference_chat(
|
| 62 |
model_id,
|
| 63 |
prompt_template,
|
|
|
|
| 73 |
inputs = prompt_template["prompt"].format(human_input=text_input)
|
| 74 |
|
| 75 |
output = query_chat_api(model_id, inputs, temperature, top_p)
|
| 76 |
+
# TODO: remove this hack when inference backend schema is updated
|
| 77 |
+
if isinstance(output, list):
|
| 78 |
+
output = output[0]
|
| 79 |
+
history.append(" " + output["generated_text"])
|
| 80 |
|
| 81 |
chat = [
|
| 82 |
(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
|
|
|
|
| 85 |
return {chatbot: chat, state: history}
|
| 86 |
|
| 87 |
|
| 88 |
+
# %% app.ipynb 16
|
| 89 |
title = """<h1 align="center">Chatty Language Models</h1>"""
|
| 90 |
description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
|
| 91 |
|
|
|
|
| 110 |
As you can see, most of these prompts exceed the maximum context size of models like Flan-T5, so an error usually means the Inference API has timed out.
|
| 111 |
"""
|
| 112 |
|
| 113 |
+
# %% app.ipynb 17
|
| 114 |
with gr.Blocks(
|
| 115 |
css="""
|
| 116 |
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
|
|
|
|
| 125 |
with gr.Row():
|
| 126 |
with gr.Column(scale=1):
|
| 127 |
model_id = gr.Dropdown(
|
| 128 |
+
choices=["google/flan-t5-xl" ,"Rallio67/joi_20B_instruct_alpha"],
|
| 129 |
value="google/flan-t5-xl",
|
| 130 |
label="Model",
|
| 131 |
interactive=True,
|