|
import time |
|
from typing import Optional |
|
import subprocess |
|
|
|
import torch |
|
import os |
|
|
|
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM |
|
from tensorizer import TensorDeserializer |
|
from tensorizer.utils import no_init_or_tensor |
|
from collections import OrderedDict |
|
from cog import BasePredictor, ConcatenateIterator, Input, Path |
|
|
|
|
|
from subclass import YieldingReplitCode |
|
|
|
|
|
|
|
|
|
|
|
|
|
TENSORIZER_WEIGHTS_PATH = "gs://replicate-weights/replit-code-v1-3b/model.tensors" |
|
|
|
|
|
|
|
|
|
DEFAULT_CONFIG_PATH = "model/" |
|
TOKENIZER_PATH = "model/" |
|
|
|
def maybe_download(path): |
|
if path.startswith("gs://"): |
|
st = time.time() |
|
output_path = "/tmp/weights.tensors" |
|
subprocess.check_call(["gcloud", "storage", "cp", path, output_path]) |
|
print(f"weights downloaded in {time.time() - st}") |
|
return output_path |
|
return path |
|
|
|
|
|
class Predictor(BasePredictor): |
|
def setup(self): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
self.model = self.load_tensorizer( |
|
weights=maybe_download(TENSORIZER_WEIGHTS_PATH), plaid_mode=True, cls=YieldingReplitCode, config_path=DEFAULT_CONFIG_PATH, |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) |
|
|
|
def load_tensorizer(self, weights, plaid_mode, cls, config_path): |
|
st = time.time() |
|
print(f"deserializing weights from {weights}") |
|
|
|
config = AutoConfig.from_pretrained(config_path, trust_remote_code=True) |
|
config.attn_config['attn_impl'] = 'triton' |
|
|
|
|
|
|
|
|
|
|
|
model = no_init_or_tensor( |
|
lambda: cls.from_pretrained( |
|
None, config=config, state_dict=OrderedDict(), trust_remote_code=True, |
|
) |
|
) |
|
|
|
|
|
deserialized = TensorDeserializer(weights, plaid_mode=True) |
|
deserialized.load_into_module(model) |
|
try: |
|
model = model.to(dtype=torch.bfloat16) |
|
except: |
|
pass |
|
|
|
print(f"weights loaded in {time.time() - st}") |
|
return model |
|
|
|
def predict( |
|
self, |
|
prompt: str = Input(description=f"Text prompt"), |
|
max_length: int = Input( |
|
description="Maximum number of tokens to generate. A word is generally 2-3 tokens", |
|
ge=1, |
|
default=500, |
|
), |
|
temperature: float = Input( |
|
description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value.", |
|
ge=0.01, |
|
le=5, |
|
default=0.75, |
|
), |
|
top_p: float = Input( |
|
description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", |
|
ge=0.01, |
|
le=1.0, |
|
default=1.0, |
|
), |
|
repetition_penalty: float = Input( |
|
description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.", |
|
ge=0.01, |
|
le=5, |
|
default=1, |
|
), |
|
length_penalty: float = Input( |
|
description="Increasing the length_penalty parameter above 1.0 will cause the model to favor longer sequences, while decreasing it below 1.0 will cause the model to favor shorter sequences.", |
|
ge=0.01, |
|
le=5, |
|
default=1, |
|
), |
|
no_repeat_ngram_size: int = Input( |
|
description="If set to int > 0, all ngrams of size no_repeat_ngram_size can only occur once.", |
|
ge=0, |
|
default=0, |
|
), |
|
stop_sequence: str = Input( |
|
description="Generation will hault if this token is produced. Currently, only single token stop sequences are support and it is recommended to use `###` as the stop sequence if you want to control generation termination.", |
|
default=None, |
|
), |
|
seed: int = Input( |
|
description="Set seed for reproducible outputs. Set to -1 for random seed.", |
|
ge=-1, |
|
default=-1, |
|
), |
|
debug: bool = Input( |
|
description="provide debugging output in logs", default=False |
|
), |
|
) -> ConcatenateIterator[str]: |
|
input = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) |
|
|
|
|
|
if seed == -1: |
|
torch.seed() |
|
|
|
else: |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
with torch.inference_mode(): |
|
first_token_yielded = False |
|
prev_ids = [] |
|
for output in self.model.generate( |
|
input, |
|
max_length=max_length, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=length_penalty, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
): |
|
cur_id = output.item() |
|
|
|
|
|
|
|
cur_token = self.tokenizer.convert_ids_to_tokens(cur_id) |
|
|
|
|
|
if not first_token_yielded and not prev_ids and cur_id == 187: |
|
continue |
|
|
|
|
|
if cur_token.startswith("Ġ"): |
|
|
|
if not prev_ids: |
|
prev_ids = [cur_id] |
|
continue |
|
|
|
|
|
else: |
|
token = self.tokenizer.decode(prev_ids, clean_up_tokenization_spaces=False) |
|
prev_ids = [cur_id] |
|
|
|
if not first_token_yielded: |
|
|
|
token = token.strip() |
|
first_token_yielded = True |
|
yield token |
|
|
|
elif cur_token == "<|endoftext|>": |
|
break |
|
|
|
elif stop_sequence and cur_token == stop_sequence: |
|
break |
|
|
|
else: |
|
prev_ids.append(cur_id) |
|
continue |
|
|
|
|
|
token = self.tokenizer.decode(prev_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
if not first_token_yielded: |
|
|
|
token = token.strip() |
|
first_token_yielded = True |
|
yield token |
|
|
|
if debug: |
|
print(f"cur memory: {torch.cuda.memory_allocated()}") |
|
print(f"max allocated: {torch.cuda.max_memory_allocated()}") |
|
print(f"peak memory: {torch.cuda.max_memory_reserved()}") |
|
|