File size: 4,427 Bytes
26e1cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from datasets import Dataset, load_dataset
import re
from tqdm import tqdm
import numpy as np

new_model_max_length = 32768
# 1. Load model and tokenizer
model_name = "fdtn-ai/Foundation-Sec-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"修改前的 model_max_length: {tokenizer.model_max_length}")
tokenizer.model_max_length = new_model_max_length  # 修改最大长度
print(f"修改后的 model_max_length: {tokenizer.model_max_length}")

# 2. 检查模型的最大位置编码
config = AutoConfig.from_pretrained(model_name)
print(f"模型最大位置嵌入: {config.max_position_embeddings}")

# 如果需要,修改模型配置(谨慎操作)
if config.max_position_embeddings < new_model_max_length:
    config.max_position_embeddings = new_model_max_length
    print(f"已更新模型 max_position_embeddings 为: {config.max_position_embeddings}")

# 2. Load dataset
def get_opencode_instructions():
    ocr_ds = load_dataset("nvidia/OpenCodeReasoning", "split_0")
    return ocr_ds

# Load dataset
dataset = get_opencode_instructions()
print(f"Dataset keys: {dataset.keys()}")
train_dataset = dataset["split_0"]  # Access the 'train' split
print(f"Train dataset length: {len(train_dataset)}")

# 3. Formatting function
def formatting_question_answer(i, question, answer):
    text = None
    think_match = re.match(r"<think>(.*?)</think>\n(.*)", answer, re.DOTALL)
    if think_match:
        think_content, assistant_content = think_match.groups()
        content = f"<think>\n{think_content.strip()}\n</think>\n\n{assistant_content.strip()}"
        chat_messages = [
            {"role": "user", "content": question},
            {"role": "assistant", "content": content}
        ]
        
        text = tokenizer.apply_chat_template(
            chat_messages,
            tokenize=False,
            add_generation_prompt=False
            # Remove template_kwargs unless required by the model
        )
    else:
        print(f"error:{i},{question}")
    
    return text

def process_dataset(train_dataset, max_process_num):
    formatted_texts = []
    token_lengths = []
    for i in tqdm(range(min(max_process_num, len(train_dataset))), desc="Processing dataset"):

        record = train_dataset[i]
        question = record.get("input", "")
        answer = record.get("output", "")
        formatted_text = formatting_question_answer(i, question, answer)
        if formatted_text:
            formatted_texts.append(formatted_text)

            full_tokens = tokenizer.encode(formatted_text, add_special_tokens=True)
            token_lengths.append(len(full_tokens))
        else:
            print(f"Failed to format record {i}: {question[:50]}...")

    max_length = max(token_lengths)
    mean_length = np.mean(token_lengths)
    percentile_95 = np.percentile(token_lengths, 95)
    percentile_99 = np.percentile(token_lengths, 99)
    
    print("Token Length Statistics (with chat template):")
    print(f"Full sequence max length: {max_length}")
    print(f"Mean sequence length: {mean_length:.1f}")
    print(f"95th percentile: {percentile_95:.1f}")
    print(f"99th percentile: {percentile_99:.1f}")
    
    recommended_length = 256
    for threshold in [128, 192, 256, 512, 1024, 2048, 3072, 4096, 8192, 16384, 16384, 32768]:
        if max_length <= threshold:
            recommended_length = threshold
            break
    print(f"Recommended max_seq_length: {recommended_length}")
    
    new_data = {"text": formatted_texts}
    new_dataset = Dataset.from_dict(new_data)

    del formatted_texts, token_lengths
    return new_dataset, recommended_length

train_dataset, max_seq_length = process_dataset(train_dataset, len(train_dataset))

print(f"Using max_seq_length: {max_seq_length}")

# 5. Print results
print(f"\nTotal formatted records (out of 100): {len(train_dataset)}")
if train_dataset:
    print(f"\nFirst formatted text:\n{train_dataset[0]}\n")
else:
    print("No records were successfully formatted.")
    
print(f"Dataset : {train_dataset}")
print(f"Dataset length2: {len(train_dataset)}")