Update text_generation.py
Browse files- text_generation.py +13 -8
text_generation.py
CHANGED
|
@@ -55,19 +55,24 @@ class TextGenerationPipeline(Pipeline):
|
|
| 55 |
truncation=True,
|
| 56 |
max_length=english_tokens_max_length,
|
| 57 |
).input_ids
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
|
| 67 |
|
| 68 |
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
|
| 69 |
english_tokens = model_inputs["english_tokens"].clone()
|
| 70 |
-
bio_tokens = model_inputs["bio_tokens"]
|
|
|
|
|
|
|
| 71 |
projected_bio_embeddings = None
|
| 72 |
|
| 73 |
actual_num_steps = 0
|
|
|
|
| 55 |
truncation=True,
|
| 56 |
max_length=english_tokens_max_length,
|
| 57 |
).input_ids
|
| 58 |
+
if len(dna_sequences) == 0:
|
| 59 |
+
bio_tokens = None
|
| 60 |
+
else:
|
| 61 |
+
bio_tokens = self.bio_tokenizer(
|
| 62 |
+
dna_sequences,
|
| 63 |
+
return_tensors="pt",
|
| 64 |
+
padding="max_length",
|
| 65 |
+
max_length=bio_tokens_max_length,
|
| 66 |
+
truncation=True,
|
| 67 |
+
).input_ids.unsqueeze(0)
|
| 68 |
|
| 69 |
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
|
| 70 |
|
| 71 |
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
|
| 72 |
english_tokens = model_inputs["english_tokens"].clone()
|
| 73 |
+
bio_tokens = model_inputs["bio_tokens"]
|
| 74 |
+
if bio_tokens is not None:
|
| 75 |
+
bio_tokens = bio_tokens.clone()
|
| 76 |
projected_bio_embeddings = None
|
| 77 |
|
| 78 |
actual_num_steps = 0
|