Spaces:
Sleeping
Sleeping
:zap: [Enhance] Auto calculate max_tokens if not set
Browse files- apis/chat_api.py +3 -3
- networks/message_streamer.py +28 -1
- requirements.txt +1 -0
apis/chat_api.py
CHANGED
|
@@ -56,7 +56,7 @@ class ChatAPIApp:
|
|
| 56 |
if api_key.startswith("hf_"):
|
| 57 |
return api_key
|
| 58 |
else:
|
| 59 |
-
logger.warn(f"Invalid HF Token")
|
| 60 |
else:
|
| 61 |
logger.warn("Not provide HF Token!")
|
| 62 |
return None
|
|
@@ -71,11 +71,11 @@ class ChatAPIApp:
|
|
| 71 |
description="(list) Messages",
|
| 72 |
)
|
| 73 |
temperature: float = Field(
|
| 74 |
-
default=0
|
| 75 |
description="(float) Temperature",
|
| 76 |
)
|
| 77 |
max_tokens: int = Field(
|
| 78 |
-
default
|
| 79 |
description="(int) Max tokens",
|
| 80 |
)
|
| 81 |
stream: bool = Field(
|
|
|
|
| 56 |
if api_key.startswith("hf_"):
|
| 57 |
return api_key
|
| 58 |
else:
|
| 59 |
+
logger.warn(f"Invalid HF Token!")
|
| 60 |
else:
|
| 61 |
logger.warn("Not provide HF Token!")
|
| 62 |
return None
|
|
|
|
| 71 |
description="(list) Messages",
|
| 72 |
)
|
| 73 |
temperature: float = Field(
|
| 74 |
+
default=0,
|
| 75 |
description="(float) Temperature",
|
| 76 |
)
|
| 77 |
max_tokens: int = Field(
|
| 78 |
+
default=-1,
|
| 79 |
description="(int) Max tokens",
|
| 80 |
)
|
| 81 |
stream: bool = Field(
|
networks/message_streamer.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import json
|
| 2 |
import re
|
| 3 |
import requests
|
|
|
|
| 4 |
from messagers.message_outputer import OpenaiStreamOutputer
|
| 5 |
from utils.logger import logger
|
| 6 |
from utils.enver import enver
|
|
@@ -22,6 +23,12 @@ class MessageStreamer:
|
|
| 22 |
"mistral-7b": "</s>",
|
| 23 |
"openchat-3.5": "<|end_of_turn|>",
|
| 24 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def __init__(self, model: str):
|
| 27 |
if model in self.MODEL_MAP.keys():
|
|
@@ -30,6 +37,7 @@ class MessageStreamer:
|
|
| 30 |
self.model = "default"
|
| 31 |
self.model_fullname = self.MODEL_MAP[self.model]
|
| 32 |
self.message_outputer = OpenaiStreamOutputer()
|
|
|
|
| 33 |
|
| 34 |
def parse_line(self, line):
|
| 35 |
line = line.decode("utf-8")
|
|
@@ -38,11 +46,17 @@ class MessageStreamer:
|
|
| 38 |
content = data["token"]["text"]
|
| 39 |
return content
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
def chat_response(
|
| 42 |
self,
|
| 43 |
prompt: str = None,
|
| 44 |
temperature: float = 0.01,
|
| 45 |
-
max_new_tokens: int =
|
| 46 |
api_key: str = None,
|
| 47 |
):
|
| 48 |
# https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
|
|
@@ -60,6 +74,19 @@ class MessageStreamer:
|
|
| 60 |
)
|
| 61 |
self.request_headers["Authorization"] = f"Bearer {api_key}"
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
# References:
|
| 64 |
# huggingface_hub/inference/_client.py:
|
| 65 |
# class InferenceClient > def text_generation()
|
|
|
|
| 1 |
import json
|
| 2 |
import re
|
| 3 |
import requests
|
| 4 |
+
from tiktoken import get_encoding as tiktoken_get_encoding
|
| 5 |
from messagers.message_outputer import OpenaiStreamOutputer
|
| 6 |
from utils.logger import logger
|
| 7 |
from utils.enver import enver
|
|
|
|
| 23 |
"mistral-7b": "</s>",
|
| 24 |
"openchat-3.5": "<|end_of_turn|>",
|
| 25 |
}
|
| 26 |
+
TOKEN_LIMIT_MAP = {
|
| 27 |
+
"mixtral-8x7b": 32768,
|
| 28 |
+
"mistral-7b": 32768,
|
| 29 |
+
"openchat-3.5": 8192,
|
| 30 |
+
}
|
| 31 |
+
TOKEN_RESERVED = 32
|
| 32 |
|
| 33 |
def __init__(self, model: str):
|
| 34 |
if model in self.MODEL_MAP.keys():
|
|
|
|
| 37 |
self.model = "default"
|
| 38 |
self.model_fullname = self.MODEL_MAP[self.model]
|
| 39 |
self.message_outputer = OpenaiStreamOutputer()
|
| 40 |
+
self.tokenizer = tiktoken_get_encoding("cl100k_base")
|
| 41 |
|
| 42 |
def parse_line(self, line):
|
| 43 |
line = line.decode("utf-8")
|
|
|
|
| 46 |
content = data["token"]["text"]
|
| 47 |
return content
|
| 48 |
|
| 49 |
+
def count_tokens(self, text):
|
| 50 |
+
tokens = self.tokenizer.encode(text)
|
| 51 |
+
token_count = len(tokens)
|
| 52 |
+
logger.note(f"Prompt Token Count: {token_count}")
|
| 53 |
+
return token_count
|
| 54 |
+
|
| 55 |
def chat_response(
|
| 56 |
self,
|
| 57 |
prompt: str = None,
|
| 58 |
temperature: float = 0.01,
|
| 59 |
+
max_new_tokens: int = None,
|
| 60 |
api_key: str = None,
|
| 61 |
):
|
| 62 |
# https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
|
|
|
|
| 74 |
)
|
| 75 |
self.request_headers["Authorization"] = f"Bearer {api_key}"
|
| 76 |
|
| 77 |
+
token_limit = (
|
| 78 |
+
self.TOKEN_LIMIT_MAP[self.model]
|
| 79 |
+
- self.TOKEN_RESERVED
|
| 80 |
+
- self.count_tokens(prompt)
|
| 81 |
+
)
|
| 82 |
+
if token_limit <= 0:
|
| 83 |
+
raise ValueError("Prompt exceeded token limit!")
|
| 84 |
+
|
| 85 |
+
if max_new_tokens is None or max_new_tokens <= 0:
|
| 86 |
+
max_new_tokens = token_limit
|
| 87 |
+
else:
|
| 88 |
+
max_new_tokens = min(max_new_tokens, token_limit)
|
| 89 |
+
|
| 90 |
# References:
|
| 91 |
# huggingface_hub/inference/_client.py:
|
| 92 |
# class InferenceClient > def text_generation()
|
requirements.txt
CHANGED
|
@@ -6,5 +6,6 @@ pydantic
|
|
| 6 |
requests
|
| 7 |
sse_starlette
|
| 8 |
termcolor
|
|
|
|
| 9 |
uvicorn
|
| 10 |
websockets
|
|
|
|
| 6 |
requests
|
| 7 |
sse_starlette
|
| 8 |
termcolor
|
| 9 |
+
tiktoken
|
| 10 |
uvicorn
|
| 11 |
websockets
|