Omnibus commited on
Commit
e72d7b7
1 Parent(s): 6657883

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py CHANGED
@@ -1,4 +1,211 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def load_html():
4
  with open('index.html','r') as h:
 
1
+ from huggingface_hub import InferenceClient, HfApi, upload_file
2
+ import datetime
3
  import gradio as gr
4
+ import random
5
+ import prompts
6
+ import json
7
+ import uuid
8
+ import os
9
+
10
+
11
+
12
+ token=os.environ.get("HF_TOKEN")
13
+ username="omnibus"
14
+ dataset_name="tmp2"
15
+ api=HfApi(token="")
16
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
17
+
18
+ history = []
19
+ hist_out= []
20
+ summary =[]
21
+ main_point=[]
22
+ summary.append("")
23
+ main_point.append("")
24
+
25
+ def format_prompt(message, history):
26
+ prompt = "<s>"
27
+ for user_prompt, bot_response in history:
28
+ prompt += f"[INST] {user_prompt} [/INST]"
29
+ prompt += f" {bot_response}</s> "
30
+ prompt += f"[INST] {message} [/INST]"
31
+ return prompt
32
+
33
+ agents =[
34
+ "COMMENTER",
35
+ "BLOG_POSTER",
36
+ "COMPRESS_HISTORY_PROMPT"
37
+ ]
38
+
39
+ temperature=0.9
40
+ max_new_tokens=256
41
+ max_new_tokens2=1048
42
+ top_p=0.95
43
+ repetition_penalty=1.0,
44
+
45
+ def compress_history(formatted_prompt):
46
+
47
+ seed = random.randint(1,1111111111111111)
48
+ agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])
49
+
50
+ system_prompt=agent
51
+ temperature = 0.9
52
+ if temperature < 1e-2:
53
+ temperature = 1e-2
54
+
55
+ generate_kwargs = dict(
56
+ temperature=temperature,
57
+ max_new_tokens=30480,
58
+ top_p=0.95,
59
+ repetition_penalty=1.0,
60
+ do_sample=True,
61
+ seed=seed,
62
+ )
63
+ #history.append((prompt,""))
64
+ #formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
65
+ formatted_prompt = formatted_prompt
66
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
67
+ output = ""
68
+
69
+ for response in stream:
70
+ output += response.token.text
71
+ #history.append((output,history))
72
+ print(output)
73
+ print(main_point[0])
74
+ return output
75
+
76
+
77
+ def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
78
+ #def question_generate(prompt, history):
79
+ seed = random.randint(1,1111111111111111)
80
+ agent=prompts.COMMENTER.format(focus=main_point[0])
81
+ system_prompt=agent
82
+ temperature = float(temperature)
83
+ if temperature < 1e-2:
84
+ temperature = 1e-2
85
+ top_p = float(top_p)
86
+
87
+ generate_kwargs = dict(
88
+ temperature=temperature,
89
+ max_new_tokens=max_new_tokens,
90
+ top_p=top_p,
91
+ repetition_penalty=repetition_penalty,
92
+ do_sample=True,
93
+ seed=seed,
94
+ )
95
+ #history.append((prompt,""))
96
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
97
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
98
+ output = ""
99
+
100
+ for response in stream:
101
+ output += response.token.text
102
+ #history.append((output,history))
103
+
104
+ return output
105
+ def create_valid_filename(invalid_filename: str) -> str:
106
+ """Converts invalid characters in a string to be suitable for a filename."""
107
+ invalid_filename.replace(" ","-")
108
+ valid_chars = '-'.join(invalid_filename.split())
109
+ allowed_chars = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
110
+ 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
111
+ 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
112
+ 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
113
+ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '_', '-')
114
+ return ''.join(char for char in valid_chars if char in allowed_chars)
115
+
116
+ def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,):
117
+ main_point[0]=prompt
118
+ #print(datetime.datetime.now())
119
+ uid=uuid.uuid4()
120
+ current_time = str(datetime.datetime.now())
121
+
122
+ current_time=current_time.replace(":","-")
123
+ current_time=current_time.replace(".","-")
124
+ print (current_time)
125
+ agent=prompts.BLOG_POSTER
126
+ system_prompt=agent
127
+ temperature = float(temperature)
128
+ if temperature < 1e-2:
129
+ temperature = 1e-2
130
+ top_p = float(top_p)
131
+ hist_out=[]
132
+ sum_out=[]
133
+ json_hist={}
134
+ json_obj={}
135
+ filename=create_valid_filename(f'{prompt}---{current_time}')
136
+ while True:
137
+ seed = random.randint(1,1111111111111111)
138
+
139
+ generate_kwargs = dict(
140
+ temperature=temperature,
141
+ max_new_tokens=max_new_tokens2,
142
+ top_p=top_p,
143
+ repetition_penalty=repetition_penalty,
144
+ do_sample=True,
145
+ seed=seed,
146
+ )
147
+ if prompt.startswith(' \"'):
148
+ prompt=prompt.strip(' \"')
149
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
150
+ if len(formatted_prompt) < (50000):
151
+ print(len(formatted_prompt))
152
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
153
+ output = ""
154
+ #if history:
155
+ # yield history
156
+
157
+ for response in stream:
158
+ output += response.token.text
159
+ yield '', [(prompt,output)],summary[0],json_obj, json_hist
160
+ out_json = {"prompt":prompt,"output":output}
161
+
162
+ prompt = question_generate(output, history)
163
+ #output += prompt
164
+ history.append((prompt,output))
165
+ print ( f'Prompt:: {len(prompt)}')
166
+ #print ( f'output:: {output}')
167
+ print ( f'history:: {len(formatted_prompt)}')
168
+ hist_out.append(out_json)
169
+ #try:
170
+ # for ea in
171
+ with open(f'{uid}.json', 'w') as f:
172
+ json_hist=json.dumps(hist_out, indent=4)
173
+ f.write(json_hist)
174
+ f.close()
175
+
176
+ upload_file(
177
+ path_or_fileobj =f"{uid}.json",
178
+ path_in_repo = f"test/{filename}.json",
179
+ repo_id =f"{username}/{dataset_name}",
180
+ repo_type = "dataset",
181
+ token=token,
182
+ )
183
+ else:
184
+ formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history)
185
+
186
+ #current_time = str(datetime.datetime.now().timestamp()).split(".",1)[0]
187
+ #filename=f'{filename}-{current_time}'
188
+ history = []
189
+ output = compress_history(formatted_prompt)
190
+ summary[0]=output
191
+ sum_json = {"summary":summary[0]}
192
+ sum_out.append(sum_json)
193
+ with open(f'{uid}-sum.json', 'w') as f:
194
+ json_obj=json.dumps(sum_out, indent=4)
195
+ f.write(json_obj)
196
+ f.close()
197
+ upload_file(
198
+ path_or_fileobj =f"{uid}-sum.json",
199
+ path_in_repo = f"summary/{filename}-summary.json",
200
+ repo_id =f"{username}/{dataset_name}",
201
+ repo_type = "dataset",
202
+ token=token,
203
+ )
204
+
205
+
206
+ prompt = question_generate(output, history)
207
+
208
+ return prompt, history, summary[0],json_obj,json_hist
209
 
210
  def load_html():
211
  with open('index.html','r') as h: