|  | import os | 
					
						
						|  | import sys | 
					
						
						|  | import logging | 
					
						
						|  | import pathlib | 
					
						
						|  | import typing | 
					
						
						|  | import warnings | 
					
						
						|  |  | 
					
						
						|  | SLURM_ENV = {k: v for k, v in os.environ.items() if 'SLURM' in k} | 
					
						
						|  | if SLURM_ENV: | 
					
						
						|  | print(f"SLURM_ENV: {SLURM_ENV}") | 
					
						
						|  | project_path = pathlib.Path(__file__).parent.parent.parent | 
					
						
						|  | sys.path.append(str(project_path)) | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.cuda | 
					
						
						|  |  | 
					
						
						|  | from mllm.config import prepare_args | 
					
						
						|  | from mllm.models import load_pretrained | 
					
						
						|  | from mllm.utils import print_trainable_params | 
					
						
						|  | from mllm.engine import prepare_trainer_collator | 
					
						
						|  | from mllm.dataset import prepare_data, prepare_target_processor | 
					
						
						|  |  | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  | logger.setLevel(logging.INFO) | 
					
						
						|  | logging.basicConfig( | 
					
						
						|  | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | 
					
						
						|  | datefmt="%m/%d/%Y %H:%M:%S", | 
					
						
						|  | handlers=[logging.StreamHandler(sys.stdout), ], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  | cfg, training_args = prepare_args() | 
					
						
						|  | model, preprocessor = load_pretrained(cfg.model_args, training_args) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model, preprocessor = prepare_target_processor(model, preprocessor, cfg.model_args, training_args) | 
					
						
						|  | print_trainable_params(model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | collator_kwargs = cfg.data_args.collator_kwargs | 
					
						
						|  | trainer_cls, data_collator_dict = prepare_trainer_collator(cfg.model_args, preprocessor, collator_kwargs) | 
					
						
						|  | dataset, compute_metrics = prepare_data(cfg.data_args, cfg.model_args, training_args, preprocessor) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | trainer = trainer_cls( | 
					
						
						|  | model=model, | 
					
						
						|  | args=training_args, | 
					
						
						|  | tokenizer=preprocessor['text'], | 
					
						
						|  | train_dataset=dataset['train'] if training_args.do_train else None, | 
					
						
						|  | eval_dataset=dataset['validation'] if training_args.do_eval else None, | 
					
						
						|  | compute_metrics=compute_metrics if training_args.predict_with_generate else None, | 
					
						
						|  | **data_collator_dict, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if training_args.do_train: | 
					
						
						|  | try: | 
					
						
						|  | if (not training_args.overwrite_output_dir) and list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): | 
					
						
						|  | train_result = trainer.train(resume_from_checkpoint=True) | 
					
						
						|  | else: | 
					
						
						|  | train_result = trainer.train() | 
					
						
						|  | trainer.log_metrics("train", train_result.metrics) | 
					
						
						|  | trainer.save_metrics("train", train_result.metrics) | 
					
						
						|  | trainer.save_model() | 
					
						
						|  | except RuntimeError as e: | 
					
						
						|  | print(f"got RuntimeError: {e.args}") | 
					
						
						|  | try: | 
					
						
						|  | print(f"#### device {training_args.local_rank} summary ####\n{torch.cuda.memory_summary(training_args.local_rank)}") | 
					
						
						|  | except Exception as inner_e: | 
					
						
						|  | print(f"get Exception when show cuda summary: {inner_e.args}") | 
					
						
						|  | raise e | 
					
						
						|  | finally: | 
					
						
						|  | trainer.save_state() | 
					
						
						|  | trainer.plot_loss() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | output_dir = training_args.output_dir | 
					
						
						|  | pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) | 
					
						
						|  | cfg.dump(os.path.join(output_dir, "cfg.py")) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | warnings.warn(f'try to save cfg to output_dir, but get exception {e.args}') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | gen_kwargs = dict(cfg.data_args.gen_kwargs) | 
					
						
						|  | gen_kwargs.setdefault('use_cache', True) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if hasattr(cfg.model_args, 'gen_kwargs_set_pad_token_id') and cfg.model_args.gen_kwargs_set_pad_token_id: | 
					
						
						|  | gen_kwargs['pad_token_id'] = preprocessor['text'].pad_token_id | 
					
						
						|  | if hasattr(cfg.model_args, 'gen_kwargs_set_bos_token_id') and cfg.model_args.gen_kwargs_set_bos_token_id: | 
					
						
						|  | gen_kwargs['bos_token_id'] = preprocessor['text'].bos_token_id | 
					
						
						|  | if hasattr(cfg.model_args, 'gen_kwargs_set_eos_token_id') and cfg.model_args.gen_kwargs_set_eos_token_id: | 
					
						
						|  | gen_kwargs['eos_token_id'] = preprocessor['text'].eos_token_id | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if training_args.do_eval: | 
					
						
						|  | if hasattr(trainer, '_test_collator') and hasattr(trainer, '_eval_collator') \ | 
					
						
						|  | and trainer._test_collator != trainer._eval_collator: | 
					
						
						|  | warnings.warn('[WARNING!!!] use different collator for eval and test. but do_eval and ' | 
					
						
						|  | 'do_predict both use trainer.predict (i.e. only test_collator is used.)') | 
					
						
						|  | eval_results = trainer.predict(dataset['validation'], metric_key_prefix="eval", **gen_kwargs) | 
					
						
						|  | trainer.log_metrics("eval", eval_results.metrics) | 
					
						
						|  | trainer.save_metrics("eval", eval_results.metrics) | 
					
						
						|  | trainer.save_prediction(eval_results, file_key_prefix='eval') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if training_args.do_predict: | 
					
						
						|  | predict_results = trainer.predict(dataset['test'], metric_key_prefix="test", **gen_kwargs) | 
					
						
						|  | trainer.log_metrics("test", predict_results.metrics) | 
					
						
						|  | trainer.save_metrics("test", predict_results.metrics) | 
					
						
						|  | trainer.save_prediction(predict_results, file_key_prefix='test') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if training_args.do_multi_predict: | 
					
						
						|  | old_compute_metrics = trainer.compute_metrics | 
					
						
						|  | multitest = dataset['multitest'] | 
					
						
						|  | multitest = typing.cast(dict, multitest) | 
					
						
						|  | for _idx, (k, item) in enumerate(multitest.items()): | 
					
						
						|  | print(f'processing multitest set {_idx}/{len(multitest)}: {k}') | 
					
						
						|  | _ds = item['dataset'] | 
					
						
						|  | _compute_metrics = item['compute_metric'] | 
					
						
						|  | _prefix = f"multitest_{k}" | 
					
						
						|  |  | 
					
						
						|  | trainer.compute_metrics = _compute_metrics | 
					
						
						|  | _pred_results = trainer.predict(_ds, metric_key_prefix=_prefix, **gen_kwargs) | 
					
						
						|  | trainer.log_metrics(_prefix, _pred_results.metrics) | 
					
						
						|  | trainer.save_metrics(_prefix, _pred_results.metrics) | 
					
						
						|  | trainer.save_prediction(_pred_results, file_key_prefix=_prefix) | 
					
						
						|  | trainer.compute_metrics = old_compute_metrics | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _mp_fn(index): | 
					
						
						|  |  | 
					
						
						|  | main() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main() | 
					
						
						|  |  |