# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import os
import fire
import json
import re
from collections import defaultdict
from datasets import load_dataset
from typing import Optional, List
from llama import Llama
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def extract_svg_from_text(text: str) -> Optional[str]:
"""
从包含SVG的文本中提取出完整的结构。
如果未匹配到,则返回一个默认的空SVG。
"""
pattern = r""
matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
if matches:
return matches[0]
else:
return """"""
def code_style_prompt(desc: str) -> str:
return f"""\
// SVG CODE GENERATION TASK FOR CODELLAMA
// OBJECTIVE: Create simple yet accurate SVG contour drawing
// DESCRIPTION: {desc}
// SVG Example(DESCRIPTION=wheelchair)(you do not need to generate an example as well):
// CODE GENERATION INSTRUCTIONS:
1. Figure out the main parts of the object(animal) according to the DESCRIPTION
1. Fill path data for main-outline using basic commands
2. Position eye element at logical position
3. Keep all coordinates within viewBox
4. Use 2 decimal precision for coordinates
5. Close all path elements properly
// {desc} GENERATION START FROM HERE:
"""
def post_process(code: str) -> str:
"""针对代码模型的输出优化后处理"""
# 提取闭合的SVG代码块
svg_match = re.search(r'', code, re.DOTALL)
if svg_match:
code = svg_match.group(0)
# 确保XML声明
if '\n' + code
# 验证必要元素
required_elements = {
'')
return code.strip()
def strict_svg_postprocess(raw_code: str) -> str:
"""
严格按照需求设计的SVG后处理器
处理逻辑:
1. 按行处理,找到第一个不以