Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Merge pull request #29 from borisdayma/load_checkpoint
Browse files- seq2seq/run_seq2seq_flax.py +44 -23
 
    	
        seq2seq/run_seq2seq_flax.py
    CHANGED
    
    | 
         @@ -125,6 +125,12 @@ class ModelArguments: 
     | 
|
| 125 | 
         
             
                        "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
         
     | 
| 126 | 
         
             
                    },
         
     | 
| 127 | 
         
             
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 128 | 
         | 
| 129 | 
         | 
| 130 | 
         
             
            @dataclass
         
     | 
| 
         @@ -424,36 +430,51 @@ def main(): 
     | 
|
| 424 | 
         
             
                # https://huggingface.co/docs/datasets/loading_datasets.html.
         
     | 
| 425 | 
         | 
| 426 | 
         
             
                # Load pretrained model and tokenizer
         
     | 
| 427 | 
         
            -
                base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
         
     | 
| 428 | 
         
            -
                    model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
         
     | 
| 429 | 
         
            -
                )
         
     | 
| 430 | 
         
             
                tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 431 | 
         
             
                    model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
         
     | 
| 432 | 
         
             
                )
         
     | 
| 433 | 
         | 
| 434 | 
         
            -
                 
     | 
| 435 | 
         
            -
             
     | 
| 436 | 
         
            -
             
     | 
| 437 | 
         
            -
             
     | 
| 438 | 
         
            -
                config.bos_token_id = BOS_TOKEN_ID  # should not be used (due to forced_bos_token_id)
         
     | 
| 439 | 
         
            -
                config.pos_token_id = BOS_TOKEN_ID  # should not be needed (as we generate until max_length)
         
     | 
| 440 | 
         
            -
                config.eos_token_id = BOS_TOKEN_ID + 1  # unreachable
         
     | 
| 441 | 
         
            -
                config.forced_bos_token_id = None  # we don't need this token
         
     | 
| 442 | 
         
            -
                config.forced_eos_token_id = None  # we don't need this token
         
     | 
| 443 | 
         
            -
                config.force_bos_token_to_be_generated = False  # otherwise it sets bos_token_id at loading
         
     | 
| 444 | 
         
            -
                config.min_length = data_args.max_target_length
         
     | 
| 445 | 
         
            -
                config.max_length = data_args.max_target_length
         
     | 
| 446 | 
         | 
| 447 | 
         
            -
             
     | 
| 448 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 449 | 
         | 
| 450 | 
         
            -
                 
     | 
| 451 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 452 | 
         | 
| 453 | 
         
            -
                 
     | 
| 454 | 
         
            -
                 
     | 
| 455 | 
         
            -
                model.params['model']['shared'] = base_model.params['model']['shared']
         
     | 
| 456 | 
         
            -
                del base_model
         
     | 
| 457 | 
         | 
| 458 | 
         
             
                prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
         
     | 
| 459 | 
         | 
| 
         | 
|
| 125 | 
         
             
                        "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
         
     | 
| 126 | 
         
             
                    },
         
     | 
| 127 | 
         
             
                )
         
     | 
| 128 | 
         
            +
                from_checkpoint: Optional[str] = field(
         
     | 
| 129 | 
         
            +
                    default=None,
         
     | 
| 130 | 
         
            +
                    metadata={
         
     | 
| 131 | 
         
            +
                        "help": "Loads a pretrained wandb checkpoint. Use artifact reference."
         
     | 
| 132 | 
         
            +
                    },
         
     | 
| 133 | 
         
            +
                )
         
     | 
| 134 | 
         | 
| 135 | 
         | 
| 136 | 
         
             
            @dataclass
         
     | 
| 
         | 
|
| 430 | 
         
             
                # https://huggingface.co/docs/datasets/loading_datasets.html.
         
     | 
| 431 | 
         | 
| 432 | 
         
             
                # Load pretrained model and tokenizer
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 433 | 
         
             
                tokenizer = AutoTokenizer.from_pretrained(
         
     | 
| 434 | 
         
             
                    model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
         
     | 
| 435 | 
         
             
                )
         
     | 
| 436 | 
         | 
| 437 | 
         
            +
                if model_args.from_checkpoint is not None:
         
     | 
| 438 | 
         
            +
                    artifact = wandb.run.use_artifact(model_args.from_checkpoint)
         
     | 
| 439 | 
         
            +
                    artifact_dir = artifact.download()
         
     | 
| 440 | 
         
            +
                    model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 441 | 
         | 
| 442 | 
         
            +
                    # some models will try to change bos (because of force_bos_token_to_be_generated)
         
     | 
| 443 | 
         
            +
                    # we ensure bos and eos are not forced
         
     | 
| 444 | 
         
            +
                    model.config.force_bos_token_to_be_generated = False
         
     | 
| 445 | 
         
            +
                    model.config.forced_bos_token_id = None
         
     | 
| 446 | 
         
            +
                    model.config.forced_eos_token_id = None
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
                    # used in the preprocessing function
         
     | 
| 449 | 
         
            +
                    config = model.config
         
     | 
| 450 | 
         | 
| 451 | 
         
            +
                else:
         
     | 
| 452 | 
         
            +
                    base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
         
     | 
| 453 | 
         
            +
                        model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
         
     | 
| 454 | 
         
            +
                    )
         
     | 
| 455 | 
         
            +
                    # Set up our new model config
         
     | 
| 456 | 
         
            +
                    config = BartConfig.from_pretrained(model_args.model_name_or_path)
         
     | 
| 457 | 
         
            +
                    config.tie_word_embeddings = False
         
     | 
| 458 | 
         
            +
                    config.decoder_start_token_id = BOS_TOKEN_ID  # for first token
         
     | 
| 459 | 
         
            +
                    config.bos_token_id = BOS_TOKEN_ID  # should not be used (due to forced_bos_token_id)
         
     | 
| 460 | 
         
            +
                    config.pos_token_id = BOS_TOKEN_ID  # should not be needed (as we generate until max_length)
         
     | 
| 461 | 
         
            +
                    config.eos_token_id = BOS_TOKEN_ID + 1  # unreachable
         
     | 
| 462 | 
         
            +
                    config.forced_bos_token_id = None  # we don't need this token
         
     | 
| 463 | 
         
            +
                    config.forced_eos_token_id = None  # we don't need this token
         
     | 
| 464 | 
         
            +
                    config.force_bos_token_to_be_generated = False  # otherwise it sets bos_token_id at loading
         
     | 
| 465 | 
         
            +
                    config.min_length = data_args.max_target_length
         
     | 
| 466 | 
         
            +
                    config.max_length = data_args.max_target_length
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                    # Create a custom model and initialize it randomly
         
     | 
| 469 | 
         
            +
                    model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
         
     | 
| 470 | 
         
            +
             
     | 
| 471 | 
         
            +
                    # Use pre-trained weights for encoder
         
     | 
| 472 | 
         
            +
                    model.params['model']['encoder'] = base_model.params['model']['encoder']
         
     | 
| 473 | 
         
            +
                    model.params['model']['shared'] = base_model.params['model']['shared']
         
     | 
| 474 | 
         
            +
                    del base_model
         
     | 
| 475 | 
         | 
| 476 | 
         
            +
                print(f"TPUs: {jax.device_count()}")
         
     | 
| 477 | 
         
            +
                assert jax.device_count() == 8, "TPUs in use, please check running processes"
         
     | 
| 
         | 
|
| 
         | 
|
| 478 | 
         | 
| 479 | 
         
             
                prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
         
     | 
| 480 | 
         |