Spaces:
Runtime error
Runtime error
feat: use resize transform
Browse files- src/augmentations.py +12 -0
src/augmentations.py
CHANGED
|
@@ -7,6 +7,17 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
| 7 |
from torchvision import transforms
|
| 8 |
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class GroupNormalize:
|
| 11 |
def __init__(self, mean: List[float], std: List[float]) -> None:
|
| 12 |
self.mean = mean
|
|
@@ -109,6 +120,7 @@ class TubeMaskingGenerator:
|
|
| 109 |
def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
|
| 110 |
return transforms.Compose(
|
| 111 |
[
|
|
|
|
| 112 |
GroupCenterCrop(input_size),
|
| 113 |
Stack(roll=False),
|
| 114 |
ToTorchFormatTensor(div=True),
|
|
|
|
| 7 |
from torchvision import transforms
|
| 8 |
|
| 9 |
|
| 10 |
+
class GroupResize:
|
| 11 |
+
def __init__(self, size: int = 256) -> None:
|
| 12 |
+
self.transform = transforms.Resize(size)
|
| 13 |
+
|
| 14 |
+
def __call__(
|
| 15 |
+
self, img_tuple: Tuple[torch.Tensor, torch.Tensor]
|
| 16 |
+
) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
| 17 |
+
img_group, label = img_tuple
|
| 18 |
+
return [self.transform(img) for img in img_group], label
|
| 19 |
+
|
| 20 |
+
|
| 21 |
class GroupNormalize:
|
| 22 |
def __init__(self, mean: List[float], std: List[float]) -> None:
|
| 23 |
self.mean = mean
|
|
|
|
| 120 |
def get_videomae_transform(input_size: int = 224) -> "transforms.Compose":
|
| 121 |
return transforms.Compose(
|
| 122 |
[
|
| 123 |
+
GroupResize(size=384),
|
| 124 |
GroupCenterCrop(input_size),
|
| 125 |
Stack(roll=False),
|
| 126 |
ToTorchFormatTensor(div=True),
|