Spaces:
Build error
Build error
| import os | |
| import sys | |
| from llamafactory.chat import ChatModel | |
| from llamafactory.extras.misc import torch_gc | |
| from dotenv import find_dotenv, load_dotenv | |
| found_dotenv = find_dotenv(".env") | |
| if len(found_dotenv) == 0: | |
| found_dotenv = find_dotenv(".env.example") | |
| print(f"loading env vars from: {found_dotenv}") | |
| load_dotenv(found_dotenv, override=False) | |
| path = os.path.dirname(found_dotenv) | |
| print(f"Adding {path} to sys.path") | |
| sys.path.append(path) | |
| from llm_toolkit.translation_engine import * | |
| from llm_toolkit.translation_utils import * | |
| model_name = os.getenv("MODEL_NAME") | |
| load_in_4bit = os.getenv("LOAD_IN_4BIT") == "true" | |
| eval_base_model = os.getenv("EVAL_BASE_MODEL") == "true" | |
| eval_fine_tuned = os.getenv("EVAL_FINE_TUNED") == "true" | |
| save_fine_tuned_model = os.getenv("SAVE_FINE_TUNED") == "true" | |
| num_train_epochs = int(os.getenv("NUM_TRAIN_EPOCHS") or 0) | |
| data_path = os.getenv("DATA_PATH") | |
| results_path = os.getenv("RESULTS_PATH") | |
| max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally! | |
| dtype = ( | |
| None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ | |
| ) | |
| print( | |
| model_name, | |
| load_in_4bit, | |
| max_seq_length, | |
| num_train_epochs, | |
| dtype, | |
| data_path, | |
| results_path, | |
| eval_base_model, | |
| eval_fine_tuned, | |
| save_fine_tuned_model, | |
| ) | |
| adapter_name_or_path = ( | |
| sys.argv[1] | |
| if len(sys.argv) > 1 | |
| else "llama-factory/saves/qwen2-0.5b/lora/sft/checkpoint-560" | |
| ) | |
| args = dict( | |
| model_name_or_path=model_name, # use bnb-4bit-quantized Llama-3-8B-Instruct model | |
| adapter_name_or_path=adapter_name_or_path, # load the saved LoRA adapters | |
| template="chatml", # same to the one in training | |
| finetuning_type="lora", # same to the one in training | |
| quantization_bit=4, # load 4-bit quantized model | |
| ) | |
| chat_model = ChatModel(args) | |
| messages = [] | |
| print( | |
| "Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application." | |
| ) | |
| while True: | |
| query = input("\nUser: ") | |
| if query.strip() == "exit": | |
| break | |
| if query.strip() == "clear": | |
| messages = [] | |
| torch_gc() | |
| print("History has been removed.") | |
| continue | |
| messages.append({"role": "user", "content": query}) | |
| print("Assistant: ", end="", flush=True) | |
| response = "" | |
| for new_text in chat_model.stream_chat(messages): | |
| print(new_text, end="", flush=True) | |
| response += new_text | |
| print() | |
| messages.append({"role": "assistant", "content": response}) | |
| torch_gc() | |