Antonio
commited on
Commit
·
f50e742
1
Parent(s):
357c13c
Cleaned Names
Browse files
app.py
CHANGED
|
@@ -184,14 +184,18 @@ decision_frameworks = {
|
|
| 184 |
def predict(video_file, video_model_name, audio_model_name, framework_name):
|
| 185 |
|
| 186 |
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
model_id = "facebook/wav2vec2-large"
|
| 190 |
config = AutoConfig.from_pretrained(model_id, num_labels=6)
|
| 191 |
audio_processor = AutoFeatureExtractor.from_pretrained(model_id)
|
| 192 |
audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id, config=config)
|
| 193 |
-
|
| 194 |
-
|
|
|
|
| 195 |
|
| 196 |
delete_directory_path = "./temp/"
|
| 197 |
|
|
@@ -219,8 +223,8 @@ def predict(video_file, video_model_name, audio_model_name, framework_name):
|
|
| 219 |
|
| 220 |
inputs = [
|
| 221 |
gr.File(label="Upload Video"),
|
| 222 |
-
gr.Dropdown(["
|
| 223 |
-
gr.Dropdown(["
|
| 224 |
gr.Dropdown(list(decision_frameworks.keys()), label="Select Decision Framework")
|
| 225 |
]
|
| 226 |
|
|
|
|
| 184 |
def predict(video_file, video_model_name, audio_model_name, framework_name):
|
| 185 |
|
| 186 |
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
|
| 187 |
+
if video_model_name == "60% Accuracy":
|
| 188 |
+
video_model = torch.load("video_model_60_acc.pth", map_location=torch.device('cpu'))
|
| 189 |
+
elif video_model_name == "80% Accuracy":
|
| 190 |
+
video_model = torch.load("video_model_80_acc.pth", map_location=torch.device('cpu'))
|
| 191 |
|
| 192 |
model_id = "facebook/wav2vec2-large"
|
| 193 |
config = AutoConfig.from_pretrained(model_id, num_labels=6)
|
| 194 |
audio_processor = AutoFeatureExtractor.from_pretrained(model_id)
|
| 195 |
audio_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id, config=config)
|
| 196 |
+
if audio_model_name == "60% Accuracy":
|
| 197 |
+
audio_model.load_state_dict(torch.load("audio_model_state_dict_6e.pth", map_location=torch.device('cpu')))
|
| 198 |
+
audio_model.eval()
|
| 199 |
|
| 200 |
delete_directory_path = "./temp/"
|
| 201 |
|
|
|
|
| 223 |
|
| 224 |
inputs = [
|
| 225 |
gr.File(label="Upload Video"),
|
| 226 |
+
gr.Dropdown(["60% Accuracy", "80% Accuracy"], label="Select Video Model"),
|
| 227 |
+
gr.Dropdown(["60% Accuracy"], label="Select Audio Model"),
|
| 228 |
gr.Dropdown(list(decision_frameworks.keys()), label="Select Decision Framework")
|
| 229 |
]
|
| 230 |
|