skylersterling commited on
Commit
db9d4db
·
verified ·
1 Parent(s): 8acccfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -94
app.py CHANGED
@@ -1,95 +1,22 @@
1
- from huggingface_hub import InferenceClient
2
  import gradio as gr
3
- import random
4
-
5
- API_URL = "https://api-inference.huggingface.co/models/"
6
-
7
- client = InferenceClient(
8
- "skylersterling/TopicGPT"
9
- )
10
-
11
- def format_prompt(message):
12
- prompt = f"<s>[INST] {message} [/INST] #TOPIC# "
13
- return prompt
14
-
15
- def generate(message, temperature=0.05, max_new_tokens=512, top_p=0.2, repetition_penalty=1.0):
16
- temperature = float(temperature)
17
- if temperature < 1e-2:
18
- temperature = 1e-2
19
- top_p = float(top_p)
20
-
21
- generate_kwargs = dict(
22
- temperature=temperature,
23
- max_new_tokens=max_new_tokens,
24
- top_p=top_p,
25
- repetition_penalty=repetition_penalty,
26
- do_sample=True,
27
- seed=random.randint(0, 10**7),
28
- )
29
-
30
- formatted_prompt = format_prompt(message)
31
-
32
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
33
- output = ""
34
-
35
- for response in stream:
36
- output += response.token.text
37
- if "#" in output:
38
- output = output.replace("#", "")
39
- break
40
- yield output
41
- return output
42
-
43
- additional_inputs=[
44
- gr.Slider(
45
- label="Temperature",
46
- value=0.05,
47
- minimum=0.0,
48
- maximum=1.0,
49
- step=0.05,
50
- interactive=True,
51
- info="Higher values produce more diverse outputs",
52
- ),
53
- gr.Slider(
54
- label="Max new tokens",
55
- value=512,
56
- minimum=64,
57
- maximum=1024,
58
- step=64,
59
- interactive=True,
60
- info="The maximum numbers of new tokens",
61
- ),
62
- gr.Slider(
63
- label="Top-p (nucleus sampling)",
64
- value=0.2,
65
- minimum=0.0,
66
- maximum=1,
67
- step=0.05,
68
- interactive=True,
69
- info="Higher values sample more low-probability tokens",
70
- ),
71
- gr.Slider(
72
- label="Repetition penalty",
73
- value=1.2,
74
- minimum=1.0,
75
- maximum=2.0,
76
- step=0.05,
77
- interactive=True,
78
- info="Penalize repeated tokens",
79
- )
80
- ]
81
-
82
- customCSS = """
83
- #component-7 { # this is the default element ID of the chat component
84
- height: 800px; # adjust the height as needed
85
- flex-grow: 1;
86
- }
87
- """
88
-
89
- with gr.Blocks(css=customCSS) as demo:
90
- gr.ChatInterface(
91
- fn=generate,
92
- additional_inputs=additional_inputs,
93
- )
94
-
95
- demo.queue().launch(debug=True, share=True)
 
1
+ # Import the libraries
2
  import gradio as gr
3
+ import transformers
4
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
5
+ import os
6
+
7
+ HF_TOKEN = os.environ.get("HF_TOKEN")
8
+
9
+ # Load the tokenizer and model
10
+ tokenizer = GPT2Tokenizer.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
11
+ model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
12
+
13
+ # Define the function that generates text from a prompt
14
+ def generate_text(prompt):
15
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
16
+ output = model.generate(input_ids, max_new_tokens=80, do_sample=True)
17
+ text = tokenizer.decode(output[0], skip_special_tokens=True)
18
+ return text
19
+
20
+ # Create a gradio interface with a text input and a text output
21
+ interface = gr.Interface(fn=generate_text, inputs='text', outputs='text')
22
+ interface.launch()