Spaces:
Runtime error
Runtime error
Commit
·
9bcca78
1
Parent(s):
63ce3fa
Update with h2oGPT hash 1b295baace42908075b47f31a84b359d8c6b1e52
Browse files- client_test.py +5 -1
- finetune.py +1 -1
- generate.py +6 -3
- gpt_langchain.py +21 -4
- gradio_runner.py +144 -87
- prompter.py +2 -0
- utils.py +1 -1
client_test.py
CHANGED
|
@@ -3,7 +3,7 @@ Client test.
|
|
| 3 |
|
| 4 |
Run server:
|
| 5 |
|
| 6 |
-
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-
|
| 7 |
|
| 8 |
NOTE: For private models, add --use-auth_token=True
|
| 9 |
|
|
@@ -39,6 +39,7 @@ Loaded as API: https://gpt.h2o.ai ✔
|
|
| 39 |
import time
|
| 40 |
import os
|
| 41 |
import markdown # pip install markdown
|
|
|
|
| 42 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
| 43 |
|
| 44 |
debug = False
|
|
@@ -79,6 +80,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
|
|
| 79 |
instruction_nochat=prompt if not chat else '',
|
| 80 |
iinput_nochat='', # only for chat=False
|
| 81 |
langchain_mode='Disabled',
|
|
|
|
| 82 |
)
|
| 83 |
if chat:
|
| 84 |
# add chatbot output on end. Assumes serialize=False
|
|
@@ -87,6 +89,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token
|
|
| 87 |
return kwargs, list(kwargs.values())
|
| 88 |
|
| 89 |
|
|
|
|
| 90 |
def test_client_basic():
|
| 91 |
return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
|
| 92 |
|
|
@@ -106,6 +109,7 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
|
| 106 |
return res_dict
|
| 107 |
|
| 108 |
|
|
|
|
| 109 |
def test_client_chat():
|
| 110 |
return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
|
| 111 |
|
|
|
|
| 3 |
|
| 4 |
Run server:
|
| 5 |
|
| 6 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
| 7 |
|
| 8 |
NOTE: For private models, add --use-auth_token=True
|
| 9 |
|
|
|
|
| 39 |
import time
|
| 40 |
import os
|
| 41 |
import markdown # pip install markdown
|
| 42 |
+
import pytest
|
| 43 |
from bs4 import BeautifulSoup # pip install beautifulsoup4
|
| 44 |
|
| 45 |
debug = False
|
|
|
|
| 80 |
instruction_nochat=prompt if not chat else '',
|
| 81 |
iinput_nochat='', # only for chat=False
|
| 82 |
langchain_mode='Disabled',
|
| 83 |
+
document_choice=['All'],
|
| 84 |
)
|
| 85 |
if chat:
|
| 86 |
# add chatbot output on end. Assumes serialize=False
|
|
|
|
| 89 |
return kwargs, list(kwargs.values())
|
| 90 |
|
| 91 |
|
| 92 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
| 93 |
def test_client_basic():
|
| 94 |
return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
|
| 95 |
|
|
|
|
| 109 |
return res_dict
|
| 110 |
|
| 111 |
|
| 112 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
| 113 |
def test_client_chat():
|
| 114 |
return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
|
| 115 |
|
finetune.py
CHANGED
|
@@ -26,7 +26,7 @@ def train(
|
|
| 26 |
save_code: bool = False,
|
| 27 |
run_id: int = None,
|
| 28 |
|
| 29 |
-
base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-
|
| 30 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
| 31 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
| 32 |
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
|
|
|
| 26 |
save_code: bool = False,
|
| 27 |
run_id: int = None,
|
| 28 |
|
| 29 |
+
base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
| 30 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
| 31 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
| 32 |
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
generate.py
CHANGED
|
@@ -297,7 +297,7 @@ def main(
|
|
| 297 |
if psutil.virtual_memory().available < 94 * 1024 ** 3:
|
| 298 |
# 12B uses ~94GB
|
| 299 |
# 6.9B uses ~47GB
|
| 300 |
-
base_model = 'h2oai/h2ogpt-oig-oasst1-512-
|
| 301 |
|
| 302 |
# get defaults
|
| 303 |
model_lower = base_model.lower()
|
|
@@ -864,6 +864,7 @@ eval_func_param_names = ['instruction',
|
|
| 864 |
'instruction_nochat',
|
| 865 |
'iinput_nochat',
|
| 866 |
'langchain_mode',
|
|
|
|
| 867 |
]
|
| 868 |
|
| 869 |
|
|
@@ -891,6 +892,7 @@ def evaluate(
|
|
| 891 |
instruction_nochat,
|
| 892 |
iinput_nochat,
|
| 893 |
langchain_mode,
|
|
|
|
| 894 |
# END NOTE: Examples must have same order of parameters
|
| 895 |
src_lang=None,
|
| 896 |
tgt_lang=None,
|
|
@@ -1010,6 +1012,7 @@ def evaluate(
|
|
| 1010 |
chunk=chunk,
|
| 1011 |
chunk_size=chunk_size,
|
| 1012 |
langchain_mode=langchain_mode,
|
|
|
|
| 1013 |
db_type=db_type,
|
| 1014 |
k=k,
|
| 1015 |
temperature=temperature,
|
|
@@ -1446,7 +1449,7 @@ y = np.random.randint(0, 1, 100)
|
|
| 1446 |
|
| 1447 |
# move to correct position
|
| 1448 |
for example in examples:
|
| 1449 |
-
example += [chat, '', '', 'Disabled']
|
| 1450 |
# adjust examples if non-chat mode
|
| 1451 |
if not chat:
|
| 1452 |
example[eval_func_param_names.index('instruction_nochat')] = example[
|
|
@@ -1546,6 +1549,6 @@ if __name__ == "__main__":
|
|
| 1546 |
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
| 1547 |
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
|
| 1548 |
|
| 1549 |
-
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-
|
| 1550 |
"""
|
| 1551 |
fire.Fire(main)
|
|
|
|
| 297 |
if psutil.virtual_memory().available < 94 * 1024 ** 3:
|
| 298 |
# 12B uses ~94GB
|
| 299 |
# 6.9B uses ~47GB
|
| 300 |
+
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
|
| 301 |
|
| 302 |
# get defaults
|
| 303 |
model_lower = base_model.lower()
|
|
|
|
| 864 |
'instruction_nochat',
|
| 865 |
'iinput_nochat',
|
| 866 |
'langchain_mode',
|
| 867 |
+
'document_choice',
|
| 868 |
]
|
| 869 |
|
| 870 |
|
|
|
|
| 892 |
instruction_nochat,
|
| 893 |
iinput_nochat,
|
| 894 |
langchain_mode,
|
| 895 |
+
document_choice,
|
| 896 |
# END NOTE: Examples must have same order of parameters
|
| 897 |
src_lang=None,
|
| 898 |
tgt_lang=None,
|
|
|
|
| 1012 |
chunk=chunk,
|
| 1013 |
chunk_size=chunk_size,
|
| 1014 |
langchain_mode=langchain_mode,
|
| 1015 |
+
document_choice=document_choice,
|
| 1016 |
db_type=db_type,
|
| 1017 |
k=k,
|
| 1018 |
temperature=temperature,
|
|
|
|
| 1449 |
|
| 1450 |
# move to correct position
|
| 1451 |
for example in examples:
|
| 1452 |
+
example += [chat, '', '', 'Disabled', ['All']]
|
| 1453 |
# adjust examples if non-chat mode
|
| 1454 |
if not chat:
|
| 1455 |
example[eval_func_param_names.index('instruction_nochat')] = example[
|
|
|
|
| 1549 |
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
| 1550 |
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
|
| 1551 |
|
| 1552 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
| 1553 |
"""
|
| 1554 |
fire.Fire(main)
|
gpt_langchain.py
CHANGED
|
@@ -150,7 +150,7 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
| 150 |
assert model_name is None
|
| 151 |
assert tokenizer is None
|
| 152 |
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
| 153 |
-
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-
|
| 154 |
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
| 155 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 156 |
device, torch_dtype, context_class = get_device_dtype()
|
|
@@ -593,7 +593,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
| 593 |
):
|
| 594 |
globs_image_types = []
|
| 595 |
globs_non_image_types = []
|
| 596 |
-
if path_or_paths
|
| 597 |
return []
|
| 598 |
elif url:
|
| 599 |
globs_non_image_types = [url]
|
|
@@ -846,6 +846,7 @@ def _run_qa_db(query=None,
|
|
| 846 |
top_k=40,
|
| 847 |
top_p=0.7,
|
| 848 |
langchain_mode=None,
|
|
|
|
| 849 |
n_jobs=-1):
|
| 850 |
"""
|
| 851 |
|
|
@@ -917,7 +918,23 @@ def _run_qa_db(query=None,
|
|
| 917 |
k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
|
| 918 |
|
| 919 |
if db and use_context:
|
| 920 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 921 |
# cut off so no high distance docs/sources considered
|
| 922 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
| 923 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
|
@@ -939,7 +956,7 @@ def _run_qa_db(query=None,
|
|
| 939 |
reduced_query_words = reduced_query.split(' ')
|
| 940 |
set_common = set(df['Lemma'].values.tolist())
|
| 941 |
num_common = len([x.lower() in set_common for x in reduced_query_words])
|
| 942 |
-
frac_common = num_common / len(reduced_query)
|
| 943 |
# FIXME: report to user bad query that uses too many common words
|
| 944 |
print("frac_common: %s" % frac_common, flush=True)
|
| 945 |
|
|
|
|
| 150 |
assert model_name is None
|
| 151 |
assert tokenizer is None
|
| 152 |
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
| 153 |
+
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
|
| 154 |
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
| 155 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 156 |
device, torch_dtype, context_class = get_device_dtype()
|
|
|
|
| 593 |
):
|
| 594 |
globs_image_types = []
|
| 595 |
globs_non_image_types = []
|
| 596 |
+
if not path_or_paths and not url and not text:
|
| 597 |
return []
|
| 598 |
elif url:
|
| 599 |
globs_non_image_types = [url]
|
|
|
|
| 846 |
top_k=40,
|
| 847 |
top_p=0.7,
|
| 848 |
langchain_mode=None,
|
| 849 |
+
document_choice=['All'],
|
| 850 |
n_jobs=-1):
|
| 851 |
"""
|
| 852 |
|
|
|
|
| 918 |
k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
|
| 919 |
|
| 920 |
if db and use_context:
|
| 921 |
+
if isinstance(document_choice, str):
|
| 922 |
+
# support string as well
|
| 923 |
+
document_choice = [document_choice]
|
| 924 |
+
if not isinstance(db, Chroma) or len(document_choice) <= 1 and document_choice[0].lower() == 'all':
|
| 925 |
+
# treat empty list as All for now, not 'None'
|
| 926 |
+
filter_kwargs = {}
|
| 927 |
+
else:
|
| 928 |
+
if len(document_choice) >= 2:
|
| 929 |
+
or_filter = [{"source": {"$eq": x}} for x in document_choice]
|
| 930 |
+
filter_kwargs = dict(filter={"$or": or_filter})
|
| 931 |
+
else:
|
| 932 |
+
one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
|
| 933 |
+
filter_kwargs = dict(filter=one_filter)
|
| 934 |
+
if len(document_choice) == 1 and document_choice[0].lower() == 'none':
|
| 935 |
+
k_db = 1
|
| 936 |
+
k = 0
|
| 937 |
+
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
|
| 938 |
# cut off so no high distance docs/sources considered
|
| 939 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
| 940 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
|
|
|
| 956 |
reduced_query_words = reduced_query.split(' ')
|
| 957 |
set_common = set(df['Lemma'].values.tolist())
|
| 958 |
num_common = len([x.lower() in set_common for x in reduced_query_words])
|
| 959 |
+
frac_common = num_common / len(reduced_query) if reduced_query else 0
|
| 960 |
# FIXME: report to user bad query that uses too many common words
|
| 961 |
print("frac_common: %s" % frac_common, flush=True)
|
| 962 |
|
gradio_runner.py
CHANGED
|
@@ -96,7 +96,13 @@ def go_gradio(**kwargs):
|
|
| 96 |
css_code = """footer {visibility: hidden}"""
|
| 97 |
css_code += """
|
| 98 |
body.dark{#warning {background-color: #555555};}
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
if kwargs['gradio_avoid_processing_markdown']:
|
| 102 |
from gradio_client import utils as client_utils
|
|
@@ -167,6 +173,7 @@ body.dark{#warning {background-color: #555555};}
|
|
| 167 |
lora_options_state = gr.State([lora_options])
|
| 168 |
my_db_state = gr.State([None, None])
|
| 169 |
chat_state = gr.State({})
|
|
|
|
| 170 |
gr.Markdown(f"""
|
| 171 |
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
| 172 |
|
|
@@ -175,7 +182,7 @@ body.dark{#warning {background-color: #555555};}
|
|
| 175 |
""")
|
| 176 |
if is_hf:
|
| 177 |
gr.HTML(
|
| 178 |
-
|
| 179 |
|
| 180 |
# go button visible if
|
| 181 |
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
|
@@ -220,7 +227,7 @@ body.dark{#warning {background-color: #555555};}
|
|
| 220 |
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
| 221 |
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
| 222 |
with gr.Row():
|
| 223 |
-
clear = gr.Button("Save
|
| 224 |
flag_btn = gr.Button("Flag")
|
| 225 |
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
| 226 |
with gr.Column(visible=kwargs['score_model']):
|
|
@@ -251,19 +258,16 @@ body.dark{#warning {background-color: #555555};}
|
|
| 251 |
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
|
| 252 |
type='value')
|
| 253 |
with gr.Row():
|
| 254 |
-
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
|
| 255 |
clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
with chats_row2:
|
| 262 |
chatsup_output = gr.File(label="Upload Chat File(s)",
|
| 263 |
file_types=['.json'],
|
| 264 |
file_count='multiple',
|
| 265 |
elem_id="warning", elem_classes="feedback")
|
| 266 |
-
add_to_chats_btn = gr.Button("Add File(s) to Chats")
|
| 267 |
with gr.TabItem("Data Source"):
|
| 268 |
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
|
| 269 |
from_str=True)
|
|
@@ -275,8 +279,8 @@ body.dark{#warning {background-color: #555555};}
|
|
| 275 |
<p>
|
| 276 |
For more options see: {langchain_readme}""",
|
| 277 |
visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
|
| 278 |
-
|
| 279 |
-
with
|
| 280 |
if is_hf:
|
| 281 |
# don't show 'wiki' since only usually useful for internal testing at moment
|
| 282 |
no_show_modes = ['Disabled', 'wiki']
|
|
@@ -292,77 +296,92 @@ body.dark{#warning {background-color: #555555};}
|
|
| 292 |
langchain_mode = gr.Radio(
|
| 293 |
[x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
|
| 294 |
value=kwargs['langchain_mode'],
|
| 295 |
-
label="Data
|
| 296 |
visible=kwargs['langchain_mode'] != 'Disabled')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
-
def upload_file(files, x):
|
| 299 |
-
file_paths = [file.name for file in files]
|
| 300 |
-
return files, file_paths
|
| 301 |
-
|
| 302 |
-
upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
|
| 303 |
-
equal_height=False)
|
| 304 |
# import control
|
| 305 |
if kwargs['langchain_mode'] != 'Disabled':
|
| 306 |
from gpt_langchain import file_types, have_arxiv
|
| 307 |
else:
|
| 308 |
have_arxiv = False
|
| 309 |
file_types = []
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
fileup_output = gr.File(label=f'Upload {file_types_str}',
|
| 313 |
-
file_types=file_types,
|
| 314 |
-
file_count="multiple",
|
| 315 |
-
elem_id="warning", elem_classes="feedback")
|
| 316 |
-
with gr.Row():
|
| 317 |
-
upload_button = gr.UploadButton("Upload %s" % file_types_str,
|
| 318 |
-
file_types=file_types,
|
| 319 |
-
file_count="multiple",
|
| 320 |
-
visible=False,
|
| 321 |
-
)
|
| 322 |
-
# add not visible until upload something
|
| 323 |
-
with gr.Column():
|
| 324 |
-
add_to_shared_db_btn = gr.Button("Add File(s) to Shared UserData DB",
|
| 325 |
-
visible=allow_upload_to_user_data) # and False)
|
| 326 |
-
add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData DB",
|
| 327 |
-
visible=allow_upload_to_my_data) # and False)
|
| 328 |
-
url_row = gr.Row(
|
| 329 |
-
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload).style(
|
| 330 |
equal_height=False)
|
| 331 |
-
with
|
| 332 |
-
url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
|
| 333 |
-
url_text = gr.Textbox(label=url_label, interactive=True)
|
| 334 |
with gr.Column():
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
equal_height=False)
|
| 342 |
-
with
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
visible=allow_upload_to_my_data)
|
| 349 |
-
# WIP:
|
| 350 |
-
with gr.Row(visible=False).style(equal_height=False):
|
| 351 |
-
github_textbox = gr.Textbox(label="Github URL")
|
| 352 |
-
with gr.Row(visible=True):
|
| 353 |
-
github_shared_btn = gr.Button(value="Add Github to Shared UserData DB",
|
| 354 |
-
visible=allow_upload_to_user_data)
|
| 355 |
-
github_my_btn = gr.Button(value="Add Github to Scratch MyData DB",
|
| 356 |
-
visible=allow_upload_to_my_data)
|
| 357 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
| 358 |
equal_height=False)
|
| 359 |
with sources_row:
|
| 360 |
sources_text = gr.HTML(label='Sources Added', interactive=False)
|
| 361 |
-
sources_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
| 362 |
-
equal_height=False)
|
| 363 |
-
with sources_row2:
|
| 364 |
-
get_sources_btn = gr.Button(value="Get Sources List for Selected DB")
|
| 365 |
-
file_source = gr.File(interactive=False, label="Download File with list of Sources")
|
| 366 |
|
| 367 |
with gr.TabItem("Expert"):
|
| 368 |
with gr.Row():
|
|
@@ -545,14 +564,6 @@ body.dark{#warning {background-color: #555555};}
|
|
| 545 |
def make_visible():
|
| 546 |
return gr.update(visible=True)
|
| 547 |
|
| 548 |
-
# add itself to output to ensure shows working and can't click again
|
| 549 |
-
upload_button.upload(upload_file, inputs=[upload_button, fileup_output],
|
| 550 |
-
outputs=[upload_button, fileup_output], queue=queue,
|
| 551 |
-
api_name='upload_file' if allow_api else None) \
|
| 552 |
-
.then(make_add_visible, fileup_output, add_to_shared_db_btn, queue=queue) \
|
| 553 |
-
.then(make_add_visible, fileup_output, add_to_my_db_btn, queue=queue) \
|
| 554 |
-
.then(make_invisible, outputs=upload_button, queue=queue)
|
| 555 |
-
|
| 556 |
# Add to UserData
|
| 557 |
update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
|
| 558 |
use_openai_embedding=use_openai_embedding,
|
|
@@ -623,8 +634,23 @@ body.dark{#warning {background-color: #555555};}
|
|
| 623 |
.then(clear_textbox, outputs=user_text_text, queue=queue)
|
| 624 |
|
| 625 |
get_sources1 = functools.partial(get_sources, dbs=dbs)
|
| 626 |
-
|
| 627 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
|
| 629 |
def check_admin_pass(x):
|
| 630 |
return gr.update(visible=x == admin_pass)
|
|
@@ -818,6 +844,11 @@ body.dark{#warning {background-color: #555555};}
|
|
| 818 |
my_db_state1 = args_list[-2]
|
| 819 |
history = args_list[-1]
|
| 820 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
| 822 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
| 823 |
if retry and history:
|
|
@@ -827,13 +858,19 @@ body.dark{#warning {background-color: #555555};}
|
|
| 827 |
args_list[eval_func_param_names.index('do_sample')] = True
|
| 828 |
if not history:
|
| 829 |
print("No history", flush=True)
|
| 830 |
-
history = [
|
| 831 |
yield history, ''
|
| 832 |
return
|
| 833 |
# ensure output will be unique to models
|
| 834 |
_, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
|
| 835 |
history = copy.deepcopy(history)
|
| 836 |
instruction1 = history[-1][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
context1 = ''
|
| 838 |
if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
|
| 839 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
|
@@ -867,10 +904,6 @@ body.dark{#warning {background-color: #555555};}
|
|
| 867 |
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
| 868 |
args_list[0] = instruction1 # override original instruction with history from user
|
| 869 |
args_list[2] = context1
|
| 870 |
-
if model_state1[0] is None or model_state1[0] == no_model_str:
|
| 871 |
-
history = [['', None]]
|
| 872 |
-
yield history, ''
|
| 873 |
-
return
|
| 874 |
fun1 = partial(evaluate,
|
| 875 |
model_state1,
|
| 876 |
my_db_state1,
|
|
@@ -1086,10 +1119,14 @@ body.dark{#warning {background-color: #555555};}
|
|
| 1086 |
api_name='export_chats' if allow_api else None)
|
| 1087 |
|
| 1088 |
def add_chats_from_file(file, chat_state1, add_btn):
|
|
|
|
|
|
|
| 1089 |
if isinstance(file, str):
|
| 1090 |
files = [file]
|
| 1091 |
else:
|
| 1092 |
files = file
|
|
|
|
|
|
|
| 1093 |
for file1 in files:
|
| 1094 |
try:
|
| 1095 |
if hasattr(file1, 'name'):
|
|
@@ -1350,22 +1387,28 @@ def get_inputs_list(inputs_dict, model_lower):
|
|
| 1350 |
def get_sources(db1, langchain_mode, dbs=None):
|
| 1351 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
| 1352 |
source_files_added = "NA"
|
|
|
|
| 1353 |
elif langchain_mode in ['wiki_full']:
|
| 1354 |
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
| 1355 |
" Ask [email protected] for file if required."
|
|
|
|
| 1356 |
elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
|
| 1357 |
db_get = db1[0].get()
|
| 1358 |
-
|
|
|
|
| 1359 |
elif langchain_mode in dbs and dbs[langchain_mode] is not None:
|
| 1360 |
db1 = dbs[langchain_mode]
|
| 1361 |
db_get = db1.get()
|
| 1362 |
-
|
|
|
|
| 1363 |
else:
|
|
|
|
| 1364 |
source_files_added = "None"
|
| 1365 |
sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
|
| 1366 |
with open(sources_file, "wt") as f:
|
| 1367 |
f.write(source_files_added)
|
| 1368 |
-
|
|
|
|
| 1369 |
|
| 1370 |
|
| 1371 |
def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
|
|
@@ -1465,6 +1508,20 @@ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='Use
|
|
| 1465 |
return x, y, source_files_added
|
| 1466 |
|
| 1467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1468 |
def get_source_files(db, exceptions=None):
|
| 1469 |
if exceptions is None:
|
| 1470 |
exceptions = []
|
|
|
|
| 96 |
css_code = """footer {visibility: hidden}"""
|
| 97 |
css_code += """
|
| 98 |
body.dark{#warning {background-color: #555555};}
|
| 99 |
+
#small_btn {
|
| 100 |
+
margin: 0.6em 0em 0.55em 0;
|
| 101 |
+
max-width: 20em;
|
| 102 |
+
min-width: 5em !important;
|
| 103 |
+
height: 5em;
|
| 104 |
+
font-size: 14px !important
|
| 105 |
+
}"""
|
| 106 |
|
| 107 |
if kwargs['gradio_avoid_processing_markdown']:
|
| 108 |
from gradio_client import utils as client_utils
|
|
|
|
| 173 |
lora_options_state = gr.State([lora_options])
|
| 174 |
my_db_state = gr.State([None, None])
|
| 175 |
chat_state = gr.State({})
|
| 176 |
+
docs_state = gr.State(['All'])
|
| 177 |
gr.Markdown(f"""
|
| 178 |
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
| 179 |
|
|
|
|
| 182 |
""")
|
| 183 |
if is_hf:
|
| 184 |
gr.HTML(
|
| 185 |
+
)
|
| 186 |
|
| 187 |
# go button visible if
|
| 188 |
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
|
|
|
| 227 |
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
| 228 |
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
| 229 |
with gr.Row():
|
| 230 |
+
clear = gr.Button("Save Chat / New Chat")
|
| 231 |
flag_btn = gr.Button("Flag")
|
| 232 |
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
| 233 |
with gr.Column(visible=kwargs['score_model']):
|
|
|
|
| 258 |
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
|
| 259 |
type='value')
|
| 260 |
with gr.Row():
|
|
|
|
| 261 |
clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
|
| 262 |
+
export_chats_btn = gr.Button(value="Export Chats to Download")
|
| 263 |
+
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
|
| 264 |
+
add_to_chats_btn = gr.Button("Import Chats from Upload")
|
| 265 |
+
with gr.Row():
|
| 266 |
+
chats_file = gr.File(interactive=False, label="Download Exported Chats")
|
|
|
|
| 267 |
chatsup_output = gr.File(label="Upload Chat File(s)",
|
| 268 |
file_types=['.json'],
|
| 269 |
file_count='multiple',
|
| 270 |
elem_id="warning", elem_classes="feedback")
|
|
|
|
| 271 |
with gr.TabItem("Data Source"):
|
| 272 |
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
|
| 273 |
from_str=True)
|
|
|
|
| 279 |
<p>
|
| 280 |
For more options see: {langchain_readme}""",
|
| 281 |
visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
|
| 282 |
+
data_row1 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
| 283 |
+
with data_row1:
|
| 284 |
if is_hf:
|
| 285 |
# don't show 'wiki' since only usually useful for internal testing at moment
|
| 286 |
no_show_modes = ['Disabled', 'wiki']
|
|
|
|
| 296 |
langchain_mode = gr.Radio(
|
| 297 |
[x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
|
| 298 |
value=kwargs['langchain_mode'],
|
| 299 |
+
label="Data Collection of Sources",
|
| 300 |
visible=kwargs['langchain_mode'] != 'Disabled')
|
| 301 |
+
data_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
|
| 302 |
+
with data_row2:
|
| 303 |
+
with gr.Column(scale=50):
|
| 304 |
+
document_choice = gr.Dropdown(docs_state.value,
|
| 305 |
+
label="Choose Subset of Doc(s) in Collection [click get to update]",
|
| 306 |
+
value=docs_state.value[0],
|
| 307 |
+
interactive=True,
|
| 308 |
+
multiselect=True,
|
| 309 |
+
)
|
| 310 |
+
with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list):
|
| 311 |
+
get_sources_btn = gr.Button(value="Get Sources",
|
| 312 |
+
).style(full_width=False, size='sm')
|
| 313 |
+
show_sources_btn = gr.Button(value="Show Sources",
|
| 314 |
+
).style(full_width=False, size='sm')
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
# import control
|
| 317 |
if kwargs['langchain_mode'] != 'Disabled':
|
| 318 |
from gpt_langchain import file_types, have_arxiv
|
| 319 |
else:
|
| 320 |
have_arxiv = False
|
| 321 |
file_types = []
|
| 322 |
+
|
| 323 |
+
upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
equal_height=False)
|
| 325 |
+
with upload_row:
|
|
|
|
|
|
|
| 326 |
with gr.Column():
|
| 327 |
+
file_types_str = '[' + ' '.join(file_types) + ']'
|
| 328 |
+
fileup_output = gr.File(label=f'Upload {file_types_str}',
|
| 329 |
+
file_types=file_types,
|
| 330 |
+
file_count="multiple",
|
| 331 |
+
elem_id="warning", elem_classes="feedback")
|
| 332 |
+
with gr.Row():
|
| 333 |
+
add_to_shared_db_btn = gr.Button("Add File(s) to UserData",
|
| 334 |
+
visible=allow_upload_to_user_data, elem_id='small_btn')
|
| 335 |
+
add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData",
|
| 336 |
+
visible=allow_upload_to_my_data,
|
| 337 |
+
elem_id='small_btn' if allow_upload_to_user_data else None,
|
| 338 |
+
).style(
|
| 339 |
+
size='sm' if not allow_upload_to_user_data else None)
|
| 340 |
+
with gr.Column(
|
| 341 |
+
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload):
|
| 342 |
+
url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
|
| 343 |
+
url_text = gr.Textbox(label=url_label, interactive=True)
|
| 344 |
+
with gr.Row():
|
| 345 |
+
url_user_btn = gr.Button(value='Add URL content to Shared UserData',
|
| 346 |
+
visible=allow_upload_to_user_data, elem_id='small_btn')
|
| 347 |
+
url_my_btn = gr.Button(value='Add URL content to Scratch MyData',
|
| 348 |
+
visible=allow_upload_to_my_data,
|
| 349 |
+
elem_id='small_btn' if allow_upload_to_user_data else None,
|
| 350 |
+
).style(size='sm' if not allow_upload_to_user_data else None)
|
| 351 |
+
with gr.Column(
|
| 352 |
+
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload):
|
| 353 |
+
user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]', interactive=True)
|
| 354 |
+
with gr.Row():
|
| 355 |
+
user_text_user_btn = gr.Button(value='Add Text to Shared UserData',
|
| 356 |
+
visible=allow_upload_to_user_data,
|
| 357 |
+
elem_id='small_btn')
|
| 358 |
+
user_text_my_btn = gr.Button(value='Add Text to Scratch MyData',
|
| 359 |
+
visible=allow_upload_to_my_data,
|
| 360 |
+
elem_id='small_btn' if allow_upload_to_user_data else None,
|
| 361 |
+
).style(
|
| 362 |
+
size='sm' if not allow_upload_to_user_data else None)
|
| 363 |
+
with gr.Column(visible=False):
|
| 364 |
+
# WIP:
|
| 365 |
+
with gr.Row(visible=False).style(equal_height=False):
|
| 366 |
+
github_textbox = gr.Textbox(label="Github URL")
|
| 367 |
+
with gr.Row(visible=True):
|
| 368 |
+
github_shared_btn = gr.Button(value="Add Github to Shared UserData",
|
| 369 |
+
visible=allow_upload_to_user_data,
|
| 370 |
+
elem_id='small_btn')
|
| 371 |
+
github_my_btn = gr.Button(value="Add Github to Scratch MyData",
|
| 372 |
+
visible=allow_upload_to_my_data, elem_id='small_btn')
|
| 373 |
+
sources_row3 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
| 374 |
equal_height=False)
|
| 375 |
+
with sources_row3:
|
| 376 |
+
with gr.Column(scale=1):
|
| 377 |
+
file_source = gr.File(interactive=False,
|
| 378 |
+
label="Download File with Sources [click get to make file]")
|
| 379 |
+
with gr.Column(scale=2):
|
| 380 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
|
| 382 |
equal_height=False)
|
| 383 |
with sources_row:
|
| 384 |
sources_text = gr.HTML(label='Sources Added', interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
with gr.TabItem("Expert"):
|
| 387 |
with gr.Row():
|
|
|
|
| 564 |
def make_visible():
|
| 565 |
return gr.update(visible=True)
|
| 566 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
# Add to UserData
|
| 568 |
update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
|
| 569 |
use_openai_embedding=use_openai_embedding,
|
|
|
|
| 634 |
.then(clear_textbox, outputs=user_text_text, queue=queue)
|
| 635 |
|
| 636 |
get_sources1 = functools.partial(get_sources, dbs=dbs)
|
| 637 |
+
|
| 638 |
+
# if change collection source, must clear doc selections from it to avoid inconsistency
|
| 639 |
+
def clear_doc_choice():
|
| 640 |
+
return gr.Dropdown.update(choices=['All'], value=['All'])
|
| 641 |
+
|
| 642 |
+
langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice)
|
| 643 |
+
|
| 644 |
+
def update_dropdown(x):
|
| 645 |
+
return gr.Dropdown.update(choices=x, value='All')
|
| 646 |
+
|
| 647 |
+
show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs)
|
| 648 |
+
get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=[file_source, docs_state],
|
| 649 |
+
queue=queue,
|
| 650 |
+
api_name='get_sources' if allow_api else None) \
|
| 651 |
+
.then(fn=update_dropdown, inputs=docs_state, outputs=document_choice)
|
| 652 |
+
# show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe
|
| 653 |
+
show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text)
|
| 654 |
|
| 655 |
def check_admin_pass(x):
|
| 656 |
return gr.update(visible=x == admin_pass)
|
|
|
|
| 844 |
my_db_state1 = args_list[-2]
|
| 845 |
history = args_list[-1]
|
| 846 |
|
| 847 |
+
if model_state1[0] is None or model_state1[0] == no_model_str:
|
| 848 |
+
history = []
|
| 849 |
+
yield history, ''
|
| 850 |
+
return
|
| 851 |
+
|
| 852 |
args_list = args_list[:-3] # only keep rest needed for evaluate()
|
| 853 |
langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
|
| 854 |
if retry and history:
|
|
|
|
| 858 |
args_list[eval_func_param_names.index('do_sample')] = True
|
| 859 |
if not history:
|
| 860 |
print("No history", flush=True)
|
| 861 |
+
history = []
|
| 862 |
yield history, ''
|
| 863 |
return
|
| 864 |
# ensure output will be unique to models
|
| 865 |
_, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
|
| 866 |
history = copy.deepcopy(history)
|
| 867 |
instruction1 = history[-1][0]
|
| 868 |
+
if not instruction1:
|
| 869 |
+
# reject empty query, can sometimes go nuts
|
| 870 |
+
history = []
|
| 871 |
+
yield history, ''
|
| 872 |
+
return
|
| 873 |
+
|
| 874 |
context1 = ''
|
| 875 |
if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
|
| 876 |
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
|
|
|
| 904 |
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
| 905 |
args_list[0] = instruction1 # override original instruction with history from user
|
| 906 |
args_list[2] = context1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
fun1 = partial(evaluate,
|
| 908 |
model_state1,
|
| 909 |
my_db_state1,
|
|
|
|
| 1119 |
api_name='export_chats' if allow_api else None)
|
| 1120 |
|
| 1121 |
def add_chats_from_file(file, chat_state1, add_btn):
|
| 1122 |
+
if not file:
|
| 1123 |
+
return chat_state1, add_btn
|
| 1124 |
if isinstance(file, str):
|
| 1125 |
files = [file]
|
| 1126 |
else:
|
| 1127 |
files = file
|
| 1128 |
+
if not files:
|
| 1129 |
+
return chat_state1, add_btn
|
| 1130 |
for file1 in files:
|
| 1131 |
try:
|
| 1132 |
if hasattr(file1, 'name'):
|
|
|
|
| 1387 |
def get_sources(db1, langchain_mode, dbs=None):
|
| 1388 |
if langchain_mode in ['ChatLLM', 'LLM']:
|
| 1389 |
source_files_added = "NA"
|
| 1390 |
+
source_list = []
|
| 1391 |
elif langchain_mode in ['wiki_full']:
|
| 1392 |
source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
|
| 1393 |
" Ask [email protected] for file if required."
|
| 1394 |
+
source_list = []
|
| 1395 |
elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
|
| 1396 |
db_get = db1[0].get()
|
| 1397 |
+
source_list = sorted(set([x['source'] for x in db_get['metadatas']]))
|
| 1398 |
+
source_files_added = '\n'.join(source_list)
|
| 1399 |
elif langchain_mode in dbs and dbs[langchain_mode] is not None:
|
| 1400 |
db1 = dbs[langchain_mode]
|
| 1401 |
db_get = db1.get()
|
| 1402 |
+
source_list = sorted(set([x['source'] for x in db_get['metadatas']]))
|
| 1403 |
+
source_files_added = '\n'.join(source_list)
|
| 1404 |
else:
|
| 1405 |
+
source_list = []
|
| 1406 |
source_files_added = "None"
|
| 1407 |
sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
|
| 1408 |
with open(sources_file, "wt") as f:
|
| 1409 |
f.write(source_files_added)
|
| 1410 |
+
source_list = ['All'] + source_list
|
| 1411 |
+
return sources_file, source_list
|
| 1412 |
|
| 1413 |
|
| 1414 |
def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
|
|
|
|
| 1508 |
return x, y, source_files_added
|
| 1509 |
|
| 1510 |
|
| 1511 |
+
def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
|
| 1512 |
+
with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
|
| 1513 |
+
if langchain_mode in ['wiki_full']:
|
| 1514 |
+
# NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
|
| 1515 |
+
db = None
|
| 1516 |
+
elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
|
| 1517 |
+
db = db1[0]
|
| 1518 |
+
elif langchain_mode in dbs and dbs[langchain_mode] is not None:
|
| 1519 |
+
db = dbs[langchain_mode]
|
| 1520 |
+
else:
|
| 1521 |
+
db = None
|
| 1522 |
+
return get_source_files(db, exceptions=None)
|
| 1523 |
+
|
| 1524 |
+
|
| 1525 |
def get_source_files(db, exceptions=None):
|
| 1526 |
if exceptions is None:
|
| 1527 |
exceptions = []
|
prompter.py
CHANGED
|
@@ -56,6 +56,8 @@ prompt_type_to_model_name = {
|
|
| 56 |
'h2oai/h2ogpt-oasst1-512-20b',
|
| 57 |
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
| 58 |
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
|
|
|
|
|
|
| 59 |
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
| 60 |
],
|
| 61 |
'dai_faq': [],
|
|
|
|
| 56 |
'h2oai/h2ogpt-oasst1-512-20b',
|
| 57 |
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
| 58 |
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
| 59 |
+
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
| 60 |
+
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
| 61 |
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
| 62 |
],
|
| 63 |
'dai_faq': [],
|
utils.py
CHANGED
|
@@ -148,7 +148,7 @@ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
|
| 148 |
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
| 149 |
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
| 150 |
assert root_dirs is not None
|
| 151 |
-
if not os.path.isdir(os.path.dirname(zip_file)):
|
| 152 |
os.makedirs(os.path.dirname(zip_file), exist_ok=True)
|
| 153 |
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
| 154 |
for root_dir in root_dirs:
|
|
|
|
| 148 |
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
| 149 |
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
| 150 |
assert root_dirs is not None
|
| 151 |
+
if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname(zip_file):
|
| 152 |
os.makedirs(os.path.dirname(zip_file), exist_ok=True)
|
| 153 |
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
| 154 |
for root_dir in root_dirs:
|