File size: 9,761 Bytes
dd89c43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import fire
import json
import re
from collections import defaultdict
from datasets import load_dataset
from typing import Optional, List
from llama import Llama
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def extract_svg_from_text(text: str) -> Optional[str]:
"""
从包含SVG的文本中提取出完整的<svg>...</svg>结构。
如果未匹配到,则返回一个默认的空SVG。
"""
pattern = r"<svg\b[^>]*>.*?</svg>"
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
if matches:
return matches[0]
else:
return """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>"""
def code_style_prompt(desc: str) -> str:
return f"""\
// SVG CODE GENERATION TASK FOR CODELLAMA
// OBJECTIVE: Create simple yet accurate SVG contour drawing
// DESCRIPTION: {desc}
// SVG Example(DESCRIPTION=wheelchair)(you do not need to generate an example as well):
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
<!-- Wheelchair seat -->
<path d="M30,40 L50,40 L50,60 L30,60 Z" fill="#555"/>
<!-- Wheelchair back -->
<path d="M30,40 L20,30 L20,20 L30,20 L30,40" fill="#555"/>
<!-- Large wheel -->
<circle cx="65" cy="65" r="25" stroke="#333" stroke-width="3" fill="none"/>
<circle cx="65" cy="65" r="5" fill="#333"/>
<!-- Small wheel -->
<circle cx="30" cy="70" r="10" stroke="#333" stroke-width="3" fill="none"/>
<circle cx="30" cy="70" r="3" fill="#333"/>
<!-- Wheel spokes (large wheel) -->
<line x1="65" y1="65" x2="80" y2="65" stroke="#333" stroke-width="2"/>
<line x1="65" y1="65" x2="65" y2="80" stroke="#333" stroke-width="2"/>
<line x1="65" y1="65" x2="55" y2="75" stroke="#333" stroke-width="2"/>
<line x1="65" y1="65" x2="55" y2="55" stroke="#333" stroke-width="2"/>
<!-- Wheel spokes (small wheel) -->
<line x1="30" y1="70" x2="38" y2="70" stroke="#333" stroke-width="2"/>
<line x1="30" y1="70" x2="30" y2="78" stroke="#333" stroke-width="2"/>
</svg>
// CODE GENERATION INSTRUCTIONS:
1. Figure out the main parts of the object(animal) according to the DESCRIPTION
1. Fill path data for main-outline using basic commands
2. Position eye element at logical position
3. Keep all coordinates within viewBox
4. Use 2 decimal precision for coordinates
5. Close all path elements properly
// {desc} GENERATION START FROM HERE:
"""
def post_process(code: str) -> str:
"""针对代码模型的输出优化后处理"""
# 提取闭合的SVG代码块
svg_match = re.search(r'<svg.*?</svg>', code, re.DOTALL)
if svg_match:
code = svg_match.group(0)
# 确保XML声明
if '<?xml' not in code:
code = '<?xml version="1.0" encoding="UTF-8"?>\n' + code
# 验证必要元素
required_elements = {
'<svg': 1,
'</svg>': 1,
'<path': 1,
'<circle': 1
}
for elem, count in required_elements.items():
if code.count(elem) < count:
code = code.replace('</svg>',
f'<!-- Auto-added {elem} -->\n<{elem} />\n</svg>')
return code.strip()
def strict_svg_postprocess(raw_code: str) -> str:
"""
严格按照需求设计的SVG后处理器
处理逻辑:
1. 按行处理,找到第一个不以<svg开头的行作为内容起点
2. 逐行检查:去重(最多3次)、完整性、排除</svg>
3. 自动添加标准头尾
"""
# 预处理:清理前后空白,分割为行
lines = [line.strip() for line in raw_code.strip().split('\n')]
# 阶段1:找到有效内容起始行
start_index = 0
for i, line in enumerate(lines):
if not line.lower().startswith("<svg"):
start_index = i
break
# 阶段2:逐行处理有效内容
valid_lines = []
line_counter = {}
for line in lines[start_index:]:
# 排除</svg>标签
if re.match(r'</\s*svg\s*>', line, re.IGNORECASE):
continue
if re.match(r'<\s*svg', line, re.IGNORECASE):
continue
# 检查完整性(匹配XML标签语法)
is_valid_tag = re.fullmatch(
r'\s*<[^>]+/?>\s*',
line,
re.IGNORECASE
)
# 检查是否已存在3次
count = line_counter.get(line, 0)
if is_valid_tag and count < 1:
valid_lines.append(line)
line_counter[line] = count + 1
# 阶段3:组装最终结果
core_content = '\n'.join(valid_lines)
return f'''<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
{core_content}
</svg>'''
def load_label_names(json_path: str) -> dict:
"""加载标签映射表"""
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data['dataset_info']['features'][0]['dtype']['class_label']['names']
def main_infer( ):
# 初始化代码模型
# generator = Llama.build(
# ckpt_dir=ckpt_dir,
# tokenizer_path=tokenizer_path,
# max_seq_len=max_seq_len,
# max_batch_size=max_batch_size,
# )
# 加载基础模型(请根据具体模型名称或路径调整)
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",token="")
# 加载 LoRA 模型,加载 LoRA 权重(此处使用“steve329/llama3-8B-edit-lora-12k”)
model = PeftModel.from_pretrained(base_model, "steve329/llama3-8B-edit-lora-12k")
# 设置评估模式
model.eval()
# 加载对应的分词器(确保与基础模型匹配)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
root = "/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean_llama8b"
for dir in os.listdir(root):
print(dir)
if (dir == "LICENSE-CODE") or (dir == "LICENSE-IMAGES") or (dir == "README.md") or (dir == "CaseGenerator.py") : continue
output_dir = os.path.jon(root+dir+'generated_svg') # 替换为实际的目标文件夹路径
# 确保目标文件夹存在,如果不存在则创建
if not os.path.exists(output_dir):
os.makedirs(output_dir)
file_dir = os.path.jon(root+dir+'query')
i=0
for file in os.listdir(file_dir):
file_name = os.path.splitext(file)[0]
file_path = os.path.join(file_dir, file)
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
print(content)
if len(content) > 4383:
file_path = os.path.join(output_dir, file_name + '.svg')
# 将final_code写入到description.svg文件中
with open(file_path, 'w', encoding='utf-8') as svg_file:
svg_file.write("""<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>""")
print(f"SVG文件已保存至: {file_path}")
with open('/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean/'+dir+'/skipped_file.txt', 'w', encoding='utf-8') as f:
f.write(f"{file_name}"+"\n")
continue
# test_input = (
# '{"instruction": "You are an expert SVG graphics generator. You generate clean, valid SVG code according to user instructions.", '
# f'"input": {content}'
# )
inputs = tokenizer(content, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
# 使用模型生成文本(可以根据需要调整生成参数)
with torch.no_grad():
generated_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=4096, # 指定生成文本的最大长度
do_sample=True, # 是否使用采样,True 可生成更多样化结果
top_k=50, # Top-K 采样参数
top_p=0.95 # Top-p (nucleus) 采样参数
)
# 解码生成的 token 成为文本
generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
# print("Prompt:", prompt)
# print("Generated text:", generated_text)
# print("-" * 80)
# results = generator.text_completion(
# prompts=[content],
# max_gen_len=max_gen_len,
# temperature=temperature,
# top_p=top_p,
# )
# 后处理
# raw_code = results[0]['generation']aa
print("raw_code:")
print(generated_text)
final_code = extract_svg_from_text(generated_text)
# 输出结果
print(f"\n=== Input: {file_name} ===")
print(f"// Generated SVG Code:")
print(final_code)
print("\n" + "=" * 40 + "\n")
# 定义SVG文件的完整路径
file_path = os.path.join(output_dir, file_name + '.svg')
# 将final_code写入到description.svg文件中
with open(file_path, 'w', encoding='utf-8') as svg_file:
svg_file.write(final_code)
print(f"SVG文件已保存至: {file_path}")
if __name__ == "__main__":
main_infer() |