archit11 commited on
Commit
e1e8c19
·
verified ·
1 Parent(s): 1050d06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -90,9 +90,10 @@ def compare_models(
90
  max_new_tokens: int,
91
  temperature: float,
92
  top_p: float,
93
- ) -> Tuple[str, str, List[Tuple[str, str]], List[Tuple[str, str]]]:
94
  if model1_name == model2_name:
95
- return "Error: Please select two different models.", "Error: Please select two different models.", chat_history1, chat_history2
 
96
 
97
  output1 = "".join(list(generate(model1_name, message, chat_history1, max_new_tokens, temperature, top_p)))
98
  output2 = "".join(list(generate(model2_name, message, chat_history2, max_new_tokens, temperature, top_p)))
@@ -102,7 +103,7 @@ def compare_models(
102
 
103
  log_results(model1_name, model2_name, message, output1, output2)
104
 
105
- return output1, output2, chat_history1, chat_history2
106
 
107
  def log_results(model1_name: str, model2_name: str, question: str, answer1: str, answer2: str, winner: str = None):
108
  log_data = {
@@ -165,4 +166,4 @@ with gr.Blocks(css="style.css") as demo:
165
  )
166
 
167
  if __name__ == "__main__":
168
- demo.queue(max_size=10).launch()
 
90
  max_new_tokens: int,
91
  temperature: float,
92
  top_p: float,
93
+ ) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]], List[Tuple[str, str]]]:
94
  if model1_name == model2_name:
95
+ error_message = [("System", "Error: Please select two different models.")]
96
+ return error_message, error_message, chat_history1, chat_history2
97
 
98
  output1 = "".join(list(generate(model1_name, message, chat_history1, max_new_tokens, temperature, top_p)))
99
  output2 = "".join(list(generate(model2_name, message, chat_history2, max_new_tokens, temperature, top_p)))
 
103
 
104
  log_results(model1_name, model2_name, message, output1, output2)
105
 
106
+ return chat_history1, chat_history2, chat_history1, chat_history2
107
 
108
  def log_results(model1_name: str, model2_name: str, question: str, answer1: str, answer2: str, winner: str = None):
109
  log_data = {
 
166
  )
167
 
168
  if __name__ == "__main__":
169
+ demo.queue(max_size=10).launch()