Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -640,7 +640,7 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 640 |
|
| 641 |
def forward(
|
| 642 |
self,
|
| 643 |
-
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
|
| 644 |
projection_english_tokens_ids: torch.Tensor,
|
| 645 |
projected_bio_embeddings: torch.Tensor = None,
|
| 646 |
) -> dict[str, torch.Tensor]:
|
|
@@ -671,8 +671,9 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 671 |
"""
|
| 672 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
| 673 |
english_token_ids = english_token_ids.clone()
|
| 674 |
-
bio_token_ids = bio_token_ids.clone()
|
| 675 |
projection_english_tokens_ids = projection_english_tokens_ids.clone()
|
|
|
|
|
|
|
| 676 |
if projected_bio_embeddings is not None:
|
| 677 |
projected_bio_embeddings = projected_bio_embeddings.clone()
|
| 678 |
|
|
|
|
| 640 |
|
| 641 |
def forward(
|
| 642 |
self,
|
| 643 |
+
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor | None],
|
| 644 |
projection_english_tokens_ids: torch.Tensor,
|
| 645 |
projected_bio_embeddings: torch.Tensor = None,
|
| 646 |
) -> dict[str, torch.Tensor]:
|
|
|
|
| 671 |
"""
|
| 672 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
| 673 |
english_token_ids = english_token_ids.clone()
|
|
|
|
| 674 |
projection_english_tokens_ids = projection_english_tokens_ids.clone()
|
| 675 |
+
if bio_token_ids is not None:
|
| 676 |
+
bio_token_ids = bio_token_ids.clone()
|
| 677 |
if projected_bio_embeddings is not None:
|
| 678 |
projected_bio_embeddings = projected_bio_embeddings.clone()
|
| 679 |
|