bloom_demo / app.py
Narsil's picture
Narsil HF staff
Adding auth to not get rate limited.
513b107
raw
history blame
No virus
3.55 kB
import gradio as gr
import requests
import json
import os
from screenshot import (
before_prompt,
prompt_to_generation,
after_generation,
js_save,
js_load_script,
)
from spaces_info import description, examples, initial_prompt_value
API_URL = os.getenv("API_URL")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
def query(payload):
print(payload)
response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
print(response)
return json.loads(response.content.decode("utf-8"))
def inference(input_sentence, max_length, sample_or_greedy, seed=42):
if sample_or_greedy == "Sample":
parameters = {
"max_new_tokens": max_length,
"top_p": 0.9,
"do_sample": True,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
else:
parameters = {
"max_new_tokens": max_length,
"do_sample": False,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }
data = query(payload)
if "error" in data:
return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")
generation = data[0]["generated_text"].split(input_sentence, 1)[1]
return (
before_prompt
+ input_sentence
+ prompt_to_generation
+ generation
+ after_generation,
data[0]["generated_text"],
"",
)
if __name__ == "__main__":
demo = gr.Blocks()
with demo:
with gr.Row():
gr.Markdown(value=description)
with gr.Row():
with gr.Column():
text = gr.Textbox(
label="Input",
value=" ", # should be set to " " when plugged into a real API
)
tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")
sampling = gr.Radio(
["Sample", "Greedy"], label="Sample or greedy", value="Sample"
)
sampling2 = gr.Radio(
["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"],
value="Sample 1",
label="Sample other generations (only work in 'Sample' mode)",
type="index",
)
with gr.Row():
submit = gr.Button("Submit")
load_image = gr.Button("Generate Image")
with gr.Column():
text_error = gr.Markdown(label="Log information")
text_out = gr.Textbox(label="Output")
display_out = gr.HTML(label="Image")
display_out.set_event_trigger(
"load",
fn=None,
inputs=None,
outputs=None,
no_target=True,
js=js_load_script,
)
with gr.Row():
gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2])
submit.click(
inference,
inputs=[text, tokens, sampling, sampling2],
outputs=[display_out, text_out, text_error],
)
load_image.click(fn=None, inputs=None, outputs=None, _js=js_save)
demo.launch()