from datasets import load_dataset from transformers import DataCollatorForLanguageModeling from transformers import Trainer, TrainingArguments import os import torch def main(): local_rank = int(os.environ['LOCAL_RANK']) rank = int(os.environ['RANK']) world_size = int(os.environ['WORLD_SIZE']) torch.distributed.init_process_group("nccl") print(f"Local Rank = {local_rank}/{world_size}") # Load your JSONL file dataset = load_dataset('json', data_files='../../data/m2_250514_1150.jsonl', split='train') # Load a Model from transformers import AutoTokenizer, AutoModelForCausalLM model_name = "FacebookAI/roberta-base" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) # Set pad token if not set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Tokenize the dataset def tokenize_function(examples): return tokenizer(examples["text"], truncation=True, max_length=512) tokenized_dataset = dataset.map(tokenize_function, batched=True) # Split the dataset into training and validation sets split_dataset = tokenized_dataset.train_test_split(test_size=0.1) # Data collator, pad the inputs to the maximum length in the batch data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False # mlm=False: causal language modeling ) # Training training_args = TrainingArguments( output_dir="./results", overwrite_output_dir=True, num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, dataloader_num_workers=8, eval_steps=500, save_steps=1000, warmup_steps=500, prediction_loss_only=True, logging_dir="./logs", logging_steps=100, learning_rate=5e-5, fp16=True, # true for GPU ) trainer = Trainer( model=model, args=training_args, train_dataset=split_dataset["train"], eval_dataset=split_dataset["test"], data_collator=data_collator, ) # Start training trainer.train() torch.distributed.destroy_process_group() # Save the model and tokenizer model.save_pretrained("./fine_tuned_model") tokenizer.save_pretrained("./fine_tuned_model") if __name__ == "__main__": main()