Spaces:
Runtime error
Runtime error
| from typing import Tuple | |
| import os | |
| import time | |
| import json | |
| from pathlib import Path | |
| import torch | |
| from fairscale.nn.model_parallel.initialize import initialize_model_parallel | |
| from llama.generation import LLaMA | |
| from llama.model import ModelArgs, Transformer | |
| from llama.tokenizer import Tokenizer | |
| from google.cloud import storage | |
| bucket_name = os.environ.get("GCS_BUCKET") | |
| llama_weight_path = "weights/llama" | |
| tokenizer_weight_path = "weights/tokenizer" | |
| def setup_model_parallel() -> Tuple[int, int]: | |
| local_rank = int(os.environ.get("LOCAL_RANK", -1)) | |
| world_size = int(os.environ.get("WORLD_SIZE", -1)) | |
| torch.distributed.init_process_group("nccl") | |
| initialize_model_parallel(world_size) | |
| torch.cuda.set_device(local_rank) | |
| # seed must be the same in all processes | |
| torch.manual_seed(1) | |
| return local_rank, world_size | |
| def download_pretrained_models( | |
| ckpt_path: str, | |
| tokenizer_path: str | |
| ): | |
| os.makedirs(llama_weight_path) | |
| os.makedirs(tokenizer_weight_path) | |
| storage_client = storage.Client.create_anonymous_client() | |
| bucket = storage_client.bucket(bucket_name) | |
| blobs = bucket.list_blobs(prefix=f"{ckpt_path}/") | |
| for blob in blobs: | |
| filename = blob.name.split("/")[1] | |
| blob.download_to_filename(f"{llama_weight_path}/{filename}") | |
| blobs = bucket.list_blobs(prefix=f"{tokenizer_path}/") | |
| for blob in blobs: | |
| filename = blob.name.split("/")[1] | |
| blob.download_to_filename(f"{tokenizer_weight_path}/{filename}") | |
| def get_pretrained_models( | |
| ckpt_path: str, | |
| tokenizer_path: str, | |
| local_rank: int, | |
| world_size: int) -> LLaMA: | |
| download_pretrained_models(ckpt_path, tokenizer_path) | |
| start_time = time.time() | |
| checkpoints = sorted(Path(llama_weight_path).glob("*.pth")) | |
| llama_ckpt_path = checkpoints[local_rank] | |
| print("Loading") | |
| checkpoint = torch.load(llama_ckpt_path, map_location="cpu") | |
| with open(Path(llama_weight_path) / "params.json", "r") as f: | |
| params = json.loads(f.read()) | |
| model_args: ModelArgs = ModelArgs(max_seq_len=512, max_batch_size=1, **params) | |
| tokenizer = Tokenizer(model_path=f"{tokenizer_weight_path}/tokenizer.model") | |
| model_args.vocab_size = tokenizer.n_words | |
| torch.set_default_tensor_type(torch.cuda.HalfTensor) | |
| model = Transformer(model_args).cuda().half() | |
| torch.set_default_tensor_type(torch.FloatTensor) | |
| model.load_state_dict(checkpoint, strict=False) | |
| generator = LLaMA(model, tokenizer) | |
| print(f"Loaded in {time.time() - start_time:.2f} seconds") | |
| return generator | |
| def get_output( | |
| generator: LLaMA, | |
| prompt: str, | |
| max_gen_len: int = 256, | |
| temperature: float = 0.8, | |
| top_p: float = 0.95): | |
| prompts = [prompt] | |
| results = generator.generate( | |
| prompts, | |
| max_gen_len=max_gen_len, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| return results |