Spaces:
Running
Running
import base64 | |
from uuid import uuid4 | |
import gradio as gr | |
from fastcore.all import * | |
from fastai.vision.all import * | |
import numpy as np | |
import timm | |
def parent_labels(o): | |
"Label `item` with the parent folder name." | |
return Path(o).parent.name.split(",") | |
class LabelSmoothingBCEWithLogitsLossFlat(BCEWithLogitsLossFlat): | |
def __init__(self, eps:float=0.1, **kwargs): | |
self.eps = eps | |
super().__init__(thresh=0.1, **kwargs) | |
def __call__(self, inp, targ, **kwargs): | |
targ_smooth = targ.float() * (1. - self.eps) + 0.5 * self.eps | |
return super().__call__(inp, targ_smooth, **kwargs) | |
learn = load_learner('models.pkl') | |
# set a new loss function with a threshold of 0.4 to remove more false positives | |
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) | |
def predict_tags(image, vtt, threshold=0.4): | |
vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", "")) | |
sprite = PILImage.create(image) | |
offsets = [] | |
times = [] | |
images = [] | |
frames = [] | |
for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)): | |
frames.append(i) | |
times.append(time_seconds) | |
offsets.append((left, top, right, bottom)) | |
cut_frame = sprite.crop((left, top, left + right, top + bottom)) | |
images.append(PILImage.create(np.asarray(cut_frame))) | |
# create dataset | |
threshold = threshold or 0.4 | |
learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold) | |
test_dl = learn.dls.test_dl(images, bs=64) | |
# get predictions | |
probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True) | |
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) | |
# swivel into tags list from activations | |
tags = {} | |
for idx1, activation in enumerate(activations): | |
for idx2, i in enumerate(activation): | |
if not i: | |
continue | |
tag = learn.dls.vocab[idx2] | |
tag = tag.replace("_", " ") | |
if tag not in tags: | |
tags[tag] = {'prob': 0, 'offset': (), 'frame': 0} | |
prob = float(probabilities[idx1][idx2]) | |
if tags[tag]['prob'] < prob: | |
tags[tag]['prob'] = prob | |
tags[tag]['offset'] = offsets[idx1] | |
tags[tag]['frame'] = idx1 | |
tags[tag]['time'] = times[idx1] | |
return tags | |
def predict_markers(image, vtt, threshold=0.4): | |
vtt = base64.b64decode(vtt.replace("data:text/vtt;base64,", "")) | |
sprite = PILImage.create(image) | |
offsets = [] | |
times = [] | |
images = [] | |
frames = [] | |
for i, (left, top, right, bottom, time_seconds) in enumerate(getVTToffsets(vtt)): | |
frames.append(i) | |
times.append(time_seconds) | |
offsets.append((left, top, right, bottom)) | |
cut_frame = sprite.crop((left, top, left + right, top + bottom)) | |
images.append(PILImage.create(np.asarray(cut_frame))) | |
# create dataset | |
threshold = threshold or 0.4 | |
learn.loss_func = BCEWithLogitsLossFlat(thresh=threshold) | |
test_dl = learn.dls.test_dl(images, bs=64) | |
# get predictions | |
probabilities, _, activations = learn.get_preds(dl=test_dl, with_decoded=True) | |
learn.loss_func = BCEWithLogitsLossFlat(thresh=0.4) | |
# swivel into tags list from activations | |
all_data_per_frame = [] | |
for idx1, activation in enumerate(activations): | |
frame_data = {'offset': offsets[idx1], 'frame': idx1, 'time': times[idx1], 'tags': []} | |
ftags = [] | |
for idx2, i in enumerate(activation): | |
if not i: | |
continue | |
tag = learn.dls.vocab[idx2] | |
tag = tag.replace("_", " ") | |
prob = float(probabilities[idx1][idx2]) | |
ftags.append({'label': tag, 'prob': prob}) | |
if not ftags: | |
continue | |
frame_data['tags'] = ftags | |
all_data_per_frame.append(frame_data) | |
filtered = [] | |
for idx, frame_data in enumerate(all_data_per_frame): | |
if idx == len(all_data_per_frame) - 1: | |
break | |
next_frame_data = all_data_per_frame[idx + 1] | |
frame_data['tags'] = [tag for tag in frame_data['tags'] for next_tag in next_frame_data['tags'] if tag['label'] == next_tag['label']] | |
if frame_data['tags']: | |
filtered.append(frame_data) | |
last_tag = set() | |
results = [] | |
for frame_data in filtered: | |
tags = {s['label'] for s in frame_data['tags']} | |
if tags.intersection(last_tag): | |
continue | |
last_tag = tags | |
frame_data['tag'] = sorted(frame_data['tags'], key=lambda x: x['prob'], reverse=True)[0] | |
del frame_data['tags'] | |
# add unique id to the frame | |
frame_data['id'] = str(uuid4()) | |
results.append(frame_data) | |
return results | |
def getVTToffsets(vtt): | |
time_seconds = 0 | |
left = top = right = bottom = None | |
for line in vtt.decode("utf-8").split("\n"): | |
line = line.strip() | |
if "-->" in line: | |
# grab the start time | |
# 00:00:00.000 --> 00:00:41.000 | |
start = line.split("-->")[0].strip().split(":") | |
# convert to seconds | |
time_seconds = ( | |
int(start[0]) * 3600 | |
+ int(start[1]) * 60 | |
+ float(start[2]) | |
) | |
left = top = right = bottom = None | |
elif "xywh=" in line: | |
left, top, right, bottom = line.split("xywh=")[-1].split(",") | |
left, top, right, bottom = ( | |
int(left), | |
int(top), | |
int(right), | |
int(bottom), | |
) | |
else: | |
continue | |
if not left: | |
continue | |
yield left, top, right, bottom, time_seconds | |
# create a gradio interface with 2 tabs | |
tag = gr.Interface( | |
fn=predict_tags, | |
inputs=[ | |
gr.Image(), | |
gr.Textbox(label="VTT file"), | |
gr.Number(value=0.4, label="Threshold") | |
], | |
outputs=gr.JSON(label=""), | |
) | |
marker = gr.Interface( | |
fn=predict_markers, | |
inputs=[ | |
gr.Image(), | |
gr.Textbox(label="VTT file"), | |
gr.Number(value=0.4, label="Threshold") | |
], | |
outputs=gr.JSON(label=""), | |
) | |
gr.TabbedInterface( | |
[tag, marker], ["tag", "marker"] | |
).launch(server_name="0.0.0.0") | |