Spaces:
Paused
Paused
| # vllm https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html | |
| from .svg_validator_base import SVGValidator, register_validator | |
| from starvector.data.util import rasterize_svg, clean_svg, use_placeholder | |
| from svgpathtools import svgstr2paths | |
| from vllm import LLM, SamplingParams | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| class StarVectorVLLMValidator(SVGValidator): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| model_name = config.model.name | |
| if config.model.from_checkpoint: | |
| model_name = config.model.from_checkpoint | |
| self.llm = LLM(model=model_name, trust_remote_code=True, dtype=config.model.torch_dtype) | |
| self.get_dataloader(config) | |
| def generate_svg(self, batch, generate_config): | |
| prompt_start = "<image-start>" | |
| model_inputs_vllm = [] | |
| for i, sample in enumerate(batch['Svg']): | |
| image = rasterize_svg(sample, self.config.dataset.im_size) | |
| model_inputs_vllm.append({ | |
| "prompt": prompt_start, | |
| "multi_modal_data": {"image": image} | |
| }) | |
| sampling_params = SamplingParams( | |
| temperature=generate_config['temperature'], | |
| top_p=generate_config['top_p'], | |
| top_k=generate_config['top_k'], | |
| max_tokens=generate_config['max_length'], | |
| n=generate_config['num_generations'], | |
| frequency_penalty=generate_config['frequency_penalty'], | |
| repetition_penalty=generate_config['repetition_penalty'], | |
| presence_penalty=generate_config['presence_penalty'], | |
| min_p=generate_config['min_p'], | |
| ) | |
| completions = self.llm.generate(model_inputs_vllm, | |
| sampling_params=sampling_params, | |
| use_tqdm=False) | |
| outputs = [] | |
| for i in range(len(completions)): | |
| for j in range(len(completions[i].outputs)): | |
| outputs.append(completions[i].outputs[j].text) | |
| return outputs | |
| def get_dataloader(self, config): | |
| data = load_dataset(config.dataset.dataset_name, config.dataset.config_name, split=config.dataset.split) | |
| if config.dataset.num_samples != -1: | |
| data = data.select(range(config.dataset.num_samples)) | |
| self.dataloader = DataLoader(data, batch_size=self.config.dataset.batch_size, shuffle=False, num_workers=self.config.dataset.num_workers) | |
| def release_memory(self): | |
| if self.llm is not None: | |
| # Delete the LLM instance | |
| del self.llm | |
| self.llm = None | |
| # Force garbage collection | |
| import gc | |
| gc.collect() | |
| # If using PyTorch, you can also explicitly clear CUDA cache | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def _handle_stream_response(self, response): | |
| generated_text = "<svg" | |
| for chunk in response: | |
| new_text = chunk.choices[0].delta.content if chunk.choices[0].delta.content else "" | |
| generated_text += new_text | |
| return generated_text | |