Omnibus commited on
Commit
a2f95d7
1 Parent(s): ed72db5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -15
app.py CHANGED
@@ -33,6 +33,7 @@ def format_prompt(message, history):
33
  agents =[
34
  "COMMENTER",
35
  "BLOG_POSTER",
 
36
  "COMPRESS_HISTORY_PROMPT"
37
  ]
38
 
@@ -102,6 +103,38 @@ def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temp
102
  #history.append((output,history))
103
 
104
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def create_valid_filename(invalid_filename: str) -> str:
106
  """Converts invalid characters in a string to be suitable for a filename."""
107
  invalid_filename.replace(" ","-")
@@ -114,7 +147,7 @@ def create_valid_filename(invalid_filename: str) -> str:
114
  return ''.join(char for char in valid_chars if char in allowed_chars)
115
 
116
  def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,):
117
- main_point[0]=prompt
118
  #print(datetime.datetime.now())
119
  uid=uuid.uuid4()
120
  current_time = str(datetime.datetime.now())
@@ -133,22 +166,41 @@ def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0
133
  json_hist={}
134
  json_obj={}
135
  filename=create_valid_filename(f'{prompt}---{current_time}')
 
136
  while True:
137
  seed = random.randint(1,1111111111111111)
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- generate_kwargs = dict(
140
- temperature=temperature,
141
- max_new_tokens=max_new_tokens2,
142
- top_p=top_p,
143
- repetition_penalty=repetition_penalty,
144
- do_sample=True,
145
- seed=seed,
146
- )
147
- if prompt.startswith(' \"'):
148
- prompt=prompt.strip(' \"')
149
- formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
 
 
 
 
 
150
  if len(formatted_prompt) < (50000):
151
  print(len(formatted_prompt))
 
 
152
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
153
  output = ""
154
  #if history:
@@ -196,15 +248,18 @@ def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0
196
  f.close()
197
  upload_file(
198
  path_or_fileobj =f"{uid}-sum.json",
199
- path_in_repo = f"summary/{filename}-summary.json",
200
  repo_id =f"{username}/{dataset_name}",
201
  repo_type = "dataset",
202
  token=token,
203
  )
204
 
205
-
206
  prompt = question_generate(output, history)
207
-
 
 
 
208
  return prompt, history, summary[0],json_obj,json_hist
209
 
210
  def load_html():
 
33
  agents =[
34
  "COMMENTER",
35
  "BLOG_POSTER",
36
+ "REPLY_TO_COMMENTER",
37
  "COMPRESS_HISTORY_PROMPT"
38
  ]
39
 
 
103
  #history.append((output,history))
104
 
105
  return output
106
+
107
+ def blog_poster_reply(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
108
+ #def question_generate(prompt, history):
109
+ seed = random.randint(1,1111111111111111)
110
+ agent=prompts.REPLY_TO_COMMENTER.format(focus=main_point[0])
111
+ system_prompt=agent
112
+ temperature = float(temperature)
113
+ if temperature < 1e-2:
114
+ temperature = 1e-2
115
+ top_p = float(top_p)
116
+
117
+ generate_kwargs = dict(
118
+ temperature=temperature,
119
+ max_new_tokens=max_new_tokens,
120
+ top_p=top_p,
121
+ repetition_penalty=repetition_penalty,
122
+ do_sample=True,
123
+ seed=seed,
124
+ )
125
+ #history.append((prompt,""))
126
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
127
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
128
+ output = ""
129
+
130
+ for response in stream:
131
+ output += response.token.text
132
+ #history.append((output,history))
133
+
134
+ return output
135
+
136
+
137
+
138
  def create_valid_filename(invalid_filename: str) -> str:
139
  """Converts invalid characters in a string to be suitable for a filename."""
140
  invalid_filename.replace(" ","-")
 
147
  return ''.join(char for char in valid_chars if char in allowed_chars)
148
 
149
  def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,):
150
+ #main_point[0]=prompt
151
  #print(datetime.datetime.now())
152
  uid=uuid.uuid4()
153
  current_time = str(datetime.datetime.now())
 
166
  json_hist={}
167
  json_obj={}
168
  filename=create_valid_filename(f'{prompt}---{current_time}')
169
+ post_cnt=1
170
  while True:
171
  seed = random.randint(1,1111111111111111)
172
+ if post_cnt==1:
173
+ generate_kwargs = dict(
174
+ temperature=temperature,
175
+ max_new_tokens=max_new_tokens2,
176
+ top_p=top_p,
177
+ repetition_penalty=repetition_penalty,
178
+ do_sample=True,
179
+ seed=seed,
180
+ )
181
+ if prompt.startswith(' \"'):
182
+ prompt=prompt.strip(' \"')
183
 
184
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
185
+ else:
186
+ system_prompt=prompts.REPLY_TO_COMMENTER.format(focus=main_points[0])
187
+
188
+ generate_kwargs = dict(
189
+ temperature=temperature,
190
+ max_new_tokens=max_new_tokens2,
191
+ top_p=top_p,
192
+ repetition_penalty=repetition_penalty,
193
+ do_sample=True,
194
+ seed=seed,
195
+ )
196
+ if prompt.startswith(' \"'):
197
+ prompt=prompt.strip(' \"')
198
+
199
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
200
  if len(formatted_prompt) < (50000):
201
  print(len(formatted_prompt))
202
+
203
+
204
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
205
  output = ""
206
  #if history:
 
248
  f.close()
249
  upload_file(
250
  path_or_fileobj =f"{uid}-sum.json",
251
+ path_in_repo = f"book1/{filename}-summary.json",
252
  repo_id =f"{username}/{dataset_name}",
253
  repo_type = "dataset",
254
  token=token,
255
  )
256
 
257
+ else:
258
  prompt = question_generate(output, history)
259
+
260
+ prompt = question_generate(output, history)
261
+ main_point[0]=prompt
262
+
263
  return prompt, history, summary[0],json_obj,json_hist
264
 
265
  def load_html():