Spaces:
Runtime error
Runtime error
chore: add bad words
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ from transformers import (
|
|
| 7 |
AutoModelForPreTraining,
|
| 8 |
AutoProcessor,
|
| 9 |
AutoConfig,
|
|
|
|
| 10 |
)
|
| 11 |
from huggingface_hub import hf_hub_download
|
| 12 |
from safetensors.torch import load_file
|
|
@@ -18,10 +19,16 @@ assert MODEL_NAME is not None
|
|
| 18 |
MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
|
| 19 |
DEVICE = torch.device("cuda")
|
| 20 |
|
|
|
|
| 21 |
|
| 22 |
def fix_compiled_state_dict(state_dict: dict):
|
| 23 |
return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def prepare_models():
|
| 27 |
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
@@ -44,6 +51,7 @@ def prepare_models():
|
|
| 44 |
|
| 45 |
def demo():
|
| 46 |
model, processor = prepare_models()
|
|
|
|
| 47 |
|
| 48 |
@spaces.GPU(duration=5)
|
| 49 |
@torch.inference_mode()
|
|
@@ -83,6 +91,7 @@ def demo():
|
|
| 83 |
top_p=top_p,
|
| 84 |
eos_token_id=processor.decoder_tokenizer.eos_token_id,
|
| 85 |
pad_token_id=processor.decoder_tokenizer.pad_token_id,
|
|
|
|
| 86 |
)
|
| 87 |
elapsed = time.time() - start_time
|
| 88 |
|
|
|
|
| 7 |
AutoModelForPreTraining,
|
| 8 |
AutoProcessor,
|
| 9 |
AutoConfig,
|
| 10 |
+
PreTrainedTokenizerFast
|
| 11 |
)
|
| 12 |
from huggingface_hub import hf_hub_download
|
| 13 |
from safetensors.torch import load_file
|
|
|
|
| 19 |
MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
|
| 20 |
DEVICE = torch.device("cuda")
|
| 21 |
|
| 22 |
+
BAD_WORD_KEYWORDS = ["(medium)"]
|
| 23 |
|
| 24 |
def fix_compiled_state_dict(state_dict: dict):
|
| 25 |
return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
|
| 26 |
|
| 27 |
+
def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast):
|
| 28 |
+
ids = [
|
| 29 |
+
[id] for token, id in tokenizer.vocab.items() if any(word in token for BAD_WORD_KEYWORDS)
|
| 30 |
+
]
|
| 31 |
+
return ids
|
| 32 |
|
| 33 |
def prepare_models():
|
| 34 |
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
|
| 51 |
|
| 52 |
def demo():
|
| 53 |
model, processor = prepare_models()
|
| 54 |
+
ban_ids = get_bad_words_ids(processor.decoder_tokenizer)
|
| 55 |
|
| 56 |
@spaces.GPU(duration=5)
|
| 57 |
@torch.inference_mode()
|
|
|
|
| 91 |
top_p=top_p,
|
| 92 |
eos_token_id=processor.decoder_tokenizer.eos_token_id,
|
| 93 |
pad_token_id=processor.decoder_tokenizer.pad_token_id,
|
| 94 |
+
bad_words_ids=ban_ids,
|
| 95 |
)
|
| 96 |
elapsed = time.time() - start_time
|
| 97 |
|