File size: 9,761 Bytes
dd89c43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import fire
import json
import re
from collections import defaultdict
from datasets import load_dataset
from typing import Optional, List
from llama import Llama
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch



def extract_svg_from_text(text: str) -> Optional[str]:
    """
    从包含SVG的文本中提取出完整的<svg>...</svg>结构。
    如果未匹配到,则返回一个默认的空SVG。
    """
    pattern = r"<svg\b[^>]*>.*?</svg>"
    matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
    if matches:
        return matches[0]
    else:
        return """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>"""


def code_style_prompt(desc: str) -> str:
    return f"""\
// SVG CODE GENERATION TASK FOR CODELLAMA
// OBJECTIVE: Create simple yet accurate SVG contour drawing
// DESCRIPTION: {desc}

// SVG Example(DESCRIPTION=wheelchair)(you do not need to generate an example as well):
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
  <!-- Wheelchair seat -->
  <path d="M30,40 L50,40 L50,60 L30,60 Z" fill="#555"/>

  <!-- Wheelchair back -->
  <path d="M30,40 L20,30 L20,20 L30,20 L30,40" fill="#555"/>

  <!-- Large wheel -->
  <circle cx="65" cy="65" r="25" stroke="#333" stroke-width="3" fill="none"/>
  <circle cx="65" cy="65" r="5" fill="#333"/>

  <!-- Small wheel -->
  <circle cx="30" cy="70" r="10" stroke="#333" stroke-width="3" fill="none"/>
  <circle cx="30" cy="70" r="3" fill="#333"/>

  <!-- Wheel spokes (large wheel) -->
  <line x1="65" y1="65" x2="80" y2="65" stroke="#333" stroke-width="2"/>
  <line x1="65" y1="65" x2="65" y2="80" stroke="#333" stroke-width="2"/>
  <line x1="65" y1="65" x2="55" y2="75" stroke="#333" stroke-width="2"/>
  <line x1="65" y1="65" x2="55" y2="55" stroke="#333" stroke-width="2"/>

  <!-- Wheel spokes (small wheel) -->
  <line x1="30" y1="70" x2="38" y2="70" stroke="#333" stroke-width="2"/>
  <line x1="30" y1="70" x2="30" y2="78" stroke="#333" stroke-width="2"/>
</svg>


// CODE GENERATION INSTRUCTIONS:
1. Figure out the main parts of the object(animal) according to the DESCRIPTION
1. Fill path data for main-outline using basic commands
2. Position eye element at logical position
3. Keep all coordinates within viewBox
4. Use 2 decimal precision for coordinates
5. Close all path elements properly

// {desc} GENERATION START FROM HERE:
"""


def post_process(code: str) -> str:
    """针对代码模型的输出优化后处理"""
    # 提取闭合的SVG代码块
    svg_match = re.search(r'<svg.*?</svg>', code, re.DOTALL)
    if svg_match:
        code = svg_match.group(0)

    # 确保XML声明
    if '<?xml' not in code:
        code = '<?xml version="1.0" encoding="UTF-8"?>\n' + code

    # 验证必要元素
    required_elements = {
        '<svg': 1,
        '</svg>': 1,
        '<path': 1,
        '<circle': 1
    }
    for elem, count in required_elements.items():
        if code.count(elem) < count:
            code = code.replace('</svg>',
                                f'<!-- Auto-added {elem} -->\n<{elem} />\n</svg>')

    return code.strip()


def strict_svg_postprocess(raw_code: str) -> str:
    """
    严格按照需求设计的SVG后处理器

    处理逻辑:
    1. 按行处理,找到第一个不以<svg开头的行作为内容起点
    2. 逐行检查:去重(最多3次)、完整性、排除</svg>
    3. 自动添加标准头尾
    """
    # 预处理:清理前后空白,分割为行
    lines = [line.strip() for line in raw_code.strip().split('\n')]

    # 阶段1:找到有效内容起始行
    start_index = 0
    for i, line in enumerate(lines):
        if not line.lower().startswith("<svg"):
            start_index = i
            break

    # 阶段2:逐行处理有效内容
    valid_lines = []
    line_counter = {}

    for line in lines[start_index:]:
        # 排除</svg>标签
        if re.match(r'</\s*svg\s*>', line, re.IGNORECASE):
            continue
        if re.match(r'<\s*svg', line, re.IGNORECASE):
            continue
        # 检查完整性(匹配XML标签语法)
        is_valid_tag = re.fullmatch(
            r'\s*<[^>]+/?>\s*',
            line,
            re.IGNORECASE
        )

        # 检查是否已存在3次
        count = line_counter.get(line, 0)

        if is_valid_tag and count < 1:
            valid_lines.append(line)
            line_counter[line] = count + 1

    # 阶段3:组装最终结果
    core_content = '\n'.join(valid_lines)

    return f'''<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
{core_content}
</svg>'''


def load_label_names(json_path: str) -> dict:
    """加载标签映射表"""
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data['dataset_info']['features'][0]['dtype']['class_label']['names']


def main_infer( ):
    # 初始化代码模型
    # generator = Llama.build(
    #     ckpt_dir=ckpt_dir,
    #     tokenizer_path=tokenizer_path,
    #     max_seq_len=max_seq_len,
    #     max_batch_size=max_batch_size,
    # )
    # 加载基础模型(请根据具体模型名称或路径调整)
    base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",token="")
    # 加载 LoRA 模型,加载 LoRA 权重(此处使用“steve329/llama3-8B-edit-lora-12k”)
    model = PeftModel.from_pretrained(base_model, "steve329/llama3-8B-edit-lora-12k")
    # 设置评估模式
    model.eval()
    # 加载对应的分词器(确保与基础模型匹配)
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)


    root = "/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean_llama8b"
    for dir in os.listdir(root):
        print(dir)
        if (dir == "LICENSE-CODE") or (dir == "LICENSE-IMAGES") or (dir == "README.md") or (dir == "CaseGenerator.py") : continue
        output_dir = os.path.jon(root+dir+'generated_svg') # 替换为实际的目标文件夹路径

        # 确保目标文件夹存在,如果不存在则创建
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        file_dir = os.path.jon(root+dir+'query')
        i=0
        for file in os.listdir(file_dir):
            
            file_name = os.path.splitext(file)[0]
            file_path = os.path.join(file_dir, file)
            with open(file_path, "r", encoding="utf-8") as file:
                content = file.read()
            print(content)
            if len(content) > 4383:
                file_path = os.path.join(output_dir, file_name + '.svg')

                # 将final_code写入到description.svg文件中
                with open(file_path, 'w', encoding='utf-8') as svg_file:
                    svg_file.write("""<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>""")

                print(f"SVG文件已保存至: {file_path}")
                with open('/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean/'+dir+'/skipped_file.txt', 'w', encoding='utf-8') as f:
                    f.write(f"{file_name}"+"\n")

                continue

            # test_input = (
            #     '{"instruction": "You are an expert SVG graphics generator. You generate clean, valid SVG code according to user instructions.", '
            #     f'"input": {content}'
            # )

            inputs = tokenizer(content, return_tensors="pt")
            input_ids = inputs.input_ids.to(device)
            attention_mask = inputs.attention_mask.to(device)
            # 使用模型生成文本(可以根据需要调整生成参数)
            with torch.no_grad():
                generated_ids = model.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    max_length=4096,  # 指定生成文本的最大长度
                    do_sample=True,  # 是否使用采样,True 可生成更多样化结果
                    top_k=50,  # Top-K 采样参数
                    top_p=0.95  # Top-p (nucleus) 采样参数
                )

            # 解码生成的 token 成为文本
            generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
            # print("Prompt:", prompt)
            # print("Generated text:", generated_text)
            # print("-" * 80)
            # results = generator.text_completion(
            #     prompts=[content],
            #     max_gen_len=max_gen_len,
            #     temperature=temperature,
            #     top_p=top_p,
            # )





            # 后处理
            # raw_code = results[0]['generation']aa
            print("raw_code:")
            print(generated_text)
            final_code = extract_svg_from_text(generated_text)

            # 输出结果
            print(f"\n=== Input: {file_name} ===")
            print(f"// Generated SVG Code:")
            print(final_code)
            print("\n" + "=" * 40 + "\n")

            # 定义SVG文件的完整路径
            file_path = os.path.join(output_dir, file_name + '.svg')

            # 将final_code写入到description.svg文件中
            with open(file_path, 'w', encoding='utf-8') as svg_file:
                svg_file.write(final_code)

            print(f"SVG文件已保存至: {file_path}")


if __name__ == "__main__":
    main_infer()