import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from datasets import Dataset, load_dataset
import re
from tqdm import tqdm
import numpy as np
new_model_max_length = 32768
# 1. Load model and tokenizer
model_name = "fdtn-ai/Foundation-Sec-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"修改前的 model_max_length: {tokenizer.model_max_length}")
tokenizer.model_max_length = new_model_max_length # 修改最大长度
print(f"修改后的 model_max_length: {tokenizer.model_max_length}")
# 2. 检查模型的最大位置编码
config = AutoConfig.from_pretrained(model_name)
print(f"模型最大位置嵌入: {config.max_position_embeddings}")
# 如果需要,修改模型配置(谨慎操作)
if config.max_position_embeddings < new_model_max_length:
config.max_position_embeddings = new_model_max_length
print(f"已更新模型 max_position_embeddings 为: {config.max_position_embeddings}")
# 2. Load dataset
def get_opencode_instructions():
ocr_ds = load_dataset("nvidia/OpenCodeReasoning", "split_0")
return ocr_ds
# Load dataset
dataset = get_opencode_instructions()
print(f"Dataset keys: {dataset.keys()}")
train_dataset = dataset["split_0"] # Access the 'train' split
print(f"Train dataset length: {len(train_dataset)}")
# 3. Formatting function
def formatting_question_answer(i, question, answer):
text = None
think_match = re.match(r"(.*?)\n(.*)", answer, re.DOTALL)
if think_match:
think_content, assistant_content = think_match.groups()
content = f"\n{think_content.strip()}\n\n\n{assistant_content.strip()}"
chat_messages = [
{"role": "user", "content": question},
{"role": "assistant", "content": content}
]
text = tokenizer.apply_chat_template(
chat_messages,
tokenize=False,
add_generation_prompt=False
# Remove template_kwargs unless required by the model
)
else:
print(f"error:{i},{question}")
return text
def process_dataset(train_dataset, max_process_num):
formatted_texts = []
token_lengths = []
for i in tqdm(range(min(max_process_num, len(train_dataset))), desc="Processing dataset"):
record = train_dataset[i]
question = record.get("input", "")
answer = record.get("output", "")
formatted_text = formatting_question_answer(i, question, answer)
if formatted_text:
formatted_texts.append(formatted_text)
full_tokens = tokenizer.encode(formatted_text, add_special_tokens=True)
token_lengths.append(len(full_tokens))
else:
print(f"Failed to format record {i}: {question[:50]}...")
max_length = max(token_lengths)
mean_length = np.mean(token_lengths)
percentile_95 = np.percentile(token_lengths, 95)
percentile_99 = np.percentile(token_lengths, 99)
print("Token Length Statistics (with chat template):")
print(f"Full sequence max length: {max_length}")
print(f"Mean sequence length: {mean_length:.1f}")
print(f"95th percentile: {percentile_95:.1f}")
print(f"99th percentile: {percentile_99:.1f}")
recommended_length = 256
for threshold in [128, 192, 256, 512, 1024, 2048, 3072, 4096, 8192, 16384, 16384, 32768]:
if max_length <= threshold:
recommended_length = threshold
break
print(f"Recommended max_seq_length: {recommended_length}")
new_data = {"text": formatted_texts}
new_dataset = Dataset.from_dict(new_data)
del formatted_texts, token_lengths
return new_dataset, recommended_length
train_dataset, max_seq_length = process_dataset(train_dataset, len(train_dataset))
print(f"Using max_seq_length: {max_seq_length}")
# 5. Print results
print(f"\nTotal formatted records (out of 100): {len(train_dataset)}")
if train_dataset:
print(f"\nFirst formatted text:\n{train_dataset[0]}\n")
else:
print("No records were successfully formatted.")
print(f"Dataset : {train_dataset}")
print(f"Dataset length2: {len(train_dataset)}")