|
|
|
|
|
import os |
|
import fire |
|
import json |
|
import re |
|
from collections import defaultdict |
|
from datasets import load_dataset |
|
from typing import Optional, List |
|
from llama import Llama |
|
from peft import PeftModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
|
|
def extract_svg_from_text(text: str) -> Optional[str]: |
|
""" |
|
从包含SVG的文本中提取出完整的<svg>...</svg>结构。 |
|
如果未匹配到,则返回一个默认的空SVG。 |
|
""" |
|
pattern = r"<svg\b[^>]*>.*?</svg>" |
|
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE) |
|
if matches: |
|
return matches[0] |
|
else: |
|
return """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>""" |
|
|
|
|
|
def code_style_prompt(desc: str) -> str: |
|
return f"""\ |
|
// SVG CODE GENERATION TASK FOR CODELLAMA |
|
// OBJECTIVE: Create simple yet accurate SVG contour drawing |
|
// DESCRIPTION: {desc} |
|
|
|
// SVG Example(DESCRIPTION=wheelchair)(you do not need to generate an example as well): |
|
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100"> |
|
<!-- Wheelchair seat --> |
|
<path d="M30,40 L50,40 L50,60 L30,60 Z" fill="#555"/> |
|
|
|
<!-- Wheelchair back --> |
|
<path d="M30,40 L20,30 L20,20 L30,20 L30,40" fill="#555"/> |
|
|
|
<!-- Large wheel --> |
|
<circle cx="65" cy="65" r="25" stroke="#333" stroke-width="3" fill="none"/> |
|
<circle cx="65" cy="65" r="5" fill="#333"/> |
|
|
|
<!-- Small wheel --> |
|
<circle cx="30" cy="70" r="10" stroke="#333" stroke-width="3" fill="none"/> |
|
<circle cx="30" cy="70" r="3" fill="#333"/> |
|
|
|
<!-- Wheel spokes (large wheel) --> |
|
<line x1="65" y1="65" x2="80" y2="65" stroke="#333" stroke-width="2"/> |
|
<line x1="65" y1="65" x2="65" y2="80" stroke="#333" stroke-width="2"/> |
|
<line x1="65" y1="65" x2="55" y2="75" stroke="#333" stroke-width="2"/> |
|
<line x1="65" y1="65" x2="55" y2="55" stroke="#333" stroke-width="2"/> |
|
|
|
<!-- Wheel spokes (small wheel) --> |
|
<line x1="30" y1="70" x2="38" y2="70" stroke="#333" stroke-width="2"/> |
|
<line x1="30" y1="70" x2="30" y2="78" stroke="#333" stroke-width="2"/> |
|
</svg> |
|
|
|
|
|
// CODE GENERATION INSTRUCTIONS: |
|
1. Figure out the main parts of the object(animal) according to the DESCRIPTION |
|
1. Fill path data for main-outline using basic commands |
|
2. Position eye element at logical position |
|
3. Keep all coordinates within viewBox |
|
4. Use 2 decimal precision for coordinates |
|
5. Close all path elements properly |
|
|
|
// {desc} GENERATION START FROM HERE: |
|
""" |
|
|
|
|
|
def post_process(code: str) -> str: |
|
"""针对代码模型的输出优化后处理""" |
|
|
|
svg_match = re.search(r'<svg.*?</svg>', code, re.DOTALL) |
|
if svg_match: |
|
code = svg_match.group(0) |
|
|
|
|
|
if '<?xml' not in code: |
|
code = '<?xml version="1.0" encoding="UTF-8"?>\n' + code |
|
|
|
|
|
required_elements = { |
|
'<svg': 1, |
|
'</svg>': 1, |
|
'<path': 1, |
|
'<circle': 1 |
|
} |
|
for elem, count in required_elements.items(): |
|
if code.count(elem) < count: |
|
code = code.replace('</svg>', |
|
f'<!-- Auto-added {elem} -->\n<{elem} />\n</svg>') |
|
|
|
return code.strip() |
|
|
|
|
|
def strict_svg_postprocess(raw_code: str) -> str: |
|
""" |
|
严格按照需求设计的SVG后处理器 |
|
|
|
处理逻辑: |
|
1. 按行处理,找到第一个不以<svg开头的行作为内容起点 |
|
2. 逐行检查:去重(最多3次)、完整性、排除</svg> |
|
3. 自动添加标准头尾 |
|
""" |
|
|
|
lines = [line.strip() for line in raw_code.strip().split('\n')] |
|
|
|
|
|
start_index = 0 |
|
for i, line in enumerate(lines): |
|
if not line.lower().startswith("<svg"): |
|
start_index = i |
|
break |
|
|
|
|
|
valid_lines = [] |
|
line_counter = {} |
|
|
|
for line in lines[start_index:]: |
|
|
|
if re.match(r'</\s*svg\s*>', line, re.IGNORECASE): |
|
continue |
|
if re.match(r'<\s*svg', line, re.IGNORECASE): |
|
continue |
|
|
|
is_valid_tag = re.fullmatch( |
|
r'\s*<[^>]+/?>\s*', |
|
line, |
|
re.IGNORECASE |
|
) |
|
|
|
|
|
count = line_counter.get(line, 0) |
|
|
|
if is_valid_tag and count < 1: |
|
valid_lines.append(line) |
|
line_counter[line] = count + 1 |
|
|
|
|
|
core_content = '\n'.join(valid_lines) |
|
|
|
return f'''<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100"> |
|
{core_content} |
|
</svg>''' |
|
|
|
|
|
def load_label_names(json_path: str) -> dict: |
|
"""加载标签映射表""" |
|
with open(json_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
return data['dataset_info']['features'][0]['dtype']['class_label']['names'] |
|
|
|
|
|
def main_infer( ): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",token="") |
|
|
|
model = PeftModel.from_pretrained(base_model, "steve329/llama3-8B-edit-lora-12k") |
|
|
|
model.eval() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
|
|
root = "/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean_llama8b" |
|
for dir in os.listdir(root): |
|
print(dir) |
|
if (dir == "LICENSE-CODE") or (dir == "LICENSE-IMAGES") or (dir == "README.md") or (dir == "CaseGenerator.py") : continue |
|
output_dir = os.path.jon(root+dir+'generated_svg') |
|
|
|
|
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
file_dir = os.path.jon(root+dir+'query') |
|
i=0 |
|
for file in os.listdir(file_dir): |
|
|
|
file_name = os.path.splitext(file)[0] |
|
file_path = os.path.join(file_dir, file) |
|
with open(file_path, "r", encoding="utf-8") as file: |
|
content = file.read() |
|
print(content) |
|
if len(content) > 4383: |
|
file_path = os.path.join(output_dir, file_name + '.svg') |
|
|
|
|
|
with open(file_path, 'w', encoding='utf-8') as svg_file: |
|
svg_file.write("""<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>""") |
|
|
|
print(f"SVG文件已保存至: {file_path}") |
|
with open('/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean/'+dir+'/skipped_file.txt', 'w', encoding='utf-8') as f: |
|
f.write(f"{file_name}"+"\n") |
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = tokenizer(content, return_tensors="pt") |
|
input_ids = inputs.input_ids.to(device) |
|
attention_mask = inputs.attention_mask.to(device) |
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
max_length=4096, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95 |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("raw_code:") |
|
print(generated_text) |
|
final_code = extract_svg_from_text(generated_text) |
|
|
|
|
|
print(f"\n=== Input: {file_name} ===") |
|
print(f"// Generated SVG Code:") |
|
print(final_code) |
|
print("\n" + "=" * 40 + "\n") |
|
|
|
|
|
file_path = os.path.join(output_dir, file_name + '.svg') |
|
|
|
|
|
with open(file_path, 'w', encoding='utf-8') as svg_file: |
|
svg_file.write(final_code) |
|
|
|
print(f"SVG文件已保存至: {file_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main_infer() |