naourpally commited on
Commit
d435c86
·
1 Parent(s): b7cebb8

Simple gradio app that calls the fair end point

Browse files
Files changed (1) hide show
  1. app.py +22 -139
app.py CHANGED
@@ -1,141 +1,24 @@
1
- import os
2
- from threading import Thread
3
- from typing import Iterator
4
-
5
  import gradio as gr
6
- import spaces
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
-
10
- MAX_MAX_NEW_TOKENS = 2048
11
- DEFAULT_MAX_NEW_TOKENS = 1024
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
-
14
- DESCRIPTION = """\
15
- # Llama-2 7B Chat
16
- This Space demonstrates model [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, a Llama 2 model with 7B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
17
- 🔎 For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
18
- 🔨 Looking for an even more powerful model? Check out the [13B version](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat) or the large [70B model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI).
19
- """
20
-
21
- LICENSE = """
22
- <p/>
23
- ---
24
- As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
25
- this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
26
- """
27
-
28
- if not torch.cuda.is_available():
29
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
30
-
31
-
32
- if torch.cuda.is_available():
33
- model_id = "meta-llama/Llama-2-7b-chat-hf"
34
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
36
- tokenizer.use_default_system_prompt = False
37
-
38
-
39
- @spaces.GPU
40
- def generate(
41
- message: str,
42
- chat_history: list[tuple[str, str]],
43
- system_prompt: str,
44
- max_new_tokens: int = 1024,
45
- temperature: float = 0.6,
46
- top_p: float = 0.9,
47
- top_k: int = 50,
48
- repetition_penalty: float = 1.2,
49
- ) -> Iterator[str]:
50
- conversation = []
51
- if system_prompt:
52
- conversation.append({"role": "system", "content": system_prompt})
53
- for user, assistant in chat_history:
54
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
55
- conversation.append({"role": "user", "content": message})
56
-
57
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
58
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
59
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
60
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
61
- input_ids = input_ids.to(model.device)
62
-
63
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
64
- generate_kwargs = dict(
65
- {"input_ids": input_ids},
66
- streamer=streamer,
67
- max_new_tokens=max_new_tokens,
68
- do_sample=True,
69
- top_p=top_p,
70
- top_k=top_k,
71
- temperature=temperature,
72
- num_beams=1,
73
- repetition_penalty=repetition_penalty,
74
- )
75
- t = Thread(target=model.generate, kwargs=generate_kwargs)
76
- t.start()
77
-
78
- outputs = []
79
- for text in streamer:
80
- outputs.append(text)
81
- yield "".join(outputs)
82
-
83
-
84
- chat_interface = gr.ChatInterface(
85
- fn=generate,
86
- additional_inputs=[
87
- gr.Textbox(label="System prompt", lines=6),
88
- gr.Slider(
89
- label="Max new tokens",
90
- minimum=1,
91
- maximum=MAX_MAX_NEW_TOKENS,
92
- step=1,
93
- value=DEFAULT_MAX_NEW_TOKENS,
94
- ),
95
- gr.Slider(
96
- label="Temperature",
97
- minimum=0.1,
98
- maximum=4.0,
99
- step=0.1,
100
- value=0.6,
101
- ),
102
- gr.Slider(
103
- label="Top-p (nucleus sampling)",
104
- minimum=0.05,
105
- maximum=1.0,
106
- step=0.05,
107
- value=0.9,
108
- ),
109
- gr.Slider(
110
- label="Top-k",
111
- minimum=1,
112
- maximum=1000,
113
- step=1,
114
- value=50,
115
- ),
116
- gr.Slider(
117
- label="Repetition penalty",
118
- minimum=1.0,
119
- maximum=2.0,
120
- step=0.05,
121
- value=1.2,
122
- ),
123
- ],
124
- stop_btn=None,
125
- examples=[
126
- ["Hello there! How are you doing?"],
127
- ["Can you explain briefly to me what is the Python programming language?"],
128
- ["Explain the plot of Cinderella in a sentence."],
129
- ["How many hours does it take a man to eat a Helicopter?"],
130
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
131
- ],
132
- )
133
-
134
- with gr.Blocks(css="style.css") as demo:
135
- gr.Markdown(DESCRIPTION)
136
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
137
- chat_interface.render()
138
- gr.Markdown(LICENSE)
139
 
140
- if __name__ == "__main__":
141
- demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
 
 
 
2
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ def get_text_response(prompt):
5
+ api_url = "http://35.233.231.20:5002/api/generate"
6
+ data_payload = {
7
+ "model": "llama2",
8
+ "prompt": prompt,
9
+ "stream": False
10
+ }
11
+ response = requests.post(api_url, json=data_payload)
12
+ response_json = response.json()
13
+ text_response = response_json.get("response", "No response received.")
14
+ return text_response
15
+
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ prompt = gr.Textbox(label="Enter your prompt")
19
+ submit_button = gr.Button("Submit")
20
+ output = gr.Textbox(label="Response")
21
+
22
+ submit_button.click(fn=get_text_response, inputs=prompt, outputs=output)
23
+
24
+ demo.launch()