tmp / sample_llama3-8B.py
steve329's picture
Upload sample_llama3-8B.py
dd89c43 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import fire
import json
import re
from collections import defaultdict
from datasets import load_dataset
from typing import Optional, List
from llama import Llama
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def extract_svg_from_text(text: str) -> Optional[str]:
"""
从包含SVG的文本中提取出完整的<svg>...</svg>结构。
如果未匹配到,则返回一个默认的空SVG。
"""
pattern = r"<svg\b[^>]*>.*?</svg>"
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
if matches:
return matches[0]
else:
return """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>"""
def code_style_prompt(desc: str) -> str:
return f"""\
// SVG CODE GENERATION TASK FOR CODELLAMA
// OBJECTIVE: Create simple yet accurate SVG contour drawing
// DESCRIPTION: {desc}
// SVG Example(DESCRIPTION=wheelchair)(you do not need to generate an example as well):
<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
<!-- Wheelchair seat -->
<path d="M30,40 L50,40 L50,60 L30,60 Z" fill="#555"/>
<!-- Wheelchair back -->
<path d="M30,40 L20,30 L20,20 L30,20 L30,40" fill="#555"/>
<!-- Large wheel -->
<circle cx="65" cy="65" r="25" stroke="#333" stroke-width="3" fill="none"/>
<circle cx="65" cy="65" r="5" fill="#333"/>
<!-- Small wheel -->
<circle cx="30" cy="70" r="10" stroke="#333" stroke-width="3" fill="none"/>
<circle cx="30" cy="70" r="3" fill="#333"/>
<!-- Wheel spokes (large wheel) -->
<line x1="65" y1="65" x2="80" y2="65" stroke="#333" stroke-width="2"/>
<line x1="65" y1="65" x2="65" y2="80" stroke="#333" stroke-width="2"/>
<line x1="65" y1="65" x2="55" y2="75" stroke="#333" stroke-width="2"/>
<line x1="65" y1="65" x2="55" y2="55" stroke="#333" stroke-width="2"/>
<!-- Wheel spokes (small wheel) -->
<line x1="30" y1="70" x2="38" y2="70" stroke="#333" stroke-width="2"/>
<line x1="30" y1="70" x2="30" y2="78" stroke="#333" stroke-width="2"/>
</svg>
// CODE GENERATION INSTRUCTIONS:
1. Figure out the main parts of the object(animal) according to the DESCRIPTION
1. Fill path data for main-outline using basic commands
2. Position eye element at logical position
3. Keep all coordinates within viewBox
4. Use 2 decimal precision for coordinates
5. Close all path elements properly
// {desc} GENERATION START FROM HERE:
"""
def post_process(code: str) -> str:
"""针对代码模型的输出优化后处理"""
# 提取闭合的SVG代码块
svg_match = re.search(r'<svg.*?</svg>', code, re.DOTALL)
if svg_match:
code = svg_match.group(0)
# 确保XML声明
if '<?xml' not in code:
code = '<?xml version="1.0" encoding="UTF-8"?>\n' + code
# 验证必要元素
required_elements = {
'<svg': 1,
'</svg>': 1,
'<path': 1,
'<circle': 1
}
for elem, count in required_elements.items():
if code.count(elem) < count:
code = code.replace('</svg>',
f'<!-- Auto-added {elem} -->\n<{elem} />\n</svg>')
return code.strip()
def strict_svg_postprocess(raw_code: str) -> str:
"""
严格按照需求设计的SVG后处理器
处理逻辑:
1. 按行处理,找到第一个不以<svg开头的行作为内容起点
2. 逐行检查:去重(最多3次)、完整性、排除</svg>
3. 自动添加标准头尾
"""
# 预处理:清理前后空白,分割为行
lines = [line.strip() for line in raw_code.strip().split('\n')]
# 阶段1:找到有效内容起始行
start_index = 0
for i, line in enumerate(lines):
if not line.lower().startswith("<svg"):
start_index = i
break
# 阶段2:逐行处理有效内容
valid_lines = []
line_counter = {}
for line in lines[start_index:]:
# 排除</svg>标签
if re.match(r'</\s*svg\s*>', line, re.IGNORECASE):
continue
if re.match(r'<\s*svg', line, re.IGNORECASE):
continue
# 检查完整性(匹配XML标签语法)
is_valid_tag = re.fullmatch(
r'\s*<[^>]+/?>\s*',
line,
re.IGNORECASE
)
# 检查是否已存在3次
count = line_counter.get(line, 0)
if is_valid_tag and count < 1:
valid_lines.append(line)
line_counter[line] = count + 1
# 阶段3:组装最终结果
core_content = '\n'.join(valid_lines)
return f'''<svg xmlns="http://www.w3.org/2000/svg" width="100" height="100" viewBox="0 0 100 100">
{core_content}
</svg>'''
def load_label_names(json_path: str) -> dict:
"""加载标签映射表"""
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data['dataset_info']['features'][0]['dtype']['class_label']['names']
def main_infer( ):
# 初始化代码模型
# generator = Llama.build(
# ckpt_dir=ckpt_dir,
# tokenizer_path=tokenizer_path,
# max_seq_len=max_seq_len,
# max_batch_size=max_batch_size,
# )
# 加载基础模型(请根据具体模型名称或路径调整)
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",token="")
# 加载 LoRA 模型,加载 LoRA 权重(此处使用“steve329/llama3-8B-edit-lora-12k”)
model = PeftModel.from_pretrained(base_model, "steve329/llama3-8B-edit-lora-12k")
# 设置评估模式
model.eval()
# 加载对应的分词器(确保与基础模型匹配)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
root = "/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean_llama8b"
for dir in os.listdir(root):
print(dir)
if (dir == "LICENSE-CODE") or (dir == "LICENSE-IMAGES") or (dir == "README.md") or (dir == "CaseGenerator.py") : continue
output_dir = os.path.jon(root+dir+'generated_svg') # 替换为实际的目标文件夹路径
# 确保目标文件夹存在,如果不存在则创建
if not os.path.exists(output_dir):
os.makedirs(output_dir)
file_dir = os.path.jon(root+dir+'query')
i=0
for file in os.listdir(file_dir):
file_name = os.path.splitext(file)[0]
file_path = os.path.join(file_dir, file)
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
print(content)
if len(content) > 4383:
file_path = os.path.join(output_dir, file_name + '.svg')
# 将final_code写入到description.svg文件中
with open(file_path, 'w', encoding='utf-8') as svg_file:
svg_file.write("""<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 36 36"></svg>""")
print(f"SVG文件已保存至: {file_path}")
with open('/export/home2/zhanjun001/codellama/codellama/SVGEditBench_clean/'+dir+'/skipped_file.txt', 'w', encoding='utf-8') as f:
f.write(f"{file_name}"+"\n")
continue
# test_input = (
# '{"instruction": "You are an expert SVG graphics generator. You generate clean, valid SVG code according to user instructions.", '
# f'"input": {content}'
# )
inputs = tokenizer(content, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
# 使用模型生成文本(可以根据需要调整生成参数)
with torch.no_grad():
generated_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=4096, # 指定生成文本的最大长度
do_sample=True, # 是否使用采样,True 可生成更多样化结果
top_k=50, # Top-K 采样参数
top_p=0.95 # Top-p (nucleus) 采样参数
)
# 解码生成的 token 成为文本
generated_text = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
# print("Prompt:", prompt)
# print("Generated text:", generated_text)
# print("-" * 80)
# results = generator.text_completion(
# prompts=[content],
# max_gen_len=max_gen_len,
# temperature=temperature,
# top_p=top_p,
# )
# 后处理
# raw_code = results[0]['generation']aa
print("raw_code:")
print(generated_text)
final_code = extract_svg_from_text(generated_text)
# 输出结果
print(f"\n=== Input: {file_name} ===")
print(f"// Generated SVG Code:")
print(final_code)
print("\n" + "=" * 40 + "\n")
# 定义SVG文件的完整路径
file_path = os.path.join(output_dir, file_name + '.svg')
# 将final_code写入到description.svg文件中
with open(file_path, 'w', encoding='utf-8') as svg_file:
svg_file.write(final_code)
print(f"SVG文件已保存至: {file_path}")
if __name__ == "__main__":
main_infer()