Update README.md
Browse files
README.md
CHANGED
@@ -30,4 +30,56 @@ Model achieves:
|
|
30 |
|
31 |
Class-wise accuracies:
|
32 |
- *shot scale*: ECS - 90.92%, CS - 83.2%, MS - 85.0%, FS - 89.71%, LS - 94.55%
|
33 |
-
- *shot movement*: Static - 94.6%, Motion - 87.7%, Pull - 57.5%, Push - 66.82%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
Class-wise accuracies:
|
32 |
- *shot scale*: ECS - 90.92%, CS - 83.2%, MS - 85.0%, FS - 89.71%, LS - 94.55%
|
33 |
+
- *shot movement*: Static - 94.6%, Motion - 87.7%, Pull - 57.5%, Push - 66.82%
|
34 |
+
|
35 |
+
|
36 |
+
## Model Definition
|
37 |
+
```python
|
38 |
+
from transformers import VideoMAEImageProcessor, VideoMAEModel, VideoMAEConfig, PreTrainedModel
|
39 |
+
|
40 |
+
class CustomVideoMAEConfig(VideoMAEConfig):
|
41 |
+
def __init__(self, scale_label2id=None, scale_id2label=None, movement_label2id=None, movement_id2label=None, **kwargs):
|
42 |
+
super().__init__(**kwargs)
|
43 |
+
self.scale_label2id = scale_label2id if scale_label2id is not None else {}
|
44 |
+
self.scale_id2label = scale_id2label if scale_id2label is not None else {}
|
45 |
+
self.movement_label2id = movement_label2id if movement_label2id is not None else {}
|
46 |
+
self.movement_id2label = movement_id2label if movement_id2label is not None else {}
|
47 |
+
|
48 |
+
|
49 |
+
class CustomModel(PreTrainedModel):
|
50 |
+
config_class = CustomVideoMAEConfig
|
51 |
+
|
52 |
+
def __init__(self, config, model_name, scale_num_classes, movement_num_classes):
|
53 |
+
super().__init__(config)
|
54 |
+
self.vmae = VideoMAEModel.from_pretrained(model_name, ignore_mismatched_sizes=True)
|
55 |
+
self.fc_norm = nn.LayerNorm(config.hidden_size) if config.use_mean_pooling else None
|
56 |
+
self.scale_cf = nn.Linear(config.hidden_size, scale_num_classes)
|
57 |
+
self.movement_cf = nn.Linear(config.hidden_size, movement_num_classes)
|
58 |
+
|
59 |
+
def forward(self, pixel_values, scale_labels=None, movement_labels=None):
|
60 |
+
|
61 |
+
vmae_outputs = self.vmae(pixel_values)
|
62 |
+
sequence_output = vmae_outputs[0]
|
63 |
+
|
64 |
+
if self.fc_norm is not None:
|
65 |
+
sequence_output = self.fc_norm(sequence_output.mean(1))
|
66 |
+
else:
|
67 |
+
sequence_output = sequence_output[:, 0]
|
68 |
+
|
69 |
+
scale_logits = self.scale_cf(sequence_output)
|
70 |
+
movement_logits = self.movement_cf(sequence_output)
|
71 |
+
|
72 |
+
if scale_labels is not None and movement_labels is not None:
|
73 |
+
loss = F.cross_entropy(scale_logits, scale_labels) + F.cross_entropy(movement_logits, movement_labels)
|
74 |
+
return {"loss": loss, "scale_logits": scale_logits, "movement_logits": movement_logits}
|
75 |
+
return {"scale_logits": scale_logits, "movement_logits": movement_logits}
|
76 |
+
|
77 |
+
|
78 |
+
scale_lab2id = {"ECS": 0, "CS": 1, "MS": 2, "FS": 3, "LS": 4}
|
79 |
+
scale_id2lab = {v:k for k,v in scale_lab2id.items()}
|
80 |
+
movement_lab2id = {"Static": 0, "Motion": 1, "Pull": 2, "Push": 3}
|
81 |
+
movement_id2lab = {v:k for k,v in movement_lab2id.items()}
|
82 |
+
|
83 |
+
config = CustomVideoMAEConfig(scale_lab2id, scale_id2lab, movement_lab2id, movement_id2lab)
|
84 |
+
model = CustomModel(config, model_name, 5, 4)
|
85 |
+
```
|