Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -30,25 +30,6 @@ MODEL_OPTIONS = [
|
|
30 |
models = {}
|
31 |
tokenizers = {}
|
32 |
|
33 |
-
# Custom chat templates
|
34 |
-
MISTRAL_TEMPLATE = """<s>[INST] {instruction} [/INST]
|
35 |
-
{response}
|
36 |
-
</s>
|
37 |
-
<s>[INST] {instruction} [/INST]
|
38 |
-
"""
|
39 |
-
|
40 |
-
LLAMA_TEMPLATE = """<s>[INST] <<SYS>>
|
41 |
-
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
42 |
-
|
43 |
-
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
44 |
-
<</SYS>>
|
45 |
-
|
46 |
-
{instruction} [/INST]
|
47 |
-
{response}
|
48 |
-
</s>
|
49 |
-
<s>[INST] {instruction} [/INST]
|
50 |
-
"""
|
51 |
-
|
52 |
for model_id in MODEL_OPTIONS:
|
53 |
tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
|
54 |
models[model_id] = AutoModelForCausalLM.from_pretrained(
|
@@ -58,11 +39,9 @@ for model_id in MODEL_OPTIONS:
|
|
58 |
)
|
59 |
models[model_id].eval()
|
60 |
|
61 |
-
# Set
|
62 |
-
if
|
63 |
-
tokenizers[model_id].
|
64 |
-
elif "OpenHathi" in model_id:
|
65 |
-
tokenizers[model_id].chat_template = LLAMA_TEMPLATE
|
66 |
|
67 |
# Initialize Flask app
|
68 |
app = Flask(__name__)
|
@@ -74,6 +53,25 @@ def log_results():
|
|
74 |
print("Logged:", json.dumps(data, indent=2))
|
75 |
return jsonify({"status": "success"}), 200
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
@spaces.GPU(duration=90)
|
78 |
def generate(
|
79 |
model_id: str,
|
@@ -86,29 +84,28 @@ def generate(
|
|
86 |
model = models[model_id]
|
87 |
tokenizer = tokenizers[model_id]
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
{"role": "user", "content": user},
|
93 |
-
{"role": "assistant", "content": assistant},
|
94 |
-
])
|
95 |
-
conversation.append({"role": "user", "content": message})
|
96 |
|
97 |
-
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
|
98 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
99 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
|
|
100 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
101 |
input_ids = input_ids.to(model.device)
|
|
|
102 |
|
103 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
104 |
generate_kwargs = dict(
|
105 |
input_ids=input_ids,
|
|
|
106 |
streamer=streamer,
|
107 |
max_new_tokens=max_new_tokens,
|
108 |
do_sample=True,
|
109 |
top_p=top_p,
|
110 |
temperature=temperature,
|
111 |
num_beams=1,
|
|
|
112 |
)
|
113 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
114 |
t.start()
|
@@ -215,5 +212,5 @@ if __name__ == "__main__":
|
|
215 |
flask_thread = Thread(target=app.run, kwargs={"host": "0.0.0.0", "port": 5000})
|
216 |
flask_thread.start()
|
217 |
|
218 |
-
# Start Gradio app
|
219 |
-
demo.queue(max_size=10).launch()
|
|
|
30 |
models = {}
|
31 |
tokenizers = {}
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
for model_id in MODEL_OPTIONS:
|
34 |
tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
|
35 |
models[model_id] = AutoModelForCausalLM.from_pretrained(
|
|
|
39 |
)
|
40 |
models[model_id].eval()
|
41 |
|
42 |
+
# Set pad_token_id to eos_token_id if it's not set
|
43 |
+
if tokenizers[model_id].pad_token_id is None:
|
44 |
+
tokenizers[model_id].pad_token_id = tokenizers[model_id].eos_token_id
|
|
|
|
|
45 |
|
46 |
# Initialize Flask app
|
47 |
app = Flask(__name__)
|
|
|
53 |
print("Logged:", json.dumps(data, indent=2))
|
54 |
return jsonify({"status": "success"}), 200
|
55 |
|
56 |
+
def prepare_input(model_id: str, message: str, chat_history: List[Tuple[str, str]]):
|
57 |
+
if "OpenHathi" in model_id:
|
58 |
+
# OpenHathi model doesn't use a specific chat template
|
59 |
+
full_prompt = message
|
60 |
+
for history_message in chat_history:
|
61 |
+
full_prompt = f"{history_message[0]}\n{history_message[1]}\n{full_prompt}"
|
62 |
+
return tokenizers[model_id](full_prompt, return_tensors="pt")
|
63 |
+
elif "Navarna" in model_id:
|
64 |
+
# Navarna model uses a chat template
|
65 |
+
conversation = []
|
66 |
+
for user, assistant in chat_history:
|
67 |
+
conversation.extend([
|
68 |
+
{"role": "user", "content": user},
|
69 |
+
{"role": "assistant", "content": assistant},
|
70 |
+
])
|
71 |
+
conversation.append({"role": "user", "content": message})
|
72 |
+
prompt = tokenizers[model_id].apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
|
73 |
+
return tokenizers[model_id](prompt, return_tensors="pt")
|
74 |
+
|
75 |
@spaces.GPU(duration=90)
|
76 |
def generate(
|
77 |
model_id: str,
|
|
|
84 |
model = models[model_id]
|
85 |
tokenizer = tokenizers[model_id]
|
86 |
|
87 |
+
inputs = prepare_input(model_id, message, chat_history)
|
88 |
+
input_ids = inputs.input_ids
|
89 |
+
attention_mask = inputs.attention_mask
|
|
|
|
|
|
|
|
|
90 |
|
|
|
91 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
92 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
93 |
+
attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
|
94 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
95 |
input_ids = input_ids.to(model.device)
|
96 |
+
attention_mask = attention_mask.to(model.device)
|
97 |
|
98 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
99 |
generate_kwargs = dict(
|
100 |
input_ids=input_ids,
|
101 |
+
attention_mask=attention_mask,
|
102 |
streamer=streamer,
|
103 |
max_new_tokens=max_new_tokens,
|
104 |
do_sample=True,
|
105 |
top_p=top_p,
|
106 |
temperature=temperature,
|
107 |
num_beams=1,
|
108 |
+
pad_token_id=tokenizer.eos_token_id,
|
109 |
)
|
110 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
111 |
t.start()
|
|
|
212 |
flask_thread = Thread(target=app.run, kwargs={"host": "0.0.0.0", "port": 5000})
|
213 |
flask_thread.start()
|
214 |
|
215 |
+
# Start Gradio app with public link
|
216 |
+
demo.queue(max_size=10).launch(share=True)
|