Nexa_Data_Studio / Tokenization /Build_tokenizer.py
Allanatrix's picture
Upload 50 files
ef4c8c3 verified
import json
from pathlib import Path
from typing import Dict
from transformers import AutoTokenizer
from Tokenization.Entropy_ranker import EntropyRanker
from Tokenization.Label_tokens import MIN_WORDS, MAX_TOKENS, MAX_TOTAL_TOKENS, TOKEN_TARGETS
from Tokenization.pretraining.Dataset_stats import DatasetAnalyzer
from Tokenization.pretraining.Instruction_formatter import InstructionFormatter
class QLoRAPreprocessor:
def __init__(self, model_name: str = "facebook/opt-350m", corpus_type: str = "warm_start"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.analyzer = DatasetAnalyzer(model_name)
self.formatter = InstructionFormatter()
self.ranker = EntropyRanker()
self.token_target = TOKEN_TARGETS[corpus_type]
self.current_tokens = 0
def track_tokens(self, text: str) -> bool:
tokens = self.tokenizer.encode(text)
self.current_tokens += len(tokens)
return self.current_tokens <= self.token_target
def validate_sample(self, sample: Dict) -> bool:
if not all(k in sample for k in ["instruction", "input", "output"]):
return False
total_text = f"{sample['instruction']} {sample['input']} {sample['output']}"
tokens = self.tokenizer.encode(total_text)
words = total_text.split()
return (len(words) >= MIN_WORDS and
len(tokens) <= MAX_TOKENS and
len(tokens) <= MAX_TOTAL_TOKENS)
def process_dataset(self, input_path: str, output_path: str):
# Load data, skipping blank lines and malformed JSON
data = []
with open(input_path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Skipping line {i}: {e}")
# Analyze dataset
stats = self.analyzer.get_dataset_stats(data)
print(f"Dataset stats: {stats}")
# Format samples
formatted_samples = [
self.formatter.format_sample(sample)
for sample in data
]
# Rank and filter samples
ranked_samples = self.ranker.rank_samples(formatted_samples)
# Track token count while processing
valid_samples = []
for sample in ranked_samples:
if not self.validate_sample(sample):
continue
sample_text = f"{sample['instruction']} {sample['input']} {sample['output']}"
if not self.track_tokens(sample_text):
break
valid_samples.append(sample)
# Save to JSONL
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
for sample in valid_samples:
f.write(json.dumps(sample) + '\n')
print(f"Processed {len(valid_samples)} samples saved to {output_path}")
if __name__ == "__main__":
preprocessor = QLoRAPreprocessor()
preprocessor.process_dataset(
"C:/Users/kunya/PycharmProjects/DataVolt/Tokenizers/combined_scientific_papers.json",
"nexa_scientific_instruction_300k.jsonl"
)