Spaces:
Running
Running
| """ | |
| T5 model trained to generate text from text | |
| --------------------------------------------------------------------- | |
| """ | |
| import json | |
| import os | |
| import torch | |
| import transformers | |
| from textattack.model_args import TEXTATTACK_MODELS | |
| from textattack.models.tokenizers import T5Tokenizer | |
| class T5ForTextToText(torch.nn.Module): | |
| """A T5 model trained to generate text from text. | |
| For more information, please see the T5 paper, "Exploring the Limits of | |
| Transfer Learning with a Unified Text-to-Text Transformer". | |
| Appendix D contains information about the various tasks supported | |
| by T5. | |
| For usage information, see HuggingFace Transformers documentation section | |
| on text-to-text with T5: | |
| https://huggingface.co/transformers/usage.html. | |
| Args: | |
| mode (string): Name of the T5 model to use. | |
| output_max_length (int): The max length of the sequence to be generated. | |
| Between 1 and infinity. | |
| input_max_length (int): Max length of the input sequence. | |
| num_beams (int): Number of beams for beam search. Must be between 1 and | |
| infinity. 1 means no beam search. | |
| early_stopping (bool): if set to `True` beam search is stopped when at | |
| least `num_beams` sentences finished per batch. Defaults to `True`. | |
| """ | |
| def __init__( | |
| self, | |
| mode="english_to_german", | |
| output_max_length=20, | |
| input_max_length=64, | |
| num_beams=1, | |
| early_stopping=True, | |
| ): | |
| super().__init__() | |
| self.model = transformers.T5ForConditionalGeneration.from_pretrained("t5-base") | |
| self.model.eval() | |
| self.tokenizer = T5Tokenizer(mode, max_length=output_max_length) | |
| self.mode = mode | |
| self.output_max_length = output_max_length | |
| self.input_max_length = input_max_length | |
| self.num_beams = num_beams | |
| self.early_stopping = early_stopping | |
| def __call__(self, *args, **kwargs): | |
| # Generate IDs from the model. | |
| output_ids_list = self.model.generate( | |
| *args, | |
| **kwargs, | |
| max_length=self.output_max_length, | |
| num_beams=self.num_beams, | |
| early_stopping=self.early_stopping, | |
| ) | |
| # Convert ID tensor to string and return. | |
| return [self.tokenizer.decode(ids) for ids in output_ids_list] | |
| def save_pretrained(self, output_dir): | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| config = { | |
| "mode": self.mode, | |
| "output_max_length": self.output_max_length, | |
| "input_max_length": self.input_max_length, | |
| "num_beams": self.num_beams, | |
| "early_stoppping": self.early_stopping, | |
| } | |
| # We don't save it as `config.json` b/c that name conflicts with HuggingFace's `config.json`. | |
| with open(os.path.join(output_dir, "t5-wrapper-config.json"), "w") as f: | |
| json.dump(config, f) | |
| self.model.save_pretrained(output_dir) | |
| def from_pretrained(cls, name_or_path): | |
| """Load trained LSTM model by name or from path. | |
| Args: | |
| name_or_path (str): Name of the model (e.g. "t5-en-de") or model saved via `save_pretrained`. | |
| """ | |
| if name_or_path in TEXTATTACK_MODELS: | |
| t5 = cls(TEXTATTACK_MODELS[name_or_path]) | |
| return t5 | |
| else: | |
| config_path = os.path.join(name_or_path, "t5-wrapper-config.json") | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| t5 = cls.__new__(cls) | |
| for key in config: | |
| setattr(t5, key, config[key]) | |
| t5.model = transformers.T5ForConditionalGeneration.from_pretrained( | |
| name_or_path | |
| ) | |
| t5.tokenizer = T5Tokenizer(t5.mode, max_length=t5.output_max_length) | |
| return t5 | |
| def get_input_embeddings(self): | |
| return self.model.get_input_embeddings() | |