steve329 commited on
Commit
dd89c43
·
verified ·
1 Parent(s): 240f5fd

Upload sample_llama3-8B.py

Browse files
Files changed (1) hide show
  1. sample_llama3-8B.py +270 -0
sample_llama3-8B.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3
+ import os
4
+ import fire
5
+ import json
6
+ import re
7
+ from collections import defaultdict
8
+ from datasets import load_dataset
9
+ from typing import Optional, List
10
+ from llama import Llama
11
+ from peft import PeftModel
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ import torch
14
+
15
+
16
+
17
+ def extract_svg_from_text(text: str) -> Optional[str]:
18
+ """
19
+ 从包含SVG的文本中提取出完整的<svg>...</svg>结构。
20
+ 如果未匹配到,则返回一个默认的空SVG。
21
+ """
22
+ pattern = r"<svg\b[^>]*>.*?</svg>"
23
+ matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
24
+ if matches:
25
+ return matches[0]
26
+ else:
27
+ return """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>"""
28
+
29
+
30
+ def code_style_prompt(desc: str) -> str:
31
+ return f"""\
32
+ // SVG CODE GENERATION TASK FOR CODELLAMA
33
+ // OBJECTIVE: Create simple yet accurate SVG contour drawing
34
+ // DESCRIPTION: {desc}
35
+
36
+ // SVG Example(DESCRIPTION=wheelchair)(you do not need to generate an example as well):
37
+ <svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
38
+ <!-- Wheelchair seat -->
39
+ <path d="M30,40 L50,40 L50,60 L30,60 Z" fill="#555"/>
40
+
41
+ <!-- Wheelchair back -->
42
+ <path d="M30,40 L20,30 L20,20 L30,20 L30,40" fill="#555"/>
43
+
44
+ <!-- Large wheel -->
45
+ <circle cx="65" cy="65" r="25" stroke="#333" stroke-width="3" fill="none"/>
46
+ <circle cx="65" cy="65" r="5" fill="#333"/>
47
+
48
+ <!-- Small wheel -->
49
+ <circle cx="30" cy="70" r="10" stroke="#333" stroke-width="3" fill="none"/>
50
+ <circle cx="30" cy="70" r="3" fill="#333"/>
51
+
52
+ <!-- Wheel spokes (large wheel) -->
53
+ <line x1="65" y1="65" x2="80" y2="65" stroke="#333" stroke-width="2"/>
54
+ <line x1="65" y1="65" x2="65" y2="80" stroke="#333" stroke-width="2"/>
55
+ <line x1="65" y1="65" x2="55" y2="75" stroke="#333" stroke-width="2"/>
56
+ <line x1="65" y1="65" x2="55" y2="55" stroke="#333" stroke-width="2"/>
57
+
58
+ <!-- Wheel spokes (small wheel) -->
59
+ <line x1="30" y1="70" x2="38" y2="70" stroke="#333" stroke-width="2"/>
60
+ <line x1="30" y1="70" x2="30" y2="78" stroke="#333" stroke-width="2"/>
61
+ </svg>
62
+
63
+
64
+ // CODE GENERATION INSTRUCTIONS:
65
+ 1. Figure out the main parts of the object(animal) according to the DESCRIPTION
66
+ 1. Fill path data for main-outline using basic commands
67
+ 2. Position eye element at logical position
68
+ 3. Keep all coordinates within viewBox
69
+ 4. Use 2 decimal precision for coordinates
70
+ 5. Close all path elements properly
71
+
72
+ // {desc} GENERATION START FROM HERE:
73
+ """
74
+
75
+
76
+ def post_process(code: str) -> str:
77
+ """针对代码模型的输出优化后处理"""
78
+ # 提取闭合的SVG代码块
79
+ svg_match = re.search(r'<svg.*?</svg>', code, re.DOTALL)
80
+ if svg_match:
81
+ code = svg_match.group(0)
82
+
83
+ # 确保XML声明
84
+ if '<?xml' not in code:
85
+ code = '<?xml version="1.0" encoding="UTF-8"?>\n' + code
86
+
87
+ # 验证必要元素
88
+ required_elements = {
89
+ '<svg': 1,
90
+ '</svg>': 1,
91
+ '<path': 1,
92
+ '<circle': 1
93
+ }
94
+ for elem, count in required_elements.items():
95
+ if code.count(elem) < count:
96
+ code = code.replace('</svg>',
97
+ f'<!-- Auto-added {elem} -->\n<{elem} />\n</svg>')
98
+
99
+ return code.strip()
100
+
101
+
102
+ def strict_svg_postprocess(raw_code: str) -> str:
103
+ """
104
+ 严格按照需求设计的SVG后处理器
105
+
106
+ 处理逻辑:
107
+ 1. 按行处理,找到第一个不以<svg开头的行作为内容起点
108
+ 2. 逐行检查:去重(最多3次)、完整性、排除</svg>
109
+ 3. 自动添加标准头尾
110
+ """
111
+ # 预处理:清理前后空白,分割为行
112
+ lines = [line.strip() for line in raw_code.strip().split('\n')]
113
+
114
+ # 阶段1:找到有效内容起始行
115
+ start_index = 0
116
+ for i, line in enumerate(lines):
117
+ if not line.lower().startswith("<svg"):
118
+ start_index = i
119
+ break
120
+
121
+ # 阶段2:逐行处理有效内容
122
+ valid_lines = []
123
+ line_counter = {}
124
+
125
+ for line in lines[start_index:]:
126
+ # 排除</svg>标签
127
+ if re.match(r'</\s*svg\s*>', line, re.IGNORECASE):
128
+ continue
129
+ if re.match(r'<\s*svg', line, re.IGNORECASE):
130
+ continue
131
+ # 检查完整性(匹配XML标签语法)
132
+ is_valid_tag = re.fullmatch(
133
+ r'\s*<[^>]+/?>\s*',
134
+ line,
135
+ re.IGNORECASE
136
+ )
137
+
138
+ # 检查是否已存在3次
139
+ count = line_counter.get(line, 0)
140
+
141
+ if is_valid_tag and count < 1:
142
+ valid_lines.append(line)
143
+ line_counter[line] = count + 1
144
+
145
+ # 阶段3:组装最终结果
146
+ core_content = '\n'.join(valid_lines)
147
+
148
+ return f'''<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
149
+ {core_content}
150
+ </svg>'''
151
+
152
+
153
+ def load_label_names(json_path: str) -> dict:
154
+ """加载标签映射表"""
155
+ with open(json_path, 'r', encoding='utf-8') as f:
156
+ data = json.load(f)
157
+ return data['dataset_info']['features'][0]['dtype']['class_label']['names']
158
+
159
+
160
+ def main_infer( ):
161
+ # 初始化代码模型
162
+ # generator = Llama.build(
163
+ # ckpt_dir=ckpt_dir,
164
+ # tokenizer_path=tokenizer_path,
165
+ # max_seq_len=max_seq_len,
166
+ # max_batch_size=max_batch_size,
167
+ # )
168
+ # 加载基础模型(请根据具体模型名称或路径调整)
169
+ base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",token="")
170
+ # 加载 LoRA 模型,加载 LoRA 权重(此处使用“steve329/llama3-8B-edit-lora-12k”)
171
+ model = PeftModel.from_pretrained(base_model, "steve329/llama3-8B-edit-lora-12k")
172
+ # 设置评估模式
173
+ model.eval()
174
+ # 加载对应的分词器(确保与基础模型匹配)
175
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
176
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177
+ model.to(device)
178
+
179
+
180
+ root = "/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean_llama8b"
181
+ for dir in os.listdir(root):
182
+ print(dir)
183
+ if (dir == "LICENSE-CODE") or (dir == "LICENSE-IMAGES") or (dir == "README.md") or (dir == "CaseGenerator.py") : continue
184
+ output_dir = os.path.jon(root+dir+'generated_svg') # 替换为实际的目标文件夹路径
185
+
186
+ # 确保目标文件夹存在,如果不存在则创建
187
+ if not os.path.exists(output_dir):
188
+ os.makedirs(output_dir)
189
+
190
+ file_dir = os.path.jon(root+dir+'query')
191
+ i=0
192
+ for file in os.listdir(file_dir):
193
+
194
+ file_name = os.path.splitext(file)[0]
195
+ file_path = os.path.join(file_dir, file)
196
+ with open(file_path, "r", encoding="utf-8") as file:
197
+ content = file.read()
198
+ print(content)
199
+ if len(content) > 4383:
200
+ file_path = os.path.join(output_dir, file_name + '.svg')
201
+
202
+ # 将final_code写入到description.svg文件中
203
+ with open(file_path, 'w', encoding='utf-8') as svg_file:
204
+ svg_file.write("""<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>""")
205
+
206
+ print(f"SVG文件已保存至: {file_path}")
207
+ with open('/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean/'+dir+'/skipped_file.txt', 'w', encoding='utf-8') as f:
208
+ f.write(f"{file_name}"+"\n")
209
+
210
+ continue
211
+
212
+ # test_input = (
213
+ # '{"instruction": "You are an expert SVG graphics generator. You generate clean, valid SVG code according to user instructions.", '
214
+ # f'"input": {content}'
215
+ # )
216
+
217
+ inputs = tokenizer(content, return_tensors="pt")
218
+ input_ids = inputs.input_ids.to(device)
219
+ attention_mask = inputs.attention_mask.to(device)
220
+ # 使用模型生成文本(可以根据需要调整生成参数)
221
+ with torch.no_grad():
222
+ generated_ids = model.generate(
223
+ input_ids,
224
+ attention_mask=attention_mask,
225
+ max_length=4096, # 指定生成文本的最大长度
226
+ do_sample=True, # 是否使用采样,True 可生成更多样化结果
227
+ top_k=50, # Top-K 采样参数
228
+ top_p=0.95 # Top-p (nucleus) 采样参数
229
+ )
230
+
231
+ # 解码生成的 token 成为文本
232
+ generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
233
+ # print("Prompt:", prompt)
234
+ # print("Generated text:", generated_text)
235
+ # print("-" * 80)
236
+ # results = generator.text_completion(
237
+ # prompts=[content],
238
+ # max_gen_len=max_gen_len,
239
+ # temperature=temperature,
240
+ # top_p=top_p,
241
+ # )
242
+
243
+
244
+
245
+
246
+
247
+ # 后处理
248
+ # raw_code = results[0]['generation']aa
249
+ print("raw_code:")
250
+ print(generated_text)
251
+ final_code = extract_svg_from_text(generated_text)
252
+
253
+ # 输出结果
254
+ print(f"\n=== Input: {file_name} ===")
255
+ print(f"// Generated SVG Code:")
256
+ print(final_code)
257
+ print("\n" + "=" * 40 + "\n")
258
+
259
+ # 定义SVG文件的完整路径
260
+ file_path = os.path.join(output_dir, file_name + '.svg')
261
+
262
+ # 将final_code写入到description.svg文件中
263
+ with open(file_path, 'w', encoding='utf-8') as svg_file:
264
+ svg_file.write(final_code)
265
+
266
+ print(f"SVG文件已保存至: {file_path}")
267
+
268
+
269
+ if __name__ == "__main__":
270
+ main_infer()