Spaces:
Sleeping
Sleeping
File size: 5,598 Bytes
6a2e815 6abca35 6a2e815 f0666ff 15d500a d36cdc2 8a62cfd d36cdc2 5010b53 d36cdc2 c14a5b1 d36cdc2 f51ad9a d36cdc2 f51ad9a 15d500a 6a2e815 6abca35 50ca940 6abca35 6a2e815 a94d4e4 6a2e815 d36cdc2 a94d4e4 8a62cfd 6abca35 d36cdc2 6abca35 15d500a c14a5b1 15d500a 6abca35 15d500a 6abca35 15d500a 6abca35 15d500a 6abca35 15d500a 6abca35 15d500a 6abca35 15d500a 6abca35 15d500a 6abca35 15d500a d36cdc2 15d500a 6abca35 15d500a 6abca35 cb99192 15d500a d36cdc2 15d500a d36cdc2 15d500a d36cdc2 15d500a 6abca35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import gradio as gr
from peft import PeftModel, PeftConfig
from transformers import (
MistralForCausalLM,
TextIteratorStreamer,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
)
from time import sleep
from threading import Thread
from torch import float16
import spaces
import huggingface_hub
from threading import Thread
from queue import Queue
from time import sleep
from os import getenv
from data_logger import log_data
from datetime import datetime
def check_thread(logging_queue: Queue):
logging_callback = log_data(
hf_token=getenv("HF_API_TOKEN"),
dataset_name=getenv("OUTPUT_DATASET"),
private=True,
)
print("Logging thread started.")
print(f"Logging to '{getenv('OUTPUT_DATASET')}'")
while True:
print("Checking for logs...")
sleep(60)
batch = []
while not logging_queue.empty():
batch.append(logging_queue.get())
if len(batch) > 0:
try:
logging_callback(batch)
except:
print(
"Error happened while pushing data to HF. Puttting items back in queue..."
)
for item in batch:
logging_queue.put(item)
if getenv("HF_API_TOKEN") is not None:
print("Starting logging thread...")
log_queue = Queue()
t = Thread(target=check_thread, args=(log_queue,))
t.start()
else:
print("No HF_API_TOKEN found. Logging is disabled.")
config = PeftConfig.from_pretrained("lang-uk/dragoman")
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=float16,
bnb_4bit_use_double_quant=False,
)
model = MistralForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1", quantization_config=quant_config
)
# device_map="auto",)
model = PeftModel.from_pretrained(model, "lang-uk/dragoman").to("cuda")
tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-v0.1", use_fast=False, add_bos_token=False
)
@spaces.GPU(duration=30)
def translate(input_text):
# generated_text = ""
input_text = input_text.strip()
print(f"{datetime.utcnow()} | Translating: {input_text}")
if getenv("HF_API_TOKEN") is not None:
log_queue.put([input_text])
input_text = f"[INST] {input_text} [/INST]"
inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
generation_kwargs = dict(
inputs, max_new_tokens=200, num_beams=10, temperature=1, pad_token_id=tokenizer.eos_token_id
) # streamer=streamer,
# streaming support
# streamer = TextIteratorStreamer(
# tokenizer, skip_prompt=True, skip_special_tokens=True
# )
# thread = Thread(target=model.generate, kwargs=generation_kwargs)
# thread.start()
# for new_text in streamer:
# generated_text += new_text
# yield generated_text
# generated_text += "\n"
# yield generated_text
output = model.generate(**generation_kwargs)
output = (
tokenizer.decode(output[0], skip_special_tokens=True)
.split("[/INST] ")[-1]
.strip()
)
return output
# download description of the model
desc_file = huggingface_hub.hf_hub_download("lang-uk/dragoman", "README.md")
with open(desc_file, "r") as f:
model_description = f.read()
model_description = model_description[model_description.find("---", 1) + 5 :]
model_description = (
"""### By using this service, users are required to agree to the following terms: you agree that user input will be collected for future research and model improvements. \n\n"""
+ model_description
)
iface = gr.Interface(
fn=translate,
inputs=gr.Textbox(
value='This demo contains a model from paper "Setting up the Data Printer with Improved English to Ukrainian Machine Translation", accepted to UNLP 2024 workshop at the LREC-COLING 2024 conference.',
label="Source sentence",
),
outputs=gr.Textbox(
value='Ця демо-версія містить модель із статті "Налаштування принтера даних із покращеним машинним перекладом з англійської на українську", яка була прийнята до семінару UNLP 2024 на конференції LREC-COLING 2024.',
label="Translated sentence",
),
examples=[
[
"How many leaves would it drop in a month of February in a non-leap year?",
],
[
"ChatGPT (Chat Generative Pre-trained Transformer) is a chatbot developed by OpenAI and launched on November 30, 2022. Based on a large language model, it enables users to refine and steer a conversation towards a desired length, format, style, level of detail, and language. Successive prompts and replies, known as prompt engineering, are considered at each conversation stage as a context.[2] ",
],
[
"who holds this neighborhood?",
],
],
title="Dragoman: SOTA English-Ukrainian translation model",
description='This demo contains a model from paper "Setting up the Data Printer with Improved English to Ukrainian Machine Translation", accepted to UNLP 2024 workshop at the LREC-COLING 2024 conference.',
article=model_description,
# thumbnail: str | None = None,
# css: str | None = None,
# batch: bool = False,
# max_batch_size: int = 4,
# api_name: str | Literal[False] | None = "predict",
submit_btn="Translate",
)
iface.launch()
|