Spaces:
Runtime error
Runtime error
Commit
·
29ea64b
1
Parent(s):
c596b21
Update with h2oGPT hash c7453f1b1ab51fb1cd342a9867f23cd7b538e000
Browse files- src/client_test.py +8 -17
- src/gen.py +7 -44
- src/gpt_langchain.py +279 -337
- src/gradio_runner.py +10 -7
src/client_test.py
CHANGED
|
@@ -80,7 +80,7 @@ def get_args(prompt, prompt_type=None, chat=False, stream_output=False,
|
|
| 80 |
version=None,
|
| 81 |
h2ogpt_key=None,
|
| 82 |
visible_models=None,
|
| 83 |
-
system_prompt='', # default of no system prompt
|
| 84 |
add_search_to_context=False,
|
| 85 |
chat_conversation=None,
|
| 86 |
text_context_list=None,
|
|
@@ -256,18 +256,13 @@ def run_client_nochat_api(prompt, prompt_type, max_new_tokens, version=None, h2o
|
|
| 256 |
|
| 257 |
|
| 258 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
| 259 |
-
def test_client_basic_api_lean(
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
version=version, h2ogpt_key=h2ogpt_key,
|
| 263 |
-
chat_conversation=chat_conversation,
|
| 264 |
-
system_prompt=system_prompt)
|
| 265 |
|
| 266 |
|
| 267 |
-
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None
|
| 268 |
-
|
| 269 |
-
kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key, chat_conversation=chat_conversation,
|
| 270 |
-
system_prompt=system_prompt)
|
| 271 |
|
| 272 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
| 273 |
client = get_client(serialize=True)
|
|
@@ -367,9 +362,7 @@ def run_client_chat(prompt='',
|
|
| 367 |
langchain_agents=[],
|
| 368 |
prompt_type=None, prompt_dict=None,
|
| 369 |
version=None,
|
| 370 |
-
h2ogpt_key=None
|
| 371 |
-
chat_conversation=None,
|
| 372 |
-
system_prompt=''):
|
| 373 |
client = get_client(serialize=False)
|
| 374 |
|
| 375 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
|
@@ -379,9 +372,7 @@ def run_client_chat(prompt='',
|
|
| 379 |
langchain_agents=langchain_agents,
|
| 380 |
prompt_dict=prompt_dict,
|
| 381 |
version=version,
|
| 382 |
-
h2ogpt_key=h2ogpt_key
|
| 383 |
-
chat_conversation=chat_conversation,
|
| 384 |
-
system_prompt=system_prompt)
|
| 385 |
return run_client(client, prompt, args, kwargs)
|
| 386 |
|
| 387 |
|
|
|
|
| 80 |
version=None,
|
| 81 |
h2ogpt_key=None,
|
| 82 |
visible_models=None,
|
| 83 |
+
system_prompt='', # default of no system prompt tiggered by empty string
|
| 84 |
add_search_to_context=False,
|
| 85 |
chat_conversation=None,
|
| 86 |
text_context_list=None,
|
|
|
|
| 256 |
|
| 257 |
|
| 258 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
| 259 |
+
def test_client_basic_api_lean(prompt_type='human_bot', version=None, h2ogpt_key=None):
|
| 260 |
+
return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50,
|
| 261 |
+
version=version, h2ogpt_key=h2ogpt_key)
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
|
| 264 |
+
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None):
|
| 265 |
+
kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key)
|
|
|
|
|
|
|
| 266 |
|
| 267 |
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
| 268 |
client = get_client(serialize=True)
|
|
|
|
| 362 |
langchain_agents=[],
|
| 363 |
prompt_type=None, prompt_dict=None,
|
| 364 |
version=None,
|
| 365 |
+
h2ogpt_key=None):
|
|
|
|
|
|
|
| 366 |
client = get_client(serialize=False)
|
| 367 |
|
| 368 |
kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output,
|
|
|
|
| 372 |
langchain_agents=langchain_agents,
|
| 373 |
prompt_dict=prompt_dict,
|
| 374 |
version=version,
|
| 375 |
+
h2ogpt_key=h2ogpt_key)
|
|
|
|
|
|
|
| 376 |
return run_client(client, prompt, args, kwargs)
|
| 377 |
|
| 378 |
|
src/gen.py
CHANGED
|
@@ -335,7 +335,7 @@ def main(
|
|
| 335 |
|
| 336 |
Or Address can be for vLLM:
|
| 337 |
Use: "vllm:IP:port" for OpenAI-compliant vLLM endpoint
|
| 338 |
-
|
| 339 |
|
| 340 |
Or Address can be replicate:
|
| 341 |
Use:
|
|
@@ -2236,17 +2236,6 @@ def evaluate(
|
|
| 2236 |
instruction = instruction_nochat
|
| 2237 |
iinput = iinput_nochat
|
| 2238 |
|
| 2239 |
-
# avoid instruction in chat_conversation itself, since always used as additional context to prompt in what follows
|
| 2240 |
-
if isinstance(chat_conversation, list) and \
|
| 2241 |
-
len(chat_conversation) > 0 and \
|
| 2242 |
-
len(chat_conversation[-1]) == 2 and \
|
| 2243 |
-
chat_conversation[-1][0] == instruction:
|
| 2244 |
-
chat_conversation = chat_conversation[:-1]
|
| 2245 |
-
if not add_chat_history_to_context:
|
| 2246 |
-
# make it easy to ignore without needing add_chat_history_to_context
|
| 2247 |
-
# some langchain or unit test may need to then handle more general case
|
| 2248 |
-
chat_conversation = []
|
| 2249 |
-
|
| 2250 |
# in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
|
| 2251 |
model_lower = base_model.lower()
|
| 2252 |
if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom':
|
|
@@ -2495,8 +2484,7 @@ def evaluate(
|
|
| 2495 |
prompt, \
|
| 2496 |
instruction, iinput, context, \
|
| 2497 |
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
|
| 2498 |
-
chat_index,
|
| 2499 |
-
top_k_docs_trial, one_doc_size = \
|
| 2500 |
get_limited_prompt(instruction,
|
| 2501 |
iinput,
|
| 2502 |
tokenizer,
|
|
@@ -2564,6 +2552,8 @@ def evaluate(
|
|
| 2564 |
sanitize_bot_response=sanitize_bot_response)
|
| 2565 |
yield dict(response=response, sources=sources, save_dict=dict())
|
| 2566 |
elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
|
|
|
|
|
|
|
| 2567 |
if system_prompt in [None, 'None', 'auto']:
|
| 2568 |
openai_system_prompt = "You are a helpful assistant."
|
| 2569 |
else:
|
|
@@ -2571,16 +2561,7 @@ def evaluate(
|
|
| 2571 |
messages0 = []
|
| 2572 |
if openai_system_prompt:
|
| 2573 |
messages0.append({"role": "system", "content": openai_system_prompt})
|
| 2574 |
-
|
| 2575 |
-
assert external_handle_chat_conversation, "Should be handling only externally"
|
| 2576 |
-
# chat_index handles token counting issues
|
| 2577 |
-
for message1 in chat_conversation[chat_index:]:
|
| 2578 |
-
if len(message1) == 2:
|
| 2579 |
-
messages0.append(
|
| 2580 |
-
{'role': 'user', 'content': message1[0] if message1[0] is not None else ''})
|
| 2581 |
-
messages0.append(
|
| 2582 |
-
{'role': 'assistant', 'content': message1[1] if message1[1] is not None else ''})
|
| 2583 |
-
messages0.append({'role': 'user', 'content': prompt if prompt is not None else ''})
|
| 2584 |
responses = openai.ChatCompletion.create(
|
| 2585 |
model=base_model,
|
| 2586 |
messages=messages0,
|
|
@@ -3628,27 +3609,13 @@ def get_limited_prompt(instruction,
|
|
| 3628 |
stream_output = prompter.stream_output
|
| 3629 |
system_prompt = prompter.system_prompt
|
| 3630 |
|
| 3631 |
-
generate_prompt_type = prompt_type
|
| 3632 |
-
external_handle_chat_conversation = False
|
| 3633 |
-
if any(inference_server.startswith(x) for x in ['openai_chat', 'openai_azure_chat', 'vllm_chat']):
|
| 3634 |
-
# Chat APIs do not take prompting
|
| 3635 |
-
# Replicate does not need prompting if no chat history, but in general can take prompting
|
| 3636 |
-
# if using prompter, prompter.system_prompt will already be filled with automatic (e.g. from llama-2),
|
| 3637 |
-
# so if replicate final prompt with system prompt still correct because only access prompter.system_prompt that was already set
|
| 3638 |
-
# below already true for openai,
|
| 3639 |
-
# but not vllm by default as that can be any model and handled by FastChat API inside vLLM itself
|
| 3640 |
-
generate_prompt_type = 'plain'
|
| 3641 |
-
# Chat APIs don't handle chat history via single prompt, but in messages, assumed to be handled outside this function
|
| 3642 |
-
chat_conversation = []
|
| 3643 |
-
external_handle_chat_conversation = True
|
| 3644 |
-
|
| 3645 |
# merge handles if chat_conversation is None
|
| 3646 |
history = []
|
| 3647 |
history = merge_chat_conversation_history(chat_conversation, history)
|
| 3648 |
history_to_context_func = functools.partial(history_to_context,
|
| 3649 |
langchain_mode=langchain_mode,
|
| 3650 |
add_chat_history_to_context=add_chat_history_to_context,
|
| 3651 |
-
prompt_type=
|
| 3652 |
prompt_dict=prompt_dict,
|
| 3653 |
chat=chat,
|
| 3654 |
model_max_length=model_max_length,
|
|
@@ -3781,9 +3748,6 @@ def get_limited_prompt(instruction,
|
|
| 3781 |
stream_output = False # doesn't matter
|
| 3782 |
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output,
|
| 3783 |
system_prompt=system_prompt)
|
| 3784 |
-
if prompt_type != generate_prompt_type:
|
| 3785 |
-
# override just this attribute, keep system_prompt etc. from original prompt_type
|
| 3786 |
-
prompter.prompt_type = generate_prompt_type
|
| 3787 |
|
| 3788 |
data_point = dict(context=context, instruction=instruction, input=iinput)
|
| 3789 |
# handle promptA/promptB addition if really from history.
|
|
@@ -3796,8 +3760,7 @@ def get_limited_prompt(instruction,
|
|
| 3796 |
return prompt, \
|
| 3797 |
instruction, iinput, context, \
|
| 3798 |
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
|
| 3799 |
-
chat_index,
|
| 3800 |
-
top_k_docs, one_doc_size
|
| 3801 |
|
| 3802 |
|
| 3803 |
def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
|
|
|
|
| 335 |
|
| 336 |
Or Address can be for vLLM:
|
| 337 |
Use: "vllm:IP:port" for OpenAI-compliant vLLM endpoint
|
| 338 |
+
Note: vllm_chat not supported by vLLM project.
|
| 339 |
|
| 340 |
Or Address can be replicate:
|
| 341 |
Use:
|
|
|
|
| 2236 |
instruction = instruction_nochat
|
| 2237 |
iinput = iinput_nochat
|
| 2238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2239 |
# in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
|
| 2240 |
model_lower = base_model.lower()
|
| 2241 |
if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom':
|
|
|
|
| 2484 |
prompt, \
|
| 2485 |
instruction, iinput, context, \
|
| 2486 |
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
|
| 2487 |
+
chat_index, top_k_docs_trial, one_doc_size = \
|
|
|
|
| 2488 |
get_limited_prompt(instruction,
|
| 2489 |
iinput,
|
| 2490 |
tokenizer,
|
|
|
|
| 2552 |
sanitize_bot_response=sanitize_bot_response)
|
| 2553 |
yield dict(response=response, sources=sources, save_dict=dict())
|
| 2554 |
elif inf_type == 'vllm_chat' or inference_server == 'openai_chat':
|
| 2555 |
+
if inf_type == 'vllm_chat':
|
| 2556 |
+
raise NotImplementedError('%s not supported by vLLM' % inf_type)
|
| 2557 |
if system_prompt in [None, 'None', 'auto']:
|
| 2558 |
openai_system_prompt = "You are a helpful assistant."
|
| 2559 |
else:
|
|
|
|
| 2561 |
messages0 = []
|
| 2562 |
if openai_system_prompt:
|
| 2563 |
messages0.append({"role": "system", "content": openai_system_prompt})
|
| 2564 |
+
messages0.append({'role': 'user', 'content': prompt})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2565 |
responses = openai.ChatCompletion.create(
|
| 2566 |
model=base_model,
|
| 2567 |
messages=messages0,
|
|
|
|
| 3609 |
stream_output = prompter.stream_output
|
| 3610 |
system_prompt = prompter.system_prompt
|
| 3611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3612 |
# merge handles if chat_conversation is None
|
| 3613 |
history = []
|
| 3614 |
history = merge_chat_conversation_history(chat_conversation, history)
|
| 3615 |
history_to_context_func = functools.partial(history_to_context,
|
| 3616 |
langchain_mode=langchain_mode,
|
| 3617 |
add_chat_history_to_context=add_chat_history_to_context,
|
| 3618 |
+
prompt_type=prompt_type,
|
| 3619 |
prompt_dict=prompt_dict,
|
| 3620 |
chat=chat,
|
| 3621 |
model_max_length=model_max_length,
|
|
|
|
| 3748 |
stream_output = False # doesn't matter
|
| 3749 |
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output,
|
| 3750 |
system_prompt=system_prompt)
|
|
|
|
|
|
|
|
|
|
| 3751 |
|
| 3752 |
data_point = dict(context=context, instruction=instruction, input=iinput)
|
| 3753 |
# handle promptA/promptB addition if really from history.
|
|
|
|
| 3760 |
return prompt, \
|
| 3761 |
instruction, iinput, context, \
|
| 3762 |
num_prompt_tokens, max_new_tokens, num_prompt_tokens0, num_prompt_tokens_actual, \
|
| 3763 |
+
chat_index, top_k_docs, one_doc_size
|
|
|
|
| 3764 |
|
| 3765 |
|
| 3766 |
def get_docs_tokens(tokenizer, text_context_list=[], max_input_tokens=None):
|
src/gpt_langchain.py
CHANGED
|
@@ -29,11 +29,10 @@ import yaml
|
|
| 29 |
|
| 30 |
from joblib import delayed
|
| 31 |
from langchain.callbacks import streaming_stdout
|
| 32 |
-
from langchain.callbacks.base import Callbacks
|
| 33 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
| 34 |
from langchain.llms.huggingface_pipeline import VALID_TASKS
|
| 35 |
from langchain.llms.utils import enforce_stop_tokens
|
| 36 |
-
from langchain.schema import LLMResult, Generation
|
| 37 |
from langchain.tools import PythonREPLTool
|
| 38 |
from langchain.tools.json.tool import JsonSpec
|
| 39 |
from tqdm import tqdm
|
|
@@ -945,10 +944,7 @@ class H2OReplicate(Replicate):
|
|
| 945 |
assert self.tokenizer is not None
|
| 946 |
from h2oai_pipeline import H2OTextGenerationPipeline
|
| 947 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
| 948 |
-
# Note Replicate handles the prompting of the specific model
|
| 949 |
-
data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
|
| 950 |
-
prompt = self.prompter.generate_prompt(data_point)
|
| 951 |
-
|
| 952 |
return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
| 953 |
|
| 954 |
def get_token_ids(self, text: str) -> List[int]:
|
|
@@ -957,98 +953,21 @@ class H2OReplicate(Replicate):
|
|
| 957 |
# return _get_token_ids_default_method(text)
|
| 958 |
|
| 959 |
|
| 960 |
-
class
|
| 961 |
-
def get_messages(self, prompts):
|
| 962 |
-
from langchain.schema import AIMessage, SystemMessage, HumanMessage
|
| 963 |
-
messages = []
|
| 964 |
-
if self.system_prompt:
|
| 965 |
-
messages.append(SystemMessage(content=self.system_prompt))
|
| 966 |
-
if self.chat_conversation:
|
| 967 |
-
for messages1 in self.chat_conversation:
|
| 968 |
-
messages.append(HumanMessage(content=messages1[0] if messages1[0] is not None else ''))
|
| 969 |
-
messages.append(AIMessage(content=messages1[1] if messages1[1] is not None else ''))
|
| 970 |
-
assert len(prompts) == 1, "Not implemented"
|
| 971 |
-
messages.append(HumanMessage(content=prompts[0].text if prompts[0].text is not None else ''))
|
| 972 |
-
return [messages]
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
class H2OChatOpenAI(ChatOpenAI, ExtraChat):
|
| 976 |
-
tokenizer: Any = None # for vllm_chat
|
| 977 |
-
system_prompt: Any = None
|
| 978 |
-
chat_conversation: Any = []
|
| 979 |
-
|
| 980 |
@classmethod
|
| 981 |
def _all_required_field_names(cls) -> Set:
|
| 982 |
_all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names()
|
| 983 |
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
|
| 984 |
return _all_required_field_names
|
| 985 |
|
| 986 |
-
def get_token_ids(self, text: str) -> List[int]:
|
| 987 |
-
if self.tokenizer is not None:
|
| 988 |
-
return self.tokenizer.encode(text)
|
| 989 |
-
else:
|
| 990 |
-
# OpenAI uses tiktoken
|
| 991 |
-
return super().get_token_ids(text)
|
| 992 |
-
|
| 993 |
-
def generate_prompt(
|
| 994 |
-
self,
|
| 995 |
-
prompts: List[PromptValue],
|
| 996 |
-
stop: Optional[List[str]] = None,
|
| 997 |
-
callbacks: Callbacks = None,
|
| 998 |
-
**kwargs: Any,
|
| 999 |
-
) -> LLMResult:
|
| 1000 |
-
prompt_messages = self.get_messages(prompts)
|
| 1001 |
-
# prompt_messages = [p.to_messages() for p in prompts]
|
| 1002 |
-
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
| 1003 |
-
|
| 1004 |
-
async def agenerate_prompt(
|
| 1005 |
-
self,
|
| 1006 |
-
prompts: List[PromptValue],
|
| 1007 |
-
stop: Optional[List[str]] = None,
|
| 1008 |
-
callbacks: Callbacks = None,
|
| 1009 |
-
**kwargs: Any,
|
| 1010 |
-
) -> LLMResult:
|
| 1011 |
-
prompt_messages = self.get_messages(prompts)
|
| 1012 |
-
# prompt_messages = [p.to_messages() for p in prompts]
|
| 1013 |
-
return await self.agenerate(
|
| 1014 |
-
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
| 1015 |
-
)
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
class H2OAzureChatOpenAI(AzureChatOpenAI, ExtraChat):
|
| 1019 |
-
system_prompt: Any = None
|
| 1020 |
-
chat_conversation: Any = []
|
| 1021 |
|
|
|
|
| 1022 |
@classmethod
|
| 1023 |
def _all_required_field_names(cls) -> Set:
|
| 1024 |
_all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names()
|
| 1025 |
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
|
| 1026 |
return _all_required_field_names
|
| 1027 |
|
| 1028 |
-
def generate_prompt(
|
| 1029 |
-
self,
|
| 1030 |
-
prompts: List[PromptValue],
|
| 1031 |
-
stop: Optional[List[str]] = None,
|
| 1032 |
-
callbacks: Callbacks = None,
|
| 1033 |
-
**kwargs: Any,
|
| 1034 |
-
) -> LLMResult:
|
| 1035 |
-
prompt_messages = self.get_messages(prompts)
|
| 1036 |
-
# prompt_messages = [p.to_messages() for p in prompts]
|
| 1037 |
-
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
| 1038 |
-
|
| 1039 |
-
async def agenerate_prompt(
|
| 1040 |
-
self,
|
| 1041 |
-
prompts: List[PromptValue],
|
| 1042 |
-
stop: Optional[List[str]] = None,
|
| 1043 |
-
callbacks: Callbacks = None,
|
| 1044 |
-
**kwargs: Any,
|
| 1045 |
-
) -> LLMResult:
|
| 1046 |
-
prompt_messages = self.get_messages(prompts)
|
| 1047 |
-
# prompt_messages = [p.to_messages() for p in prompts]
|
| 1048 |
-
return await self.agenerate(
|
| 1049 |
-
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
| 1050 |
-
)
|
| 1051 |
-
|
| 1052 |
|
| 1053 |
class H2OAzureOpenAI(AzureOpenAI):
|
| 1054 |
@classmethod
|
|
@@ -1133,7 +1052,7 @@ def get_llm(use_openai_model=False,
|
|
| 1133 |
if 'meta/llama' in model_string:
|
| 1134 |
temperature = max(0.01, temperature if do_sample else 0)
|
| 1135 |
else:
|
| 1136 |
-
temperature =
|
| 1137 |
gen_kwargs = dict(temperature=temperature,
|
| 1138 |
seed=1234,
|
| 1139 |
max_length=max_new_tokens, # langchain
|
|
@@ -1149,7 +1068,8 @@ def get_llm(use_openai_model=False,
|
|
| 1149 |
if system_prompt:
|
| 1150 |
gen_kwargs.update(dict(system_prompt=system_prompt))
|
| 1151 |
|
| 1152 |
-
# replicate handles prompting
|
|
|
|
| 1153 |
if stream_output:
|
| 1154 |
callbacks = [StreamingGradioCallbackHandler()]
|
| 1155 |
streamer = callbacks[0] if stream_output else None
|
|
@@ -1188,8 +1108,8 @@ def get_llm(use_openai_model=False,
|
|
| 1188 |
if inf_type == 'openai_chat' or inf_type == 'vllm_chat':
|
| 1189 |
cls = H2OChatOpenAI
|
| 1190 |
# FIXME: Support context, iinput
|
| 1191 |
-
if inf_type == 'vllm_chat':
|
| 1192 |
-
|
| 1193 |
openai_api_key = openai.api_key
|
| 1194 |
elif inf_type == 'openai_azure_chat':
|
| 1195 |
cls = H2OAzureChatOpenAI
|
|
@@ -1248,8 +1168,6 @@ def get_llm(use_openai_model=False,
|
|
| 1248 |
logit_bias=None if inf_type == 'vllm' else {},
|
| 1249 |
max_retries=6,
|
| 1250 |
streaming=stream_output,
|
| 1251 |
-
system_prompt=system_prompt,
|
| 1252 |
-
# chat_conversation=chat_conversation, # don't do here, not token aware
|
| 1253 |
**kwargs_extra
|
| 1254 |
)
|
| 1255 |
streamer = callbacks[0] if stream_output else None
|
|
@@ -3582,6 +3500,7 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
|
|
| 3582 |
prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output,
|
| 3583 |
system_prompt=system_prompt)
|
| 3584 |
|
|
|
|
| 3585 |
scores = []
|
| 3586 |
chain = None
|
| 3587 |
|
|
@@ -3598,8 +3517,8 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
|
|
| 3598 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
| 3599 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
| 3600 |
docs, chain, scores, \
|
| 3601 |
-
num_docs_before_cut, \
|
| 3602 |
-
use_llm_if_no_docs, top_k_docs_max_show = \
|
| 3603 |
get_chain(**sim_kwargs)
|
| 3604 |
if document_subset in non_query_commands:
|
| 3605 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
|
@@ -3620,21 +3539,23 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
|
|
| 3620 |
ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
|
| 3621 |
yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0)
|
| 3622 |
return
|
| 3623 |
-
if
|
| 3624 |
-
if not docs
|
| 3625 |
-
|
| 3626 |
-
|
| 3627 |
-
|
| 3628 |
-
|
| 3629 |
-
|
| 3630 |
-
|
|
|
|
|
|
|
| 3631 |
extra = ''
|
| 3632 |
yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
|
| 3633 |
return
|
| 3634 |
|
| 3635 |
-
|
| 3636 |
-
|
| 3637 |
-
|
| 3638 |
return
|
| 3639 |
|
| 3640 |
# context stuff similar to used in evaluate()
|
|
@@ -3735,8 +3656,7 @@ Respond to prompt of Final Answer with your final high-quality bullet list answe
|
|
| 3735 |
prompt = prompt_basic
|
| 3736 |
num_prompt_tokens = get_token_count(prompt, tokenizer)
|
| 3737 |
|
| 3738 |
-
if
|
| 3739 |
-
# if no docs, then no sources to cite
|
| 3740 |
ret = answer
|
| 3741 |
extra = ''
|
| 3742 |
yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
|
|
@@ -3895,7 +3815,8 @@ def get_chain(query=None,
|
|
| 3895 |
if text_context_list is None:
|
| 3896 |
text_context_list = []
|
| 3897 |
|
| 3898 |
-
#
|
|
|
|
| 3899 |
query_action = langchain_action == LangChainAction.QUERY.value
|
| 3900 |
summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
| 3901 |
LangChainAction.SUMMARIZE_ALL.value,
|
|
@@ -3927,6 +3848,8 @@ def get_chain(query=None,
|
|
| 3927 |
add_search_to_context &= len(docs_search) > 0
|
| 3928 |
top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search))
|
| 3929 |
|
|
|
|
|
|
|
| 3930 |
use_llm_if_no_docs = True
|
| 3931 |
|
| 3932 |
from src.output_parser import H2OMRKLOutputParser
|
|
@@ -3954,9 +3877,10 @@ def get_chain(query=None,
|
|
| 3954 |
|
| 3955 |
docs = []
|
| 3956 |
scores = []
|
|
|
|
| 3957 |
num_docs_before_cut = 0
|
| 3958 |
use_llm_if_no_docs = True
|
| 3959 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
| 3960 |
|
| 3961 |
if LangChainAgent.COLLECTION.value in langchain_agents:
|
| 3962 |
output_parser = H2OMRKLOutputParser()
|
|
@@ -3975,9 +3899,10 @@ def get_chain(query=None,
|
|
| 3975 |
|
| 3976 |
docs = []
|
| 3977 |
scores = []
|
|
|
|
| 3978 |
num_docs_before_cut = 0
|
| 3979 |
use_llm_if_no_docs = True
|
| 3980 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
| 3981 |
|
| 3982 |
if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'):
|
| 3983 |
chain = create_python_agent(
|
|
@@ -3993,9 +3918,10 @@ def get_chain(query=None,
|
|
| 3993 |
|
| 3994 |
docs = []
|
| 3995 |
scores = []
|
|
|
|
| 3996 |
num_docs_before_cut = 0
|
| 3997 |
use_llm_if_no_docs = True
|
| 3998 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
| 3999 |
|
| 4000 |
if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'):
|
| 4001 |
# FIXME: DATA
|
|
@@ -4012,9 +3938,10 @@ def get_chain(query=None,
|
|
| 4012 |
|
| 4013 |
docs = []
|
| 4014 |
scores = []
|
|
|
|
| 4015 |
num_docs_before_cut = 0
|
| 4016 |
use_llm_if_no_docs = True
|
| 4017 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
| 4018 |
|
| 4019 |
if isinstance(document_choice, str):
|
| 4020 |
document_choice = [document_choice]
|
|
@@ -4044,9 +3971,10 @@ def get_chain(query=None,
|
|
| 4044 |
|
| 4045 |
docs = []
|
| 4046 |
scores = []
|
|
|
|
| 4047 |
num_docs_before_cut = 0
|
| 4048 |
use_llm_if_no_docs = True
|
| 4049 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
| 4050 |
|
| 4051 |
if isinstance(document_choice, str):
|
| 4052 |
document_choice = [document_choice]
|
|
@@ -4057,7 +3985,7 @@ def get_chain(query=None,
|
|
| 4057 |
document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')]
|
| 4058 |
if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[
|
| 4059 |
0].endswith(
|
| 4060 |
-
|
| 4061 |
data_file = document_choice[0]
|
| 4062 |
if inference_server.startswith('openai_chat'):
|
| 4063 |
chain = create_csv_agent(
|
|
@@ -4078,9 +4006,19 @@ def get_chain(query=None,
|
|
| 4078 |
|
| 4079 |
docs = []
|
| 4080 |
scores = []
|
|
|
|
| 4081 |
num_docs_before_cut = 0
|
| 4082 |
use_llm_if_no_docs = True
|
| 4083 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4084 |
|
| 4085 |
# https://github.com/hwchase17/langchain/issues/1946
|
| 4086 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
|
@@ -4152,7 +4090,8 @@ def get_chain(query=None,
|
|
| 4152 |
pre_prompt_query, prompt_query,
|
| 4153 |
pre_prompt_summary, prompt_summary,
|
| 4154 |
langchain_action,
|
| 4155 |
-
|
|
|
|
| 4156 |
auto_reduce_chunks,
|
| 4157 |
got_db_docs,
|
| 4158 |
add_search_to_context)
|
|
@@ -4160,242 +4099,239 @@ def get_chain(query=None,
|
|
| 4160 |
max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
|
| 4161 |
model_name=model_name, max_new_tokens=max_new_tokens)
|
| 4162 |
|
| 4163 |
-
if
|
| 4164 |
-
|
| 4165 |
-
|
| 4166 |
-
|
| 4167 |
-
|
| 4168 |
-
|
| 4169 |
-
|
| 4170 |
-
|
| 4171 |
-
|
| 4172 |
-
|
| 4173 |
-
|
| 4174 |
-
|
| 4175 |
-
|
| 4176 |
-
|
| 4177 |
-
|
| 4178 |
-
|
| 4179 |
-
|
| 4180 |
-
|
| 4181 |
-
|
| 4182 |
-
|
| 4183 |
-
|
| 4184 |
-
|
| 4185 |
-
{"filter": {"chunk_id": {"$
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4186 |
else:
|
| 4187 |
-
|
| 4188 |
-
|
| 4189 |
-
|
| 4190 |
-
filter_kwargs = {}
|
| 4191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4192 |
or_filter = [
|
| 4193 |
-
{"
|
| 4194 |
-
|
|
|
|
| 4195 |
for x in document_choice]
|
| 4196 |
filter_kwargs = dict(filter={"$or": or_filter})
|
| 4197 |
-
|
| 4198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4199 |
one_filter = \
|
| 4200 |
-
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {
|
| 4201 |
-
|
| 4202 |
-
|
| 4203 |
-
"$eq": -1}}
|
| 4204 |
for x in document_choice][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4205 |
|
| 4206 |
-
|
| 4207 |
-
|
| 4208 |
-
|
| 4209 |
-
|
| 4210 |
-
|
| 4211 |
-
0] == DocumentChoice.ALL.value:
|
| 4212 |
-
filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
|
| 4213 |
-
{"filter": {"chunk_id": {"$eq": -1}}}
|
| 4214 |
-
filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}}
|
| 4215 |
-
elif len(document_choice) >= 2:
|
| 4216 |
-
if document_choice[0] == DocumentChoice.ALL.value:
|
| 4217 |
-
document_choice = document_choice[1:]
|
| 4218 |
-
or_filter = [
|
| 4219 |
-
{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
|
| 4220 |
-
"chunk_id": {
|
| 4221 |
-
"$eq": -1}}
|
| 4222 |
-
for x in document_choice]
|
| 4223 |
-
filter_kwargs = dict(filter={"$or": or_filter})
|
| 4224 |
-
or_filter_backup = [
|
| 4225 |
-
{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
|
| 4226 |
-
for x in document_choice]
|
| 4227 |
-
filter_kwargs_backup = dict(filter={"$or": or_filter_backup})
|
| 4228 |
-
elif len(document_choice) == 1:
|
| 4229 |
-
# degenerate UX bug in chroma
|
| 4230 |
-
one_filter = \
|
| 4231 |
-
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
|
| 4232 |
-
"chunk_id": {
|
| 4233 |
-
"$eq": -1}}
|
| 4234 |
-
for x in document_choice][0]
|
| 4235 |
-
filter_kwargs = dict(filter=one_filter)
|
| 4236 |
-
one_filter_backup = \
|
| 4237 |
-
[{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
|
| 4238 |
-
for x in document_choice][0]
|
| 4239 |
-
filter_kwargs_backup = dict(filter=one_filter_backup)
|
| 4240 |
-
else:
|
| 4241 |
-
# shouldn't reach
|
| 4242 |
-
filter_kwargs = {}
|
| 4243 |
-
filter_kwargs_backup = {}
|
| 4244 |
-
|
| 4245 |
-
if document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
|
| 4246 |
-
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
|
| 4247 |
-
text_context_list=text_context_list)
|
| 4248 |
-
if len(db_documents) == 0 and filter_kwargs_backup:
|
| 4249 |
-
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup,
|
| 4250 |
text_context_list=text_context_list)
|
| 4251 |
-
|
| 4252 |
-
|
| 4253 |
-
|
| 4254 |
-
|
| 4255 |
-
|
| 4256 |
-
|
| 4257 |
-
|
| 4258 |
-
|
| 4259 |
-
|
| 4260 |
-
|
| 4261 |
-
|
| 4262 |
-
|
| 4263 |
-
|
| 4264 |
-
|
| 4265 |
-
|
| 4266 |
-
|
| 4267 |
-
|
| 4268 |
-
|
| 4269 |
-
|
| 4270 |
-
|
| 4271 |
-
|
| 4272 |
-
|
| 4273 |
-
]
|
| 4274 |
-
if len(docs_with_score2) == 0 and len(docs_with_score) > 0:
|
| 4275 |
-
# old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive
|
| 4276 |
-
# just do again and relax filter, let summarize operate on actual chunks if nothing else
|
| 4277 |
docs_with_score2 = [x for hx, cx, x in
|
| 4278 |
-
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score),
|
| 4279 |
-
|
| 4280 |
]
|
| 4281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4282 |
|
| 4283 |
-
|
| 4284 |
-
|
| 4285 |
-
|
| 4286 |
-
|
| 4287 |
-
|
| 4288 |
-
|
| 4289 |
-
|
| 4290 |
-
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type,
|
| 4291 |
-
text_context_list=text_context_list,
|
| 4292 |
-
verbose=verbose)
|
| 4293 |
-
if len(docs_with_score) == 0 and filter_kwargs_backup:
|
| 4294 |
-
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db,
|
| 4295 |
-
db_type,
|
| 4296 |
text_context_list=text_context_list,
|
| 4297 |
verbose=verbose)
|
| 4298 |
-
|
| 4299 |
-
|
| 4300 |
-
|
| 4301 |
-
|
| 4302 |
-
|
| 4303 |
-
|
| 4304 |
-
|
| 4305 |
-
|
| 4306 |
-
|
| 4307 |
-
|
| 4308 |
-
|
| 4309 |
-
|
| 4310 |
-
|
| 4311 |
-
|
| 4312 |
-
|
| 4313 |
-
|
| 4314 |
-
|
| 4315 |
-
|
| 4316 |
-
|
| 4317 |
-
|
| 4318 |
-
|
| 4319 |
-
|
| 4320 |
-
|
| 4321 |
-
|
| 4322 |
-
|
| 4323 |
-
|
| 4324 |
-
|
| 4325 |
-
|
| 4326 |
-
|
| 4327 |
-
|
| 4328 |
-
|
| 4329 |
-
|
| 4330 |
-
|
| 4331 |
-
|
| 4332 |
-
|
| 4333 |
-
|
| 4334 |
-
|
| 4335 |
-
|
| 4336 |
-
|
| 4337 |
-
|
| 4338 |
-
|
| 4339 |
-
|
| 4340 |
-
|
| 4341 |
-
|
| 4342 |
-
|
| 4343 |
-
assert external_handle_chat_conversation, "Should be handling only externally"
|
| 4344 |
-
llm.chat_conversation = chat_conversation[chat_index:]
|
| 4345 |
-
if hasattr(llm, 'context'):
|
| 4346 |
-
llm.context = context
|
| 4347 |
-
if hasattr(llm, 'iinput'):
|
| 4348 |
-
llm.iinput = iinput
|
| 4349 |
-
# avoid craziness
|
| 4350 |
-
if 0 < top_k_docs_trial < max_chunks:
|
| 4351 |
# avoid craziness
|
| 4352 |
-
if
|
| 4353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4354 |
else:
|
| 4355 |
-
|
| 4356 |
-
elif top_k_docs_trial >= max_chunks:
|
| 4357 |
-
top_k_docs = max_chunks
|
| 4358 |
-
if top_k_docs > 0:
|
| 4359 |
-
docs_with_score = docs_with_score[:top_k_docs]
|
| 4360 |
-
elif one_doc_size is not None:
|
| 4361 |
-
docs_with_score = [docs_with_score[0][:one_doc_size]]
|
| 4362 |
else:
|
| 4363 |
-
|
| 4364 |
-
|
| 4365 |
-
|
| 4366 |
-
|
| 4367 |
-
|
| 4368 |
-
|
| 4369 |
-
text_context_list=[x[0].page_content for x in docs_with_score],
|
| 4370 |
-
max_input_tokens=total_tokens_for_docs)
|
| 4371 |
|
| 4372 |
-
|
| 4373 |
-
|
| 4374 |
-
# put most relevant chunks closest to question,
|
| 4375 |
-
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
|
| 4376 |
-
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
|
| 4377 |
-
if docs_ordering_type in ['best_first']:
|
| 4378 |
-
pass
|
| 4379 |
-
elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
|
| 4380 |
-
docs_with_score.reverse()
|
| 4381 |
-
elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
|
| 4382 |
-
docs_with_score = reverse_ucurve_list(docs_with_score)
|
| 4383 |
-
else:
|
| 4384 |
-
raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)
|
| 4385 |
|
| 4386 |
-
|
| 4387 |
-
|
| 4388 |
-
|
| 4389 |
-
|
| 4390 |
-
|
| 4391 |
-
|
| 4392 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4393 |
|
| 4394 |
-
|
|
|
|
|
|
|
| 4395 |
|
| 4396 |
if document_subset in non_query_commands:
|
| 4397 |
-
# no LLM use
|
| 4398 |
-
return docs, None, [], num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
| 4399 |
|
| 4400 |
# FIXME: WIP
|
| 4401 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
|
@@ -4413,6 +4349,7 @@ def get_chain(query=None,
|
|
| 4413 |
|
| 4414 |
if len(docs) == 0:
|
| 4415 |
# avoid context == in prompt then
|
|
|
|
| 4416 |
template = template_if_no_docs
|
| 4417 |
|
| 4418 |
got_db_docs = got_db_docs and len(text_context_list) < len(docs)
|
|
@@ -4424,7 +4361,8 @@ def get_chain(query=None,
|
|
| 4424 |
pre_prompt_query, prompt_query,
|
| 4425 |
pre_prompt_summary, prompt_summary,
|
| 4426 |
langchain_action,
|
| 4427 |
-
|
|
|
|
| 4428 |
auto_reduce_chunks,
|
| 4429 |
got_db_docs,
|
| 4430 |
add_search_to_context)
|
|
@@ -4442,7 +4380,10 @@ def get_chain(query=None,
|
|
| 4442 |
else:
|
| 4443 |
# only if use_openai_model = True, unused normally except in testing
|
| 4444 |
chain = load_qa_with_sources_chain(llm)
|
| 4445 |
-
|
|
|
|
|
|
|
|
|
|
| 4446 |
target = wrapped_partial(chain, chain_kwargs)
|
| 4447 |
elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
| 4448 |
LangChainAction.SUMMARIZE_REFINE,
|
|
@@ -4486,7 +4427,7 @@ def get_chain(query=None,
|
|
| 4486 |
else:
|
| 4487 |
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
| 4488 |
|
| 4489 |
-
return docs, target, scores, num_docs_before_cut, use_llm_if_no_docs, top_k_docs_max_show
|
| 4490 |
|
| 4491 |
|
| 4492 |
def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None):
|
|
@@ -4532,11 +4473,11 @@ def get_tokenizer(db=None, llm=None, tokenizer=None, inference_server=None, use_
|
|
| 4532 |
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
|
| 4533 |
# more accurate
|
| 4534 |
return llm.pipeline.tokenizer
|
| 4535 |
-
elif hasattr(llm, 'tokenizer')
|
| 4536 |
# e.g. TGI client mode etc.
|
| 4537 |
return llm.tokenizer
|
| 4538 |
elif inference_server in ['openai', 'openai_chat', 'openai_azure',
|
| 4539 |
-
'openai_azure_chat']
|
| 4540 |
return tokenizer
|
| 4541 |
elif isinstance(tokenizer, FakeTokenizer):
|
| 4542 |
return tokenizer
|
|
@@ -4559,7 +4500,8 @@ def get_template(query, iinput,
|
|
| 4559 |
pre_prompt_query, prompt_query,
|
| 4560 |
pre_prompt_summary, prompt_summary,
|
| 4561 |
langchain_action,
|
| 4562 |
-
|
|
|
|
| 4563 |
auto_reduce_chunks,
|
| 4564 |
got_db_docs,
|
| 4565 |
add_search_to_context):
|
|
@@ -4581,7 +4523,7 @@ def get_template(query, iinput,
|
|
| 4581 |
if langchain_action == LangChainAction.QUERY.value:
|
| 4582 |
if iinput:
|
| 4583 |
query = "%s\n%s" % (query, iinput)
|
| 4584 |
-
if not
|
| 4585 |
template_if_no_docs = template = """{context}{question}"""
|
| 4586 |
else:
|
| 4587 |
template = """%s
|
|
|
|
| 29 |
|
| 30 |
from joblib import delayed
|
| 31 |
from langchain.callbacks import streaming_stdout
|
|
|
|
| 32 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
| 33 |
from langchain.llms.huggingface_pipeline import VALID_TASKS
|
| 34 |
from langchain.llms.utils import enforce_stop_tokens
|
| 35 |
+
from langchain.schema import LLMResult, Generation
|
| 36 |
from langchain.tools import PythonREPLTool
|
| 37 |
from langchain.tools.json.tool import JsonSpec
|
| 38 |
from tqdm import tqdm
|
|
|
|
| 944 |
assert self.tokenizer is not None
|
| 945 |
from h2oai_pipeline import H2OTextGenerationPipeline
|
| 946 |
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
| 947 |
+
# Note Replicate handles the prompting of the specific model
|
|
|
|
|
|
|
|
|
|
| 948 |
return super()._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
|
| 949 |
|
| 950 |
def get_token_ids(self, text: str) -> List[int]:
|
|
|
|
| 953 |
# return _get_token_ids_default_method(text)
|
| 954 |
|
| 955 |
|
| 956 |
+
class H2OChatOpenAI(ChatOpenAI):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 957 |
@classmethod
|
| 958 |
def _all_required_field_names(cls) -> Set:
|
| 959 |
_all_required_field_names = super(ChatOpenAI, cls)._all_required_field_names()
|
| 960 |
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
|
| 961 |
return _all_required_field_names
|
| 962 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 963 |
|
| 964 |
+
class H2OAzureChatOpenAI(AzureChatOpenAI):
|
| 965 |
@classmethod
|
| 966 |
def _all_required_field_names(cls) -> Set:
|
| 967 |
_all_required_field_names = super(AzureChatOpenAI, cls)._all_required_field_names()
|
| 968 |
_all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty', 'logit_bias'})
|
| 969 |
return _all_required_field_names
|
| 970 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 971 |
|
| 972 |
class H2OAzureOpenAI(AzureOpenAI):
|
| 973 |
@classmethod
|
|
|
|
| 1052 |
if 'meta/llama' in model_string:
|
| 1053 |
temperature = max(0.01, temperature if do_sample else 0)
|
| 1054 |
else:
|
| 1055 |
+
temperature =temperature if do_sample else 0
|
| 1056 |
gen_kwargs = dict(temperature=temperature,
|
| 1057 |
seed=1234,
|
| 1058 |
max_length=max_new_tokens, # langchain
|
|
|
|
| 1068 |
if system_prompt:
|
| 1069 |
gen_kwargs.update(dict(system_prompt=system_prompt))
|
| 1070 |
|
| 1071 |
+
# replicate handles prompting, so avoid get_response() filter
|
| 1072 |
+
prompter.prompt_type = 'plain'
|
| 1073 |
if stream_output:
|
| 1074 |
callbacks = [StreamingGradioCallbackHandler()]
|
| 1075 |
streamer = callbacks[0] if stream_output else None
|
|
|
|
| 1108 |
if inf_type == 'openai_chat' or inf_type == 'vllm_chat':
|
| 1109 |
cls = H2OChatOpenAI
|
| 1110 |
# FIXME: Support context, iinput
|
| 1111 |
+
# if inf_type == 'vllm_chat':
|
| 1112 |
+
# kwargs_extra.update(dict(tokenizer=tokenizer))
|
| 1113 |
openai_api_key = openai.api_key
|
| 1114 |
elif inf_type == 'openai_azure_chat':
|
| 1115 |
cls = H2OAzureChatOpenAI
|
|
|
|
| 1168 |
logit_bias=None if inf_type == 'vllm' else {},
|
| 1169 |
max_retries=6,
|
| 1170 |
streaming=stream_output,
|
|
|
|
|
|
|
| 1171 |
**kwargs_extra
|
| 1172 |
)
|
| 1173 |
streamer = callbacks[0] if stream_output else None
|
|
|
|
| 3500 |
prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=chat, stream_output=stream_output,
|
| 3501 |
system_prompt=system_prompt)
|
| 3502 |
|
| 3503 |
+
use_docs_planned = False
|
| 3504 |
scores = []
|
| 3505 |
chain = None
|
| 3506 |
|
|
|
|
| 3517 |
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
|
| 3518 |
assert not missing_kwargs, "Missing: %s" % missing_kwargs
|
| 3519 |
docs, chain, scores, \
|
| 3520 |
+
use_docs_planned, num_docs_before_cut, \
|
| 3521 |
+
use_llm_if_no_docs, llm_mode, top_k_docs_max_show = \
|
| 3522 |
get_chain(**sim_kwargs)
|
| 3523 |
if document_subset in non_query_commands:
|
| 3524 |
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
|
|
|
|
| 3539 |
ret, extra = get_sources_answer(*get_answer_args, **get_answer_kwargs)
|
| 3540 |
yield dict(prompt=prompt_basic, response=formatted_doc_chunks, sources=extra, num_prompt_tokens=0)
|
| 3541 |
return
|
| 3542 |
+
if not use_llm_if_no_docs:
|
| 3543 |
+
if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
| 3544 |
+
LangChainAction.SUMMARIZE_ALL.value,
|
| 3545 |
+
LangChainAction.SUMMARIZE_REFINE.value]:
|
| 3546 |
+
ret = 'No relevant documents to summarize.' if num_docs_before_cut else 'No documents to summarize.'
|
| 3547 |
+
extra = ''
|
| 3548 |
+
yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
|
| 3549 |
+
return
|
| 3550 |
+
if not docs and not llm_mode:
|
| 3551 |
+
ret = 'No relevant documents to query (for chatting with LLM, pick Resources->Collections->LLM).' if num_docs_before_cut else 'No documents to query (for chatting with LLM, pick Resources->Collections->LLM).'
|
| 3552 |
extra = ''
|
| 3553 |
yield dict(prompt=prompt_basic, response=ret, sources=extra, num_prompt_tokens=0)
|
| 3554 |
return
|
| 3555 |
|
| 3556 |
+
if chain is None and not langchain_only_model:
|
| 3557 |
+
# here if no docs at all and not HF type
|
| 3558 |
+
# can only return if HF type
|
| 3559 |
return
|
| 3560 |
|
| 3561 |
# context stuff similar to used in evaluate()
|
|
|
|
| 3656 |
prompt = prompt_basic
|
| 3657 |
num_prompt_tokens = get_token_count(prompt, tokenizer)
|
| 3658 |
|
| 3659 |
+
if not use_docs_planned:
|
|
|
|
| 3660 |
ret = answer
|
| 3661 |
extra = ''
|
| 3662 |
yield dict(prompt=prompt, response=ret, sources=extra, num_prompt_tokens=num_prompt_tokens)
|
|
|
|
| 3815 |
if text_context_list is None:
|
| 3816 |
text_context_list = []
|
| 3817 |
|
| 3818 |
+
# default value:
|
| 3819 |
+
llm_mode = langchain_mode in ['Disabled', 'LLM'] and len(text_context_list) == 0
|
| 3820 |
query_action = langchain_action == LangChainAction.QUERY.value
|
| 3821 |
summarize_action = langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
| 3822 |
LangChainAction.SUMMARIZE_ALL.value,
|
|
|
|
| 3848 |
add_search_to_context &= len(docs_search) > 0
|
| 3849 |
top_k_docs_max_show = max(top_k_docs_max_show, len(docs_search))
|
| 3850 |
|
| 3851 |
+
if len(text_context_list) > 0:
|
| 3852 |
+
llm_mode = False
|
| 3853 |
use_llm_if_no_docs = True
|
| 3854 |
|
| 3855 |
from src.output_parser import H2OMRKLOutputParser
|
|
|
|
| 3877 |
|
| 3878 |
docs = []
|
| 3879 |
scores = []
|
| 3880 |
+
use_docs_planned = False
|
| 3881 |
num_docs_before_cut = 0
|
| 3882 |
use_llm_if_no_docs = True
|
| 3883 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 3884 |
|
| 3885 |
if LangChainAgent.COLLECTION.value in langchain_agents:
|
| 3886 |
output_parser = H2OMRKLOutputParser()
|
|
|
|
| 3899 |
|
| 3900 |
docs = []
|
| 3901 |
scores = []
|
| 3902 |
+
use_docs_planned = False
|
| 3903 |
num_docs_before_cut = 0
|
| 3904 |
use_llm_if_no_docs = True
|
| 3905 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 3906 |
|
| 3907 |
if LangChainAgent.PYTHON.value in langchain_agents and inference_server.startswith('openai'):
|
| 3908 |
chain = create_python_agent(
|
|
|
|
| 3918 |
|
| 3919 |
docs = []
|
| 3920 |
scores = []
|
| 3921 |
+
use_docs_planned = False
|
| 3922 |
num_docs_before_cut = 0
|
| 3923 |
use_llm_if_no_docs = True
|
| 3924 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 3925 |
|
| 3926 |
if LangChainAgent.PANDAS.value in langchain_agents and inference_server.startswith('openai_chat'):
|
| 3927 |
# FIXME: DATA
|
|
|
|
| 3938 |
|
| 3939 |
docs = []
|
| 3940 |
scores = []
|
| 3941 |
+
use_docs_planned = False
|
| 3942 |
num_docs_before_cut = 0
|
| 3943 |
use_llm_if_no_docs = True
|
| 3944 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 3945 |
|
| 3946 |
if isinstance(document_choice, str):
|
| 3947 |
document_choice = [document_choice]
|
|
|
|
| 3971 |
|
| 3972 |
docs = []
|
| 3973 |
scores = []
|
| 3974 |
+
use_docs_planned = False
|
| 3975 |
num_docs_before_cut = 0
|
| 3976 |
use_llm_if_no_docs = True
|
| 3977 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 3978 |
|
| 3979 |
if isinstance(document_choice, str):
|
| 3980 |
document_choice = [document_choice]
|
|
|
|
| 3985 |
document_choice_agent = [x for x in document_choice_agent if x.endswith('.csv')]
|
| 3986 |
if LangChainAgent.CSV.value in langchain_agents and len(document_choice_agent) == 1 and document_choice_agent[
|
| 3987 |
0].endswith(
|
| 3988 |
+
'.csv'):
|
| 3989 |
data_file = document_choice[0]
|
| 3990 |
if inference_server.startswith('openai_chat'):
|
| 3991 |
chain = create_csv_agent(
|
|
|
|
| 4006 |
|
| 4007 |
docs = []
|
| 4008 |
scores = []
|
| 4009 |
+
use_docs_planned = False
|
| 4010 |
num_docs_before_cut = 0
|
| 4011 |
use_llm_if_no_docs = True
|
| 4012 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 4013 |
+
|
| 4014 |
+
# determine whether use of context out of docs is planned
|
| 4015 |
+
if not use_openai_model and prompt_type not in ['plain'] or langchain_only_model:
|
| 4016 |
+
if llm_mode:
|
| 4017 |
+
use_docs_planned = False
|
| 4018 |
+
else:
|
| 4019 |
+
use_docs_planned = True
|
| 4020 |
+
else:
|
| 4021 |
+
use_docs_planned = True
|
| 4022 |
|
| 4023 |
# https://github.com/hwchase17/langchain/issues/1946
|
| 4024 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
|
|
|
| 4090 |
pre_prompt_query, prompt_query,
|
| 4091 |
pre_prompt_summary, prompt_summary,
|
| 4092 |
langchain_action,
|
| 4093 |
+
llm_mode,
|
| 4094 |
+
use_docs_planned,
|
| 4095 |
auto_reduce_chunks,
|
| 4096 |
got_db_docs,
|
| 4097 |
add_search_to_context)
|
|
|
|
| 4099 |
max_input_tokens = get_max_input_tokens(llm=llm, tokenizer=tokenizer, inference_server=inference_server,
|
| 4100 |
model_name=model_name, max_new_tokens=max_new_tokens)
|
| 4101 |
|
| 4102 |
+
if (db or text_context_list) and use_docs_planned:
|
| 4103 |
+
if hasattr(db, '_persist_directory'):
|
| 4104 |
+
lock_file = get_db_lock_file(db, lock_type='sim')
|
| 4105 |
+
else:
|
| 4106 |
+
base_path = 'locks'
|
| 4107 |
+
base_path = makedirs(base_path, exist_ok=True, tmp_ok=True, use_base=True)
|
| 4108 |
+
name_path = "sim.lock"
|
| 4109 |
+
lock_file = os.path.join(base_path, name_path)
|
| 4110 |
+
|
| 4111 |
+
if not (isinstance(db, Chroma) or isinstance(db, ChromaMig) or ChromaMig.__name__ in str(db)):
|
| 4112 |
+
# only chroma supports filtering
|
| 4113 |
+
filter_kwargs = {}
|
| 4114 |
+
filter_kwargs_backup = {}
|
| 4115 |
+
else:
|
| 4116 |
+
import logging
|
| 4117 |
+
logging.getLogger("chromadb").setLevel(logging.ERROR)
|
| 4118 |
+
assert document_choice is not None, "Document choice was None"
|
| 4119 |
+
if isinstance(db, Chroma):
|
| 4120 |
+
filter_kwargs_backup = {} # shouldn't ever need backup
|
| 4121 |
+
# chroma >= 0.4
|
| 4122 |
+
if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
|
| 4123 |
+
0] == DocumentChoice.ALL.value:
|
| 4124 |
+
filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
|
| 4125 |
+
{"filter": {"chunk_id": {"$eq": -1}}}
|
| 4126 |
+
else:
|
| 4127 |
+
if document_choice[0] == DocumentChoice.ALL.value:
|
| 4128 |
+
document_choice = document_choice[1:]
|
| 4129 |
+
if len(document_choice) == 0:
|
| 4130 |
+
filter_kwargs = {}
|
| 4131 |
+
elif len(document_choice) > 1:
|
| 4132 |
+
or_filter = [
|
| 4133 |
+
{"$and": [dict(source={"$eq": x}), dict(chunk_id={"$gte": 0})]} if query_action else {
|
| 4134 |
+
"$and": [dict(source={"$eq": x}), dict(chunk_id={"$eq": -1})]}
|
| 4135 |
+
for x in document_choice]
|
| 4136 |
+
filter_kwargs = dict(filter={"$or": or_filter})
|
| 4137 |
+
else:
|
| 4138 |
+
# still chromadb UX bug, have to do different thing for 1 vs. 2+ docs when doing filter
|
| 4139 |
+
one_filter = \
|
| 4140 |
+
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {
|
| 4141 |
+
"source": {"$eq": x},
|
| 4142 |
+
"chunk_id": {
|
| 4143 |
+
"$eq": -1}}
|
| 4144 |
+
for x in document_choice][0]
|
| 4145 |
+
|
| 4146 |
+
filter_kwargs = dict(filter={"$and": [dict(source=one_filter['source']),
|
| 4147 |
+
dict(chunk_id=one_filter['chunk_id'])]})
|
| 4148 |
else:
|
| 4149 |
+
# migration for chroma < 0.4
|
| 4150 |
+
if len(document_choice) == 0 or len(document_choice) >= 1 and document_choice[
|
| 4151 |
+
0] == DocumentChoice.ALL.value:
|
| 4152 |
+
filter_kwargs = {"filter": {"chunk_id": {"$gte": 0}}} if query_action else \
|
| 4153 |
+
{"filter": {"chunk_id": {"$eq": -1}}}
|
| 4154 |
+
filter_kwargs_backup = {"filter": {"chunk_id": {"$gte": 0}}}
|
| 4155 |
+
elif len(document_choice) >= 2:
|
| 4156 |
+
if document_choice[0] == DocumentChoice.ALL.value:
|
| 4157 |
+
document_choice = document_choice[1:]
|
| 4158 |
or_filter = [
|
| 4159 |
+
{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
|
| 4160 |
+
"chunk_id": {
|
| 4161 |
+
"$eq": -1}}
|
| 4162 |
for x in document_choice]
|
| 4163 |
filter_kwargs = dict(filter={"$or": or_filter})
|
| 4164 |
+
or_filter_backup = [
|
| 4165 |
+
{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
|
| 4166 |
+
for x in document_choice]
|
| 4167 |
+
filter_kwargs_backup = dict(filter={"$or": or_filter_backup})
|
| 4168 |
+
elif len(document_choice) == 1:
|
| 4169 |
+
# degenerate UX bug in chroma
|
| 4170 |
one_filter = \
|
| 4171 |
+
[{"source": {"$eq": x}, "chunk_id": {"$gte": 0}} if query_action else {"source": {"$eq": x},
|
| 4172 |
+
"chunk_id": {
|
| 4173 |
+
"$eq": -1}}
|
|
|
|
| 4174 |
for x in document_choice][0]
|
| 4175 |
+
filter_kwargs = dict(filter=one_filter)
|
| 4176 |
+
one_filter_backup = \
|
| 4177 |
+
[{"source": {"$eq": x}} if query_action else {"source": {"$eq": x}}
|
| 4178 |
+
for x in document_choice][0]
|
| 4179 |
+
filter_kwargs_backup = dict(filter=one_filter_backup)
|
| 4180 |
+
else:
|
| 4181 |
+
# shouldn't reach
|
| 4182 |
+
filter_kwargs = {}
|
| 4183 |
+
filter_kwargs_backup = {}
|
| 4184 |
|
| 4185 |
+
if llm_mode:
|
| 4186 |
+
docs = []
|
| 4187 |
+
scores = []
|
| 4188 |
+
elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
|
| 4189 |
+
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4190 |
text_context_list=text_context_list)
|
| 4191 |
+
if len(db_documents) == 0 and filter_kwargs_backup:
|
| 4192 |
+
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs_backup,
|
| 4193 |
+
text_context_list=text_context_list)
|
| 4194 |
+
|
| 4195 |
+
if top_k_docs == -1:
|
| 4196 |
+
top_k_docs = len(db_documents)
|
| 4197 |
+
# similar to langchain's chroma's _results_to_docs_and_scores
|
| 4198 |
+
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
| 4199 |
+
for result in zip(db_documents, db_metadatas)]
|
| 4200 |
+
# set in metadata original order of docs
|
| 4201 |
+
[x[0].metadata.update(orig_index=ii) for ii, x in enumerate(docs_with_score)]
|
| 4202 |
+
|
| 4203 |
+
# order documents
|
| 4204 |
+
doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas]
|
| 4205 |
+
if query_action:
|
| 4206 |
+
doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas]
|
| 4207 |
+
docs_with_score2 = [x for hx, cx, x in
|
| 4208 |
+
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
| 4209 |
+
if cx >= 0]
|
| 4210 |
+
else:
|
| 4211 |
+
assert summarize_action
|
| 4212 |
+
doc_chunk_ids = [x.get('chunk_id', -1) for x in db_metadatas]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4213 |
docs_with_score2 = [x for hx, cx, x in
|
| 4214 |
+
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1]))
|
| 4215 |
+
if cx == -1
|
| 4216 |
]
|
| 4217 |
+
if len(docs_with_score2) == 0 and len(docs_with_score) > 0:
|
| 4218 |
+
# old database without chunk_id, migration added 0 but didn't make -1 as that would be expensive
|
| 4219 |
+
# just do again and relax filter, let summarize operate on actual chunks if nothing else
|
| 4220 |
+
docs_with_score2 = [x for hx, cx, x in
|
| 4221 |
+
sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score),
|
| 4222 |
+
key=lambda x: (x[0], x[1]))
|
| 4223 |
+
]
|
| 4224 |
+
docs_with_score = docs_with_score2
|
| 4225 |
|
| 4226 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
| 4227 |
+
docs = [x[0] for x in docs_with_score]
|
| 4228 |
+
scores = [x[1] for x in docs_with_score]
|
| 4229 |
+
num_docs_before_cut = len(docs)
|
| 4230 |
+
else:
|
| 4231 |
+
with filelock.FileLock(lock_file):
|
| 4232 |
+
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs, db, db_type,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4233 |
text_context_list=text_context_list,
|
| 4234 |
verbose=verbose)
|
| 4235 |
+
if len(docs_with_score) == 0 and filter_kwargs_backup:
|
| 4236 |
+
docs_with_score, got_db_docs = get_docs_with_score(query, k_db, filter_kwargs_backup, db,
|
| 4237 |
+
db_type,
|
| 4238 |
+
text_context_list=text_context_list,
|
| 4239 |
+
verbose=verbose)
|
| 4240 |
+
|
| 4241 |
+
tokenizer = get_tokenizer(db=db, llm=llm, tokenizer=tokenizer, inference_server=inference_server,
|
| 4242 |
+
use_openai_model=use_openai_model,
|
| 4243 |
+
db_type=db_type)
|
| 4244 |
+
# NOTE: if map_reduce, then no need to auto reduce chunks
|
| 4245 |
+
if query_action and (top_k_docs == -1 or auto_reduce_chunks):
|
| 4246 |
+
top_k_docs_tokenize = 100
|
| 4247 |
+
docs_with_score = docs_with_score[:top_k_docs_tokenize]
|
| 4248 |
+
|
| 4249 |
+
prompt_no_docs = template.format(context='', question=query)
|
| 4250 |
+
|
| 4251 |
+
model_max_length = tokenizer.model_max_length
|
| 4252 |
+
chat = True # FIXME?
|
| 4253 |
+
|
| 4254 |
+
# first docs_with_score are most important with highest score
|
| 4255 |
+
full_prompt, \
|
| 4256 |
+
instruction, iinput, context, \
|
| 4257 |
+
num_prompt_tokens, max_new_tokens, \
|
| 4258 |
+
num_prompt_tokens0, num_prompt_tokens_actual, \
|
| 4259 |
+
chat_index, top_k_docs_trial, one_doc_size = \
|
| 4260 |
+
get_limited_prompt(prompt_no_docs,
|
| 4261 |
+
iinput,
|
| 4262 |
+
tokenizer,
|
| 4263 |
+
prompter=prompter,
|
| 4264 |
+
inference_server=inference_server,
|
| 4265 |
+
prompt_type=prompt_type,
|
| 4266 |
+
prompt_dict=prompt_dict,
|
| 4267 |
+
chat=chat,
|
| 4268 |
+
max_new_tokens=max_new_tokens,
|
| 4269 |
+
system_prompt=system_prompt,
|
| 4270 |
+
context=context,
|
| 4271 |
+
chat_conversation=chat_conversation,
|
| 4272 |
+
text_context_list=[x[0].page_content for x in docs_with_score],
|
| 4273 |
+
keep_sources_in_context=keep_sources_in_context,
|
| 4274 |
+
model_max_length=model_max_length,
|
| 4275 |
+
memory_restriction_level=memory_restriction_level,
|
| 4276 |
+
langchain_mode=langchain_mode,
|
| 4277 |
+
add_chat_history_to_context=add_chat_history_to_context,
|
| 4278 |
+
min_max_new_tokens=min_max_new_tokens,
|
| 4279 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4280 |
# avoid craziness
|
| 4281 |
+
if 0 < top_k_docs_trial < max_chunks:
|
| 4282 |
+
# avoid craziness
|
| 4283 |
+
if top_k_docs == -1:
|
| 4284 |
+
top_k_docs = top_k_docs_trial
|
| 4285 |
+
else:
|
| 4286 |
+
top_k_docs = min(top_k_docs, top_k_docs_trial)
|
| 4287 |
+
elif top_k_docs_trial >= max_chunks:
|
| 4288 |
+
top_k_docs = max_chunks
|
| 4289 |
+
if top_k_docs > 0:
|
| 4290 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
| 4291 |
+
elif one_doc_size is not None:
|
| 4292 |
+
docs_with_score = [docs_with_score[0][:one_doc_size]]
|
| 4293 |
else:
|
| 4294 |
+
docs_with_score = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4295 |
else:
|
| 4296 |
+
if total_tokens_for_docs is not None:
|
| 4297 |
+
# used to limit tokens for summarization, e.g. public instance
|
| 4298 |
+
top_k_docs, one_doc_size, num_doc_tokens = \
|
| 4299 |
+
get_docs_tokens(tokenizer,
|
| 4300 |
+
text_context_list=[x[0].page_content for x in docs_with_score],
|
| 4301 |
+
max_input_tokens=total_tokens_for_docs)
|
|
|
|
|
|
|
| 4302 |
|
| 4303 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4304 |
|
| 4305 |
+
# put most relevant chunks closest to question,
|
| 4306 |
+
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
|
| 4307 |
+
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
|
| 4308 |
+
if docs_ordering_type in ['best_first']:
|
| 4309 |
+
pass
|
| 4310 |
+
elif docs_ordering_type in ['best_near_prompt', 'reverse_sort']:
|
| 4311 |
+
docs_with_score.reverse()
|
| 4312 |
+
elif docs_ordering_type in ['', None, 'reverse_ucurve_sort']:
|
| 4313 |
+
docs_with_score = reverse_ucurve_list(docs_with_score)
|
| 4314 |
+
else:
|
| 4315 |
+
raise ValueError("No such docs_ordering_type=%s" % docs_ordering_type)
|
| 4316 |
+
|
| 4317 |
+
# cut off so no high distance docs/sources considered
|
| 4318 |
+
num_docs_before_cut = len(docs_with_score)
|
| 4319 |
+
docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
|
| 4320 |
+
scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
|
| 4321 |
+
if len(scores) > 0 and verbose:
|
| 4322 |
+
print("Distance: min: %s max: %s mean: %s median: %s" %
|
| 4323 |
+
(scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
|
| 4324 |
+
else:
|
| 4325 |
+
docs = []
|
| 4326 |
+
scores = []
|
| 4327 |
|
| 4328 |
+
if not docs and use_docs_planned and not langchain_only_model:
|
| 4329 |
+
# if HF type and have no docs, can bail out
|
| 4330 |
+
return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 4331 |
|
| 4332 |
if document_subset in non_query_commands:
|
| 4333 |
+
# no LLM use
|
| 4334 |
+
return docs, None, [], False, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 4335 |
|
| 4336 |
# FIXME: WIP
|
| 4337 |
common_words_file = "data/NGSL_1.2_stats.csv.zip"
|
|
|
|
| 4349 |
|
| 4350 |
if len(docs) == 0:
|
| 4351 |
# avoid context == in prompt then
|
| 4352 |
+
use_docs_planned = False
|
| 4353 |
template = template_if_no_docs
|
| 4354 |
|
| 4355 |
got_db_docs = got_db_docs and len(text_context_list) < len(docs)
|
|
|
|
| 4361 |
pre_prompt_query, prompt_query,
|
| 4362 |
pre_prompt_summary, prompt_summary,
|
| 4363 |
langchain_action,
|
| 4364 |
+
llm_mode,
|
| 4365 |
+
use_docs_planned,
|
| 4366 |
auto_reduce_chunks,
|
| 4367 |
got_db_docs,
|
| 4368 |
add_search_to_context)
|
|
|
|
| 4380 |
else:
|
| 4381 |
# only if use_openai_model = True, unused normally except in testing
|
| 4382 |
chain = load_qa_with_sources_chain(llm)
|
| 4383 |
+
if not use_docs_planned:
|
| 4384 |
+
chain_kwargs = dict(input_documents=[], question=query)
|
| 4385 |
+
else:
|
| 4386 |
+
chain_kwargs = dict(input_documents=docs, question=query)
|
| 4387 |
target = wrapped_partial(chain, chain_kwargs)
|
| 4388 |
elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value,
|
| 4389 |
LangChainAction.SUMMARIZE_REFINE,
|
|
|
|
| 4427 |
else:
|
| 4428 |
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
| 4429 |
|
| 4430 |
+
return docs, target, scores, use_docs_planned, num_docs_before_cut, use_llm_if_no_docs, llm_mode, top_k_docs_max_show
|
| 4431 |
|
| 4432 |
|
| 4433 |
def get_max_model_length(llm=None, tokenizer=None, inference_server=None, model_name=None):
|
|
|
|
| 4473 |
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
|
| 4474 |
# more accurate
|
| 4475 |
return llm.pipeline.tokenizer
|
| 4476 |
+
elif hasattr(llm, 'tokenizer'):
|
| 4477 |
# e.g. TGI client mode etc.
|
| 4478 |
return llm.tokenizer
|
| 4479 |
elif inference_server in ['openai', 'openai_chat', 'openai_azure',
|
| 4480 |
+
'openai_azure_chat']:
|
| 4481 |
return tokenizer
|
| 4482 |
elif isinstance(tokenizer, FakeTokenizer):
|
| 4483 |
return tokenizer
|
|
|
|
| 4500 |
pre_prompt_query, prompt_query,
|
| 4501 |
pre_prompt_summary, prompt_summary,
|
| 4502 |
langchain_action,
|
| 4503 |
+
llm_mode,
|
| 4504 |
+
use_docs_planned,
|
| 4505 |
auto_reduce_chunks,
|
| 4506 |
got_db_docs,
|
| 4507 |
add_search_to_context):
|
|
|
|
| 4523 |
if langchain_action == LangChainAction.QUERY.value:
|
| 4524 |
if iinput:
|
| 4525 |
query = "%s\n%s" % (query, iinput)
|
| 4526 |
+
if llm_mode or not use_docs_planned:
|
| 4527 |
template_if_no_docs = template = """{context}{question}"""
|
| 4528 |
else:
|
| 4529 |
template = """%s
|
src/gradio_runner.py
CHANGED
|
@@ -737,7 +737,8 @@ def go_gradio(**kwargs):
|
|
| 737 |
visible=True,
|
| 738 |
elem_id="langchain_agents",
|
| 739 |
filterable=False)
|
| 740 |
-
visible_doc_track = upload_visible and kwargs['visible_doc_track'] and not kwargs[
|
|
|
|
| 741 |
row_doc_track = gr.Row(visible=visible_doc_track)
|
| 742 |
with row_doc_track:
|
| 743 |
if kwargs['langchain_mode'] in langchain_modes_non_db:
|
|
@@ -784,6 +785,9 @@ def go_gradio(**kwargs):
|
|
| 784 |
text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
|
| 785 |
show_copy_button=True)
|
| 786 |
|
|
|
|
|
|
|
|
|
|
| 787 |
# CHAT
|
| 788 |
col_chat = gr.Column(visible=kwargs['chat'])
|
| 789 |
with col_chat:
|
|
@@ -806,7 +810,8 @@ def go_gradio(**kwargs):
|
|
| 806 |
size="sm",
|
| 807 |
min_width=24,
|
| 808 |
file_types=['.' + x for x in file_types],
|
| 809 |
-
file_count="multiple"
|
|
|
|
| 810 |
|
| 811 |
submit_buttons = gr.Row(equal_height=False, visible=kwargs['visible_submit_buttons'])
|
| 812 |
with submit_buttons:
|
|
@@ -886,11 +891,9 @@ def go_gradio(**kwargs):
|
|
| 886 |
visible=sources_visible and allow_upload_to_user_data)
|
| 887 |
with gr.Column(scale=4):
|
| 888 |
pass
|
|
|
|
| 889 |
with gr.Row():
|
| 890 |
with gr.Column(scale=1):
|
| 891 |
-
visible_add_remove_collection = (allow_upload_to_user_data or
|
| 892 |
-
allow_upload_to_my_data) and \
|
| 893 |
-
kwargs['langchain_mode'] != 'Disabled'
|
| 894 |
add_placeholder = "e.g. UserData2, shared, user_path2" \
|
| 895 |
if not is_public else "e.g. MyData2, personal (optional)"
|
| 896 |
remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
|
|
@@ -1143,7 +1146,8 @@ def go_gradio(**kwargs):
|
|
| 1143 |
)
|
| 1144 |
min_max_new_tokens = gr.Slider(
|
| 1145 |
minimum=1, maximum=max_max_new_tokens, step=1,
|
| 1146 |
-
value=min(max_max_new_tokens, kwargs['min_max_new_tokens']),
|
|
|
|
| 1147 |
)
|
| 1148 |
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
| 1149 |
value=kwargs['early_stopping'], visible=max_beams > 1)
|
|
@@ -2881,7 +2885,6 @@ def go_gradio(**kwargs):
|
|
| 2881 |
history = args_list[-1]
|
| 2882 |
if not history:
|
| 2883 |
history = []
|
| 2884 |
-
# NOTE: For these, could check if None, then automatically use CLI values, but too complex behavior
|
| 2885 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
| 2886 |
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
| 2887 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
|
|
|
| 737 |
visible=True,
|
| 738 |
elem_id="langchain_agents",
|
| 739 |
filterable=False)
|
| 740 |
+
visible_doc_track = upload_visible and kwargs['visible_doc_track'] and not kwargs[
|
| 741 |
+
'large_file_count_mode']
|
| 742 |
row_doc_track = gr.Row(visible=visible_doc_track)
|
| 743 |
with row_doc_track:
|
| 744 |
if kwargs['langchain_mode'] in langchain_modes_non_db:
|
|
|
|
| 785 |
text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False,
|
| 786 |
show_copy_button=True)
|
| 787 |
|
| 788 |
+
visible_upload = (allow_upload_to_user_data or
|
| 789 |
+
allow_upload_to_my_data) and \
|
| 790 |
+
kwargs['langchain_mode'] != 'Disabled'
|
| 791 |
# CHAT
|
| 792 |
col_chat = gr.Column(visible=kwargs['chat'])
|
| 793 |
with col_chat:
|
|
|
|
| 810 |
size="sm",
|
| 811 |
min_width=24,
|
| 812 |
file_types=['.' + x for x in file_types],
|
| 813 |
+
file_count="multiple",
|
| 814 |
+
visible=visible_upload)
|
| 815 |
|
| 816 |
submit_buttons = gr.Row(equal_height=False, visible=kwargs['visible_submit_buttons'])
|
| 817 |
with submit_buttons:
|
|
|
|
| 891 |
visible=sources_visible and allow_upload_to_user_data)
|
| 892 |
with gr.Column(scale=4):
|
| 893 |
pass
|
| 894 |
+
visible_add_remove_collection = visible_upload
|
| 895 |
with gr.Row():
|
| 896 |
with gr.Column(scale=1):
|
|
|
|
|
|
|
|
|
|
| 897 |
add_placeholder = "e.g. UserData2, shared, user_path2" \
|
| 898 |
if not is_public else "e.g. MyData2, personal (optional)"
|
| 899 |
remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
|
|
|
|
| 1146 |
)
|
| 1147 |
min_max_new_tokens = gr.Slider(
|
| 1148 |
minimum=1, maximum=max_max_new_tokens, step=1,
|
| 1149 |
+
value=min(max_max_new_tokens, kwargs['min_max_new_tokens']),
|
| 1150 |
+
label="Min. of Max output length",
|
| 1151 |
)
|
| 1152 |
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
| 1153 |
value=kwargs['early_stopping'], visible=max_beams > 1)
|
|
|
|
| 2885 |
history = args_list[-1]
|
| 2886 |
if not history:
|
| 2887 |
history = []
|
|
|
|
| 2888 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
| 2889 |
prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
|
| 2890 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|