import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig from datasets import Dataset, load_dataset import re from tqdm import tqdm import numpy as np new_model_max_length = 32768 # 1. Load model and tokenizer model_name = "fdtn-ai/Foundation-Sec-8B" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id print(f"修改前的 model_max_length: {tokenizer.model_max_length}") tokenizer.model_max_length = new_model_max_length # 修改最大长度 print(f"修改后的 model_max_length: {tokenizer.model_max_length}") # 2. 检查模型的最大位置编码 config = AutoConfig.from_pretrained(model_name) print(f"模型最大位置嵌入: {config.max_position_embeddings}") # 如果需要,修改模型配置(谨慎操作) if config.max_position_embeddings < new_model_max_length: config.max_position_embeddings = new_model_max_length print(f"已更新模型 max_position_embeddings 为: {config.max_position_embeddings}") # 2. Load dataset def get_opencode_instructions(): ocr_ds = load_dataset("nvidia/OpenCodeReasoning", "split_0") return ocr_ds # Load dataset dataset = get_opencode_instructions() print(f"Dataset keys: {dataset.keys()}") train_dataset = dataset["split_0"] # Access the 'train' split print(f"Train dataset length: {len(train_dataset)}") # 3. Formatting function def formatting_question_answer(i, question, answer): text = None think_match = re.match(r"(.*?)\n(.*)", answer, re.DOTALL) if think_match: think_content, assistant_content = think_match.groups() content = f"\n{think_content.strip()}\n\n\n{assistant_content.strip()}" chat_messages = [ {"role": "user", "content": question}, {"role": "assistant", "content": content} ] text = tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=False # Remove template_kwargs unless required by the model ) else: print(f"error:{i},{question}") return text def process_dataset(train_dataset, max_process_num): formatted_texts = [] token_lengths = [] for i in tqdm(range(min(max_process_num, len(train_dataset))), desc="Processing dataset"): record = train_dataset[i] question = record.get("input", "") answer = record.get("output", "") formatted_text = formatting_question_answer(i, question, answer) if formatted_text: formatted_texts.append(formatted_text) full_tokens = tokenizer.encode(formatted_text, add_special_tokens=True) token_lengths.append(len(full_tokens)) else: print(f"Failed to format record {i}: {question[:50]}...") max_length = max(token_lengths) mean_length = np.mean(token_lengths) percentile_95 = np.percentile(token_lengths, 95) percentile_99 = np.percentile(token_lengths, 99) print("Token Length Statistics (with chat template):") print(f"Full sequence max length: {max_length}") print(f"Mean sequence length: {mean_length:.1f}") print(f"95th percentile: {percentile_95:.1f}") print(f"99th percentile: {percentile_99:.1f}") recommended_length = 256 for threshold in [128, 192, 256, 512, 1024, 2048, 3072, 4096, 8192, 16384, 16384, 32768]: if max_length <= threshold: recommended_length = threshold break print(f"Recommended max_seq_length: {recommended_length}") new_data = {"text": formatted_texts} new_dataset = Dataset.from_dict(new_data) del formatted_texts, token_lengths return new_dataset, recommended_length train_dataset, max_seq_length = process_dataset(train_dataset, len(train_dataset)) print(f"Using max_seq_length: {max_seq_length}") # 5. Print results print(f"\nTotal formatted records (out of 100): {len(train_dataset)}") if train_dataset: print(f"\nFirst formatted text:\n{train_dataset[0]}\n") else: print("No records were successfully formatted.") print(f"Dataset : {train_dataset}") print(f"Dataset length2: {len(train_dataset)}")