# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. import os import fire import json import re from collections import defaultdict from datasets import load_dataset from typing import Optional, List from llama import Llama from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch def extract_svg_from_text(text: str) -> Optional[str]: """ 从包含SVG的文本中提取出完整的...结构。 如果未匹配到,则返回一个默认的空SVG。 """ pattern = r"]*>.*?" matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE) if matches: return matches[0] else: return """""" def code_style_prompt(desc: str) -> str: return f"""\ // SVG CODE GENERATION TASK FOR CODELLAMA // OBJECTIVE: Create simple yet accurate SVG contour drawing // DESCRIPTION: {desc} // SVG Example(DESCRIPTION=wheelchair)(you do not need to generate an example as well): // CODE GENERATION INSTRUCTIONS: 1. Figure out the main parts of the object(animal) according to the DESCRIPTION 1. Fill path data for main-outline using basic commands 2. Position eye element at logical position 3. Keep all coordinates within viewBox 4. Use 2 decimal precision for coordinates 5. Close all path elements properly // {desc} GENERATION START FROM HERE: """ def post_process(code: str) -> str: """针对代码模型的输出优化后处理""" # 提取闭合的SVG代码块 svg_match = re.search(r'', code, re.DOTALL) if svg_match: code = svg_match.group(0) # 确保XML声明 if '\n' + code # 验证必要元素 required_elements = { '': 1, '', f'\n<{elem} />\n') return code.strip() def strict_svg_postprocess(raw_code: str) -> str: """ 严格按照需求设计的SVG后处理器 处理逻辑: 1. 按行处理,找到第一个不以 3. 自动添加标准头尾 """ # 预处理:清理前后空白,分割为行 lines = [line.strip() for line in raw_code.strip().split('\n')] # 阶段1:找到有效内容起始行 start_index = 0 for i, line in enumerate(lines): if not line.lower().startswith("标签 if re.match(r'', line, re.IGNORECASE): continue if re.match(r'<\s*svg', line, re.IGNORECASE): continue # 检查完整性(匹配XML标签语法) is_valid_tag = re.fullmatch( r'\s*<[^>]+/?>\s*', line, re.IGNORECASE ) # 检查是否已存在3次 count = line_counter.get(line, 0) if is_valid_tag and count < 1: valid_lines.append(line) line_counter[line] = count + 1 # 阶段3:组装最终结果 core_content = '\n'.join(valid_lines) return f''' {core_content} ''' def load_label_names(json_path: str) -> dict: """加载标签映射表""" with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) return data['dataset_info']['features'][0]['dtype']['class_label']['names'] def main_infer( ): # 初始化代码模型 # generator = Llama.build( # ckpt_dir=ckpt_dir, # tokenizer_path=tokenizer_path, # max_seq_len=max_seq_len, # max_batch_size=max_batch_size, # ) # 加载基础模型(请根据具体模型名称或路径调整) base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",token="") # 加载 LoRA 模型,加载 LoRA 权重(此处使用“steve329/llama3-8B-edit-lora-12k”) model = PeftModel.from_pretrained(base_model, "steve329/llama3-8B-edit-lora-12k") # 设置评估模式 model.eval() # 加载对应的分词器(确保与基础模型匹配) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) root = "/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean_llama8b" for dir in os.listdir(root): print(dir) if (dir == "LICENSE-CODE") or (dir == "LICENSE-IMAGES") or (dir == "README.md") or (dir == "CaseGenerator.py") : continue output_dir = os.path.jon(root+dir+'generated_svg') # 替换为实际的目标文件夹路径 # 确保目标文件夹存在,如果不存在则创建 if not os.path.exists(output_dir): os.makedirs(output_dir) file_dir = os.path.jon(root+dir+'query') i=0 for file in os.listdir(file_dir): file_name = os.path.splitext(file)[0] file_path = os.path.join(file_dir, file) with open(file_path, "r", encoding="utf-8") as file: content = file.read() print(content) if len(content) > 4383: file_path = os.path.join(output_dir, file_name + '.svg') # 将final_code写入到description.svg文件中 with open(file_path, 'w', encoding='utf-8') as svg_file: svg_file.write("""""") print(f"SVG文件已保存至: {file_path}") with open('/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean/'+dir+'/skipped_file.txt', 'w', encoding='utf-8') as f: f.write(f"{file_name}"+"\n") continue # test_input = ( # '{"instruction": "You are an expert SVG graphics generator. You generate clean, valid SVG code according to user instructions.", ' # f'"input": {content}' # ) inputs = tokenizer(content, return_tensors="pt") input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) # 使用模型生成文本(可以根据需要调整生成参数) with torch.no_grad(): generated_ids = model.generate( input_ids, attention_mask=attention_mask, max_length=4096, # 指定生成文本的最大长度 do_sample=True, # 是否使用采样,True 可生成更多样化结果 top_k=50, # Top-K 采样参数 top_p=0.95 # Top-p (nucleus) 采样参数 ) # 解码生成的 token 成为文本 generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) # print("Prompt:", prompt) # print("Generated text:", generated_text) # print("-" * 80) # results = generator.text_completion( # prompts=[content], # max_gen_len=max_gen_len, # temperature=temperature, # top_p=top_p, # ) # 后处理 # raw_code = results[0]['generation']aa print("raw_code:") print(generated_text) final_code = extract_svg_from_text(generated_text) # 输出结果 print(f"\n=== Input: {file_name} ===") print(f"// Generated SVG Code:") print(final_code) print("\n" + "=" * 40 + "\n") # 定义SVG文件的完整路径 file_path = os.path.join(output_dir, file_name + '.svg') # 将final_code写入到description.svg文件中 with open(file_path, 'w', encoding='utf-8') as svg_file: svg_file.write(final_code) print(f"SVG文件已保存至: {file_path}") if __name__ == "__main__": main_infer()