Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
·
90d1f68
1
Parent(s):
4678c9b
Add functionality to predict self-promo and interaction reminders
Browse files- src/evaluate.py +2 -4
- src/predict.py +31 -24
- src/preprocess.py +111 -55
- src/segment.py +3 -3
- src/shared.py +5 -6
- src/train.py +15 -15
- src/utils.py +6 -0
src/evaluate.py
CHANGED
|
@@ -105,13 +105,13 @@ def calculate_metrics(labelled_words, predictions):
|
|
| 105 |
|
| 106 |
if predicted_sponsor:
|
| 107 |
# total_positive_time += duration
|
| 108 |
-
if word['
|
| 109 |
metrics['true_positive'] += duration
|
| 110 |
else:
|
| 111 |
metrics['false_positive'] += duration
|
| 112 |
else:
|
| 113 |
# total_negative_time += duration
|
| 114 |
-
if word['
|
| 115 |
metrics['false_negative'] += duration
|
| 116 |
else:
|
| 117 |
metrics['true_negative'] += duration
|
|
@@ -176,8 +176,6 @@ def main():
|
|
| 176 |
with open(final_path) as fp:
|
| 177 |
final_data = json.load(fp)
|
| 178 |
|
| 179 |
-
classifier, vectorizer = get_classifier_vectorizer(classifier_args)
|
| 180 |
-
|
| 181 |
total_accuracy = 0
|
| 182 |
total_precision = 0
|
| 183 |
total_recall = 0
|
|
|
|
| 105 |
|
| 106 |
if predicted_sponsor:
|
| 107 |
# total_positive_time += duration
|
| 108 |
+
if word['category'] is not None: # Is actual sponsor
|
| 109 |
metrics['true_positive'] += duration
|
| 110 |
else:
|
| 111 |
metrics['false_positive'] += duration
|
| 112 |
else:
|
| 113 |
# total_negative_time += duration
|
| 114 |
+
if word['category'] is not None: # Is actual sponsor
|
| 115 |
metrics['false_negative'] += duration
|
| 116 |
else:
|
| 117 |
metrics['true_negative'] += duration
|
|
|
|
| 176 |
with open(final_path) as fp:
|
| 177 |
final_data = json.load(fp)
|
| 178 |
|
|
|
|
|
|
|
| 179 |
total_accuracy = 0
|
| 180 |
total_precision = 0
|
| 181 |
total_recall = 0
|
src/predict.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from
|
| 2 |
from shared import OutputArguments
|
| 3 |
from typing import Optional
|
| 4 |
from segment import (
|
|
@@ -11,21 +11,22 @@ from segment import (
|
|
| 11 |
SegmentationArguments
|
| 12 |
)
|
| 13 |
import preprocess
|
| 14 |
-
import re
|
| 15 |
from errors import TranscriptError
|
| 16 |
from model import get_classifier_vectorizer
|
| 17 |
from transformers import (
|
| 18 |
AutoModelForSeq2SeqLM,
|
| 19 |
-
AutoTokenizer
|
|
|
|
| 20 |
)
|
|
|
|
| 21 |
from dataclasses import dataclass, field
|
| 22 |
-
from transformers import HfArgumentParser
|
| 23 |
from shared import device
|
| 24 |
import logging
|
| 25 |
|
| 26 |
|
| 27 |
def seconds_to_time(seconds):
|
| 28 |
-
fractional =
|
|
|
|
| 29 |
h, remainder = divmod(abs(int(seconds)), 3600)
|
| 30 |
m, s = divmod(remainder, 60)
|
| 31 |
return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}"
|
|
@@ -64,7 +65,7 @@ class PredictArguments(TrainingOutputArguments):
|
|
| 64 |
)
|
| 65 |
|
| 66 |
|
| 67 |
-
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.
|
| 68 |
|
| 69 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
| 70 |
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
|
@@ -97,11 +98,13 @@ class ClassifierArguments:
|
|
| 97 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
| 98 |
|
| 99 |
|
| 100 |
-
def filter_predictions(predictions, classifier, vectorizer,
|
| 101 |
"""Use classifier to filter predictions"""
|
| 102 |
if not predictions:
|
| 103 |
return predictions
|
| 104 |
|
|
|
|
|
|
|
| 105 |
transformed_segments = vectorizer.transform([
|
| 106 |
preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
|
| 107 |
for pred in predictions
|
|
@@ -142,9 +145,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
| 142 |
words, prediction['start'], prediction['end'])
|
| 143 |
|
| 144 |
if classifier_args is not None:
|
| 145 |
-
|
| 146 |
-
predictions = filter_predictions(
|
| 147 |
-
predictions, classifier, vectorizer, classifier_args)
|
| 148 |
|
| 149 |
return predictions
|
| 150 |
|
|
@@ -166,13 +167,10 @@ def greedy_match(list, sublist):
|
|
| 166 |
return best_i, best_j, best_k
|
| 167 |
|
| 168 |
|
| 169 |
-
DEFAULT_TOKEN_PREFIX = 'summarize: '
|
| 170 |
-
|
| 171 |
-
|
| 172 |
def predict_sponsor_text(text, model, tokenizer):
|
| 173 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
| 174 |
input_ids = tokenizer(
|
| 175 |
-
f'{
|
| 176 |
|
| 177 |
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
| 178 |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|
|
@@ -183,10 +181,11 @@ def predict_sponsor_text(text, model, tokenizer):
|
|
| 183 |
|
| 184 |
def predict_sponsor_matches(text, model, tokenizer):
|
| 185 |
sponsorship_text = predict_sponsor_text(text, model, tokenizer)
|
| 186 |
-
|
|
|
|
| 187 |
return []
|
| 188 |
|
| 189 |
-
return
|
| 190 |
|
| 191 |
|
| 192 |
def segments_to_prediction_times(segments, model, tokenizer):
|
|
@@ -202,7 +201,7 @@ def segments_to_prediction_times(segments, model, tokenizer):
|
|
| 202 |
matches = predict_sponsor_matches(batch_text, model, tokenizer)
|
| 203 |
|
| 204 |
for match in matches:
|
| 205 |
-
matched_text = match.split()
|
| 206 |
# TODO skip if too short
|
| 207 |
|
| 208 |
i1, j1, k1 = greedy_match(
|
|
@@ -217,7 +216,8 @@ def segments_to_prediction_times(segments, model, tokenizer):
|
|
| 217 |
|
| 218 |
predicted_time_ranges.append({
|
| 219 |
'start': word_start(extracted_words[0]),
|
| 220 |
-
'end': word_end(extracted_words[-1])
|
|
|
|
| 221 |
})
|
| 222 |
|
| 223 |
# Necessary to sort matches by start time
|
|
@@ -225,23 +225,29 @@ def segments_to_prediction_times(segments, model, tokenizer):
|
|
| 225 |
|
| 226 |
# Merge overlapping predictions and sponsorships that are close together
|
| 227 |
# Caused by model having max input size
|
| 228 |
-
|
|
|
|
|
|
|
| 229 |
final_predicted_time_ranges = []
|
| 230 |
for range in predicted_time_ranges:
|
| 231 |
start_time = range['start']
|
| 232 |
end_time = range['end']
|
| 233 |
|
| 234 |
-
if
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
final_predicted_time_ranges[-1]['end'] = end_time
|
| 237 |
|
| 238 |
else: # No overlap, is a new prediction
|
| 239 |
final_predicted_time_ranges.append({
|
| 240 |
'start': start_time,
|
| 241 |
'end': end_time,
|
|
|
|
| 242 |
})
|
| 243 |
|
| 244 |
-
|
| 245 |
|
| 246 |
return final_predicted_time_ranges
|
| 247 |
|
|
@@ -268,7 +274,7 @@ def main():
|
|
| 268 |
|
| 269 |
predict_args.video_id = predict_args.video_id.strip()
|
| 270 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
| 271 |
-
segmentation_args, classifier_args=classifier_args
|
| 272 |
|
| 273 |
video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
|
| 274 |
if not predictions:
|
|
@@ -282,7 +288,8 @@ def main():
|
|
| 282 |
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
| 283 |
print('Time:', seconds_to_time(
|
| 284 |
prediction['start']), '-->', seconds_to_time(prediction['end']))
|
| 285 |
-
print('Probability:', prediction
|
|
|
|
| 286 |
print()
|
| 287 |
|
| 288 |
|
|
|
|
| 1 |
+
from utils import re_findall
|
| 2 |
from shared import OutputArguments
|
| 3 |
from typing import Optional
|
| 4 |
from segment import (
|
|
|
|
| 11 |
SegmentationArguments
|
| 12 |
)
|
| 13 |
import preprocess
|
|
|
|
| 14 |
from errors import TranscriptError
|
| 15 |
from model import get_classifier_vectorizer
|
| 16 |
from transformers import (
|
| 17 |
AutoModelForSeq2SeqLM,
|
| 18 |
+
AutoTokenizer,
|
| 19 |
+
HfArgumentParser
|
| 20 |
)
|
| 21 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 22 |
from dataclasses import dataclass, field
|
|
|
|
| 23 |
from shared import device
|
| 24 |
import logging
|
| 25 |
|
| 26 |
|
| 27 |
def seconds_to_time(seconds):
|
| 28 |
+
fractional = round(seconds % 1, 3)
|
| 29 |
+
fractional = '' if fractional == 0 else str(fractional)[1:]
|
| 30 |
h, remainder = divmod(abs(int(seconds)), 3600)
|
| 31 |
m, s = divmod(remainder, 60)
|
| 32 |
return f"{'-' if seconds < 0 else ''}{h:02}:{m:02}:{s:02}{fractional}"
|
|
|
|
| 65 |
)
|
| 66 |
|
| 67 |
|
| 68 |
+
SPONSOR_MATCH_RE = fr'(?<={CustomTokens.START_SEGMENT.value})\s*_(?P<category>\S+)\s*(?P<text>.*?)\s*(?={CustomTokens.END_SEGMENT.value}|$)'
|
| 69 |
|
| 70 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
| 71 |
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
|
|
|
| 98 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
| 99 |
|
| 100 |
|
| 101 |
+
def filter_predictions(predictions, classifier_args): # classifier, vectorizer,
|
| 102 |
"""Use classifier to filter predictions"""
|
| 103 |
if not predictions:
|
| 104 |
return predictions
|
| 105 |
|
| 106 |
+
classifier, vectorizer = get_classifier_vectorizer(classifier_args)
|
| 107 |
+
|
| 108 |
transformed_segments = vectorizer.transform([
|
| 109 |
preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
|
| 110 |
for pred in predictions
|
|
|
|
| 145 |
words, prediction['start'], prediction['end'])
|
| 146 |
|
| 147 |
if classifier_args is not None:
|
| 148 |
+
predictions = filter_predictions(predictions, classifier_args)
|
|
|
|
|
|
|
| 149 |
|
| 150 |
return predictions
|
| 151 |
|
|
|
|
| 167 |
return best_i, best_j, best_k
|
| 168 |
|
| 169 |
|
|
|
|
|
|
|
|
|
|
| 170 |
def predict_sponsor_text(text, model, tokenizer):
|
| 171 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
| 172 |
input_ids = tokenizer(
|
| 173 |
+
f'{CustomTokens.EXTRACT_SEGMENTS_PREFIX.value} {text}', return_tensors='pt', truncation=True).input_ids.to(device())
|
| 174 |
|
| 175 |
# Can't be longer than input length + SAFETY_TOKENS or model input dim
|
| 176 |
max_out_len = min(len(input_ids[0]) + SAFETY_TOKENS, model.model_dim)
|
|
|
|
| 181 |
|
| 182 |
def predict_sponsor_matches(text, model, tokenizer):
|
| 183 |
sponsorship_text = predict_sponsor_text(text, model, tokenizer)
|
| 184 |
+
|
| 185 |
+
if CustomTokens.NO_SEGMENT.value in sponsorship_text:
|
| 186 |
return []
|
| 187 |
|
| 188 |
+
return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
|
| 189 |
|
| 190 |
|
| 191 |
def segments_to_prediction_times(segments, model, tokenizer):
|
|
|
|
| 201 |
matches = predict_sponsor_matches(batch_text, model, tokenizer)
|
| 202 |
|
| 203 |
for match in matches:
|
| 204 |
+
matched_text = match['text'].split()
|
| 205 |
# TODO skip if too short
|
| 206 |
|
| 207 |
i1, j1, k1 = greedy_match(
|
|
|
|
| 216 |
|
| 217 |
predicted_time_ranges.append({
|
| 218 |
'start': word_start(extracted_words[0]),
|
| 219 |
+
'end': word_end(extracted_words[-1]),
|
| 220 |
+
'category': match['category']
|
| 221 |
})
|
| 222 |
|
| 223 |
# Necessary to sort matches by start time
|
|
|
|
| 225 |
|
| 226 |
# Merge overlapping predictions and sponsorships that are close together
|
| 227 |
# Caused by model having max input size
|
| 228 |
+
|
| 229 |
+
prev_prediction = None
|
| 230 |
+
|
| 231 |
final_predicted_time_ranges = []
|
| 232 |
for range in predicted_time_ranges:
|
| 233 |
start_time = range['start']
|
| 234 |
end_time = range['end']
|
| 235 |
|
| 236 |
+
if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
|
| 237 |
+
start_time <= prev_prediction['end'] <= end_time or start_time -
|
| 238 |
+
prev_prediction['end'] <= MERGE_TIME_WITHIN
|
| 239 |
+
):
|
| 240 |
+
# Ending time of last segment is in this segment or c, so we extend last prediction range
|
| 241 |
final_predicted_time_ranges[-1]['end'] = end_time
|
| 242 |
|
| 243 |
else: # No overlap, is a new prediction
|
| 244 |
final_predicted_time_ranges.append({
|
| 245 |
'start': start_time,
|
| 246 |
'end': end_time,
|
| 247 |
+
'category': range['category']
|
| 248 |
})
|
| 249 |
|
| 250 |
+
prev_prediction = range
|
| 251 |
|
| 252 |
return final_predicted_time_ranges
|
| 253 |
|
|
|
|
| 274 |
|
| 275 |
predict_args.video_id = predict_args.video_id.strip()
|
| 276 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
| 277 |
+
segmentation_args) # TODO add back , classifier_args=classifier_args
|
| 278 |
|
| 279 |
video_url = f'https://www.youtube.com/watch?v={predict_args.video_id}'
|
| 280 |
if not predictions:
|
|
|
|
| 288 |
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
| 289 |
print('Time:', seconds_to_time(
|
| 290 |
prediction['start']), '-->', seconds_to_time(prediction['end']))
|
| 291 |
+
print('Probability:', prediction.get('probability'))
|
| 292 |
+
print('Category:', prediction.get('category'))
|
| 293 |
print()
|
| 294 |
|
| 295 |
|
src/preprocess.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
|
|
| 1 |
import itertools
|
| 2 |
-
from typing import Optional
|
| 3 |
from datasets import load_dataset
|
| 4 |
from model import ModelArguments
|
| 5 |
import segment
|
|
@@ -24,8 +25,10 @@ def find(s, ch):
|
|
| 24 |
return [i for i, ltr in enumerate(s) if ltr == ch]
|
| 25 |
|
| 26 |
|
| 27 |
-
def wordify(transcript):
|
| 28 |
"""Try to replicate format for automatically generated transcripts"""
|
|
|
|
|
|
|
| 29 |
words = []
|
| 30 |
|
| 31 |
for line_index, line in enumerate(transcript):
|
|
@@ -34,9 +37,14 @@ def wordify(transcript):
|
|
| 34 |
continue
|
| 35 |
|
| 36 |
start = line['start']
|
| 37 |
-
next_start = transcript[line_index +
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
duration = end - start
|
| 41 |
|
| 42 |
indices = find(text, ' ') + [len(text)]
|
|
@@ -52,9 +60,9 @@ def wordify(transcript):
|
|
| 52 |
w_start = start + percentage * duration
|
| 53 |
|
| 54 |
words.append({
|
| 55 |
-
'start': round(w_start,
|
| 56 |
-
'duration': round(w_duration,
|
| 57 |
-
'end': round(w_start + w_duration,
|
| 58 |
'text': word,
|
| 59 |
})
|
| 60 |
|
|
@@ -69,6 +77,10 @@ def get_manual_words(transcript_list):
|
|
| 69 |
return wordify(transcript)
|
| 70 |
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def get_auto_words(transcript_list):
|
| 73 |
words = []
|
| 74 |
transcript = transcript_list.find_generated_transcript(['en'])
|
|
@@ -82,7 +94,7 @@ def get_auto_words(transcript_list):
|
|
| 82 |
offset_ms = word.get('tOffsetMs', 0)
|
| 83 |
|
| 84 |
texts = word['utf8'].replace(
|
| 85 |
-
|
| 86 |
).strip().split()
|
| 87 |
|
| 88 |
for text in texts:
|
|
@@ -94,7 +106,7 @@ def get_auto_words(transcript_list):
|
|
| 94 |
return words
|
| 95 |
|
| 96 |
|
| 97 |
-
def get_words(video_id, process=True, fallback=
|
| 98 |
"""Get parsed video transcript with caching system
|
| 99 |
returns None if not processed yet and process is False
|
| 100 |
"""
|
|
@@ -148,21 +160,31 @@ def extract_sponsors(words, min_sponsor_segment_length=5):
|
|
| 148 |
|
| 149 |
paragraphs = []
|
| 150 |
current = []
|
|
|
|
| 151 |
for word in words:
|
| 152 |
-
if
|
| 153 |
-
continue
|
| 154 |
|
| 155 |
-
if word['
|
| 156 |
current.append(word['text'])
|
| 157 |
else:
|
| 158 |
-
paragraphs.append(
|
|
|
|
|
|
|
|
|
|
| 159 |
current = []
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
# Remove all too short:
|
| 164 |
paragraphs = list(filter(lambda x: len(
|
| 165 |
-
x) >= min_sponsor_segment_length, paragraphs))
|
| 166 |
|
| 167 |
return paragraphs
|
| 168 |
|
|
@@ -203,10 +225,8 @@ def clean_text(text):
|
|
| 203 |
text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
|
| 204 |
|
| 205 |
# Replace profanity with special token
|
| 206 |
-
text = text.replace(CustomTokens.
|
| 207 |
-
|
| 208 |
-
text = text.replace(CustomTokens.PROFANITY_CONVERTED.value,
|
| 209 |
-
CustomTokens.PROFANITY.value)
|
| 210 |
|
| 211 |
return text.strip()
|
| 212 |
|
|
@@ -254,11 +274,25 @@ class PreprocessArguments:
|
|
| 254 |
do_create: bool = field(
|
| 255 |
default=False, metadata={'help': 'Merge sponsor segments into single file'}
|
| 256 |
)
|
|
|
|
| 257 |
min_votes: int = field(
|
| 258 |
default=0, metadata={'help': 'Minimum number of votes'})
|
| 259 |
# Downvotes will make this negative.
|
| 260 |
# 1 = At least one positive vote
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
do_transcribe: bool = field(
|
| 263 |
default=False, metadata={'help': 'Get transcripts for videos'}
|
| 264 |
)
|
|
@@ -266,7 +300,7 @@ class PreprocessArguments:
|
|
| 266 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
| 267 |
|
| 268 |
overwrite: bool = field(
|
| 269 |
-
default=
|
| 270 |
)
|
| 271 |
|
| 272 |
do_generate: bool = field(
|
|
@@ -447,14 +481,26 @@ def main():
|
|
| 447 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
| 448 |
|
| 449 |
def get_rows():
|
|
|
|
|
|
|
|
|
|
| 450 |
with open(raw_dataset_path, newline='') as csvfile:
|
| 451 |
reader = csv.DictReader(csvfile)
|
|
|
|
| 452 |
for line in reader:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
if line['service'] != 'YouTube':
|
| 454 |
continue
|
|
|
|
|
|
|
| 455 |
|
| 456 |
# TODO add support for other categories and action types?
|
| 457 |
-
if line['category']
|
| 458 |
continue
|
| 459 |
if line['actionType'] != 'skip':
|
| 460 |
continue
|
|
@@ -463,9 +509,6 @@ def main():
|
|
| 463 |
if line['hidden'] == '1' or line['shadowHidden'] == '1':
|
| 464 |
continue
|
| 465 |
|
| 466 |
-
if len(line['videoID']) != 11:
|
| 467 |
-
continue # Invalid youtube video ID
|
| 468 |
-
|
| 469 |
# Skip those that aren't highly voted
|
| 470 |
line['votes'] = int(line['votes'])
|
| 471 |
# incorrect_votes = int(line['incorrectVotes'])
|
|
@@ -494,6 +537,8 @@ def main():
|
|
| 494 |
for row in data_rows:
|
| 495 |
video_ids.add(row['videoID'])
|
| 496 |
|
|
|
|
|
|
|
| 497 |
print('Start transcribing')
|
| 498 |
with tqdm(total=len(video_ids)) as progress:
|
| 499 |
def on_job_complete(job):
|
|
@@ -517,21 +562,18 @@ def main():
|
|
| 517 |
final_path = os.path.join(
|
| 518 |
processed_args.processed_dir, processed_args.processed_file)
|
| 519 |
|
| 520 |
-
if
|
| 521 |
-
logging.info(f'{final_path} exists, opening file')
|
| 522 |
-
with open(final_path) as fp:
|
| 523 |
-
final_data = json.load(fp)
|
| 524 |
-
else:
|
| 525 |
print('Create final data')
|
| 526 |
|
| 527 |
final_data = {}
|
| 528 |
|
| 529 |
if data_rows is None:
|
| 530 |
data_rows = get_rows()
|
|
|
|
| 531 |
|
| 532 |
# TODO add progress bar
|
| 533 |
# TODO parallelise?
|
| 534 |
-
for line in data_rows:
|
| 535 |
video_id = line['videoID']
|
| 536 |
|
| 537 |
if video_id not in final_data:
|
|
@@ -540,7 +582,10 @@ def main():
|
|
| 540 |
segment_start = float(line['startTime'])
|
| 541 |
segment_end = float(line['endTime'])
|
| 542 |
|
| 543 |
-
video_words = get_words(video_id, process=
|
|
|
|
|
|
|
|
|
|
| 544 |
segment_words = segment.extract_segment(
|
| 545 |
video_words, segment_start, segment_end)
|
| 546 |
|
|
@@ -552,7 +597,8 @@ def main():
|
|
| 552 |
wps = len(segment_words)/duration if duration > 0 else 0
|
| 553 |
|
| 554 |
if wps < preprocess_args.min_wps:
|
| 555 |
-
print('bad segment in',
|
|
|
|
| 556 |
continue
|
| 557 |
|
| 558 |
final_data[video_id].append({
|
|
@@ -580,10 +626,16 @@ def main():
|
|
| 580 |
# raw_dataset_path, final_path, preprocess_args.min_votes)
|
| 581 |
# # TODO save metadata in final.json?
|
| 582 |
|
| 583 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
# TODO shuffle final_data
|
| 586 |
-
|
| 587 |
# if not os.path.exists(excess_path) or preprocess_args.overwrite
|
| 588 |
# TODO use overwrite param
|
| 589 |
|
|
@@ -610,10 +662,8 @@ def main():
|
|
| 610 |
write_mode = 'w' if preprocess_args.overwrite else 'a'
|
| 611 |
|
| 612 |
get_all = preprocess_args.max_videos is None
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
else:
|
| 616 |
-
total = preprocess_args.max_videos
|
| 617 |
|
| 618 |
index = 0
|
| 619 |
data = final_data.items()
|
|
@@ -641,7 +691,7 @@ def main():
|
|
| 641 |
elif count_videos >= preprocess_args.max_videos:
|
| 642 |
break
|
| 643 |
|
| 644 |
-
words = get_words(video_id, False)
|
| 645 |
if not words:
|
| 646 |
continue
|
| 647 |
|
|
@@ -662,34 +712,40 @@ def main():
|
|
| 662 |
progress.update()
|
| 663 |
|
| 664 |
for seg in segments:
|
| 665 |
-
|
| 666 |
-
segment_text = ' '.join((x['text'] for x in seg))
|
| 667 |
-
|
| 668 |
-
extracted_text = ''
|
| 669 |
-
for p in extract_sponsors(seg):
|
| 670 |
-
p_text = ' '.join(p)
|
| 671 |
-
extracted_text += f'{CustomTokens.START_SPONSOR.value} {p_text} {CustomTokens.END_SPONSOR.value}. '
|
| 672 |
-
|
| 673 |
duration = segment.word_end(
|
| 674 |
seg[-1]) - segment.word_start(seg[0])
|
| 675 |
wps = len(seg)/duration if duration > 0 else 0
|
|
|
|
| 676 |
# Ignore segments with "not enough words" in the transcript
|
| 677 |
if wps < preprocess_args.min_wps:
|
| 678 |
continue
|
| 679 |
|
|
|
|
|
|
|
| 680 |
d = {
|
| 681 |
'video_index': index,
|
| 682 |
'video_id': video_id,
|
| 683 |
'text': clean_text(segment_text),
|
| 684 |
-
'words_per_second': wps,
|
| 685 |
}
|
| 686 |
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
|
| 691 |
-
|
| 692 |
-
|
|
|
|
| 693 |
|
| 694 |
if preprocess_args.do_split:
|
| 695 |
print('Splitting')
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
import itertools
|
| 3 |
+
from typing import Optional, List
|
| 4 |
from datasets import load_dataset
|
| 5 |
from model import ModelArguments
|
| 6 |
import segment
|
|
|
|
| 25 |
return [i for i, ltr in enumerate(s) if ltr == ch]
|
| 26 |
|
| 27 |
|
| 28 |
+
def wordify(transcript, maximum_wps=1):
|
| 29 |
"""Try to replicate format for automatically generated transcripts"""
|
| 30 |
+
|
| 31 |
+
# Do not allow segments to be on screen for too long using maximum_wps
|
| 32 |
words = []
|
| 33 |
|
| 34 |
for line_index, line in enumerate(transcript):
|
|
|
|
| 37 |
continue
|
| 38 |
|
| 39 |
start = line['start']
|
| 40 |
+
next_start = transcript[line_index + 1]['start'] \
|
| 41 |
+
if line_index < len(transcript) - 1 else float('inf')
|
| 42 |
+
|
| 43 |
+
# Use maximum wps to calculate latest end (to avoid segments which stay on screen too long)
|
| 44 |
+
longest_duration = maximum_wps * text.count(' ')
|
| 45 |
+
latest_end = start + longest_duration
|
| 46 |
+
end = min(start + line['duration'], next_start, latest_end)
|
| 47 |
+
|
| 48 |
duration = end - start
|
| 49 |
|
| 50 |
indices = find(text, ' ') + [len(text)]
|
|
|
|
| 60 |
w_start = start + percentage * duration
|
| 61 |
|
| 62 |
words.append({
|
| 63 |
+
'start': round(w_start, 3),
|
| 64 |
+
'duration': round(w_duration, 3),
|
| 65 |
+
'end': round(w_start + w_duration, 3),
|
| 66 |
'text': word,
|
| 67 |
})
|
| 68 |
|
|
|
|
| 77 |
return wordify(transcript)
|
| 78 |
|
| 79 |
|
| 80 |
+
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
| 81 |
+
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
| 82 |
+
|
| 83 |
+
|
| 84 |
def get_auto_words(transcript_list):
|
| 85 |
words = []
|
| 86 |
transcript = transcript_list.find_generated_transcript(['en'])
|
|
|
|
| 94 |
offset_ms = word.get('tOffsetMs', 0)
|
| 95 |
|
| 96 |
texts = word['utf8'].replace(
|
| 97 |
+
PROFANITY_RAW, PROFANITY_CONVERTED
|
| 98 |
).strip().split()
|
| 99 |
|
| 100 |
for text in texts:
|
|
|
|
| 106 |
return words
|
| 107 |
|
| 108 |
|
| 109 |
+
def get_words(video_id, process=True, fallback=True, transcript_type='auto'):
|
| 110 |
"""Get parsed video transcript with caching system
|
| 111 |
returns None if not processed yet and process is False
|
| 112 |
"""
|
|
|
|
| 160 |
|
| 161 |
paragraphs = []
|
| 162 |
current = []
|
| 163 |
+
prev_category = None
|
| 164 |
for word in words:
|
| 165 |
+
if word['category'] is None: # and not current:
|
| 166 |
+
continue # Skip unimportant
|
| 167 |
|
| 168 |
+
if word['category'] == prev_category:
|
| 169 |
current.append(word['text'])
|
| 170 |
else:
|
| 171 |
+
paragraphs.append({
|
| 172 |
+
'words': current,
|
| 173 |
+
'category': prev_category,
|
| 174 |
+
})
|
| 175 |
current = []
|
| 176 |
+
|
| 177 |
+
prev_category = word['category']
|
| 178 |
+
|
| 179 |
+
if current and prev_category is not None:
|
| 180 |
+
paragraphs.append({
|
| 181 |
+
'words': current,
|
| 182 |
+
'category': prev_category,
|
| 183 |
+
})
|
| 184 |
|
| 185 |
# Remove all too short:
|
| 186 |
paragraphs = list(filter(lambda x: len(
|
| 187 |
+
x['words']) >= min_sponsor_segment_length, paragraphs))
|
| 188 |
|
| 189 |
return paragraphs
|
| 190 |
|
|
|
|
| 225 |
text = re.sub(NUM_REGEX, CustomTokens.NUMBER.value, text)
|
| 226 |
|
| 227 |
# Replace profanity with special token
|
| 228 |
+
text = text.replace(PROFANITY_RAW, CustomTokens.PROFANITY.value)
|
| 229 |
+
text = text.replace(PROFANITY_CONVERTED, CustomTokens.PROFANITY.value)
|
|
|
|
|
|
|
| 230 |
|
| 231 |
return text.strip()
|
| 232 |
|
|
|
|
| 274 |
do_create: bool = field(
|
| 275 |
default=False, metadata={'help': 'Merge sponsor segments into single file'}
|
| 276 |
)
|
| 277 |
+
|
| 278 |
min_votes: int = field(
|
| 279 |
default=0, metadata={'help': 'Minimum number of votes'})
|
| 280 |
# Downvotes will make this negative.
|
| 281 |
# 1 = At least one positive vote
|
| 282 |
|
| 283 |
+
min_date: str = field(
|
| 284 |
+
default='20/08/2021', metadata={'help': 'Only use submissions from after this date, defaults to the release of v3.0 (https://github.com/ajayyy/SponsorBlock/releases/tag/3.0)'})
|
| 285 |
+
|
| 286 |
+
categories: str = field(
|
| 287 |
+
default_factory=lambda: ['sponsor', 'selfpromo', 'interaction'],
|
| 288 |
+
metadata={
|
| 289 |
+
'nargs': '+',
|
| 290 |
+
'choices': ['intro', 'sponsor', 'interaction',
|
| 291 |
+
'outro', 'selfpromo', 'preview',
|
| 292 |
+
'poi_highlight', 'filler', 'music_offtopic'] # moreCategories
|
| 293 |
+
}
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
do_transcribe: bool = field(
|
| 297 |
default=False, metadata={'help': 'Get transcripts for videos'}
|
| 298 |
)
|
|
|
|
| 300 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
| 301 |
|
| 302 |
overwrite: bool = field(
|
| 303 |
+
default=True, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
|
| 304 |
)
|
| 305 |
|
| 306 |
do_generate: bool = field(
|
|
|
|
| 481 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
| 482 |
|
| 483 |
def get_rows():
|
| 484 |
+
|
| 485 |
+
latest_time = datetime.strptime(preprocess_args.min_date, '%d/%m/%Y')
|
| 486 |
+
|
| 487 |
with open(raw_dataset_path, newline='') as csvfile:
|
| 488 |
reader = csv.DictReader(csvfile)
|
| 489 |
+
|
| 490 |
for line in reader:
|
| 491 |
+
submitted_time = datetime.fromtimestamp(
|
| 492 |
+
float(line['timeSubmitted'])/1e3)
|
| 493 |
+
|
| 494 |
+
if submitted_time < latest_time:
|
| 495 |
+
continue
|
| 496 |
+
|
| 497 |
if line['service'] != 'YouTube':
|
| 498 |
continue
|
| 499 |
+
if len(line['videoID']) != 11:
|
| 500 |
+
continue # Invalid youtube video ID
|
| 501 |
|
| 502 |
# TODO add support for other categories and action types?
|
| 503 |
+
if line['category'] not in preprocess_args.categories:
|
| 504 |
continue
|
| 505 |
if line['actionType'] != 'skip':
|
| 506 |
continue
|
|
|
|
| 509 |
if line['hidden'] == '1' or line['shadowHidden'] == '1':
|
| 510 |
continue
|
| 511 |
|
|
|
|
|
|
|
|
|
|
| 512 |
# Skip those that aren't highly voted
|
| 513 |
line['votes'] = int(line['votes'])
|
| 514 |
# incorrect_votes = int(line['incorrectVotes'])
|
|
|
|
| 537 |
for row in data_rows:
|
| 538 |
video_ids.add(row['videoID'])
|
| 539 |
|
| 540 |
+
# TODO first set - os.listdir and do rest
|
| 541 |
+
|
| 542 |
print('Start transcribing')
|
| 543 |
with tqdm(total=len(video_ids)) as progress:
|
| 544 |
def on_job_complete(job):
|
|
|
|
| 562 |
final_path = os.path.join(
|
| 563 |
processed_args.processed_dir, processed_args.processed_file)
|
| 564 |
|
| 565 |
+
if preprocess_args.do_create:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
print('Create final data')
|
| 567 |
|
| 568 |
final_data = {}
|
| 569 |
|
| 570 |
if data_rows is None:
|
| 571 |
data_rows = get_rows()
|
| 572 |
+
# data_rows = itertools.islice(data_rows, 1000) # TODO temp
|
| 573 |
|
| 574 |
# TODO add progress bar
|
| 575 |
# TODO parallelise?
|
| 576 |
+
for index, line in enumerate(data_rows):
|
| 577 |
video_id = line['videoID']
|
| 578 |
|
| 579 |
if video_id not in final_data:
|
|
|
|
| 582 |
segment_start = float(line['startTime'])
|
| 583 |
segment_end = float(line['endTime'])
|
| 584 |
|
| 585 |
+
video_words = get_words(video_id, process=False)
|
| 586 |
+
if not video_words:
|
| 587 |
+
continue
|
| 588 |
+
|
| 589 |
segment_words = segment.extract_segment(
|
| 590 |
video_words, segment_start, segment_end)
|
| 591 |
|
|
|
|
| 597 |
wps = len(segment_words)/duration if duration > 0 else 0
|
| 598 |
|
| 599 |
if wps < preprocess_args.min_wps:
|
| 600 |
+
print(index, 'Skipping bad segment in',
|
| 601 |
+
video_id, '| wps =', wps)
|
| 602 |
continue
|
| 603 |
|
| 604 |
final_data[video_id].append({
|
|
|
|
| 626 |
# raw_dataset_path, final_path, preprocess_args.min_votes)
|
| 627 |
# # TODO save metadata in final.json?
|
| 628 |
|
| 629 |
+
elif os.path.exists(final_path):
|
| 630 |
+
# Already exists
|
| 631 |
+
logging.info(f'{final_path} exists, opening file')
|
| 632 |
+
with open(final_path) as fp:
|
| 633 |
+
final_data = json.load(fp)
|
| 634 |
+
logging.info(f'Found {len(final_data)} videos')
|
| 635 |
+
else:
|
| 636 |
+
return # Do not continue
|
| 637 |
|
| 638 |
# TODO shuffle final_data
|
|
|
|
| 639 |
# if not os.path.exists(excess_path) or preprocess_args.overwrite
|
| 640 |
# TODO use overwrite param
|
| 641 |
|
|
|
|
| 662 |
write_mode = 'w' if preprocess_args.overwrite else 'a'
|
| 663 |
|
| 664 |
get_all = preprocess_args.max_videos is None
|
| 665 |
+
|
| 666 |
+
total = len(final_data) if get_all else preprocess_args.max_videos
|
|
|
|
|
|
|
| 667 |
|
| 668 |
index = 0
|
| 669 |
data = final_data.items()
|
|
|
|
| 691 |
elif count_videos >= preprocess_args.max_videos:
|
| 692 |
break
|
| 693 |
|
| 694 |
+
words = get_words(video_id, process=False)
|
| 695 |
if not words:
|
| 696 |
continue
|
| 697 |
|
|
|
|
| 712 |
progress.update()
|
| 713 |
|
| 714 |
for seg in segments:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
duration = segment.word_end(
|
| 716 |
seg[-1]) - segment.word_start(seg[0])
|
| 717 |
wps = len(seg)/duration if duration > 0 else 0
|
| 718 |
+
|
| 719 |
# Ignore segments with "not enough words" in the transcript
|
| 720 |
if wps < preprocess_args.min_wps:
|
| 721 |
continue
|
| 722 |
|
| 723 |
+
segment_text = ' '.join((x['text'] for x in seg))
|
| 724 |
+
extracted_segments = extract_sponsors(seg)
|
| 725 |
d = {
|
| 726 |
'video_index': index,
|
| 727 |
'video_id': video_id,
|
| 728 |
'text': clean_text(segment_text),
|
| 729 |
+
'words_per_second': round(wps, 3),
|
| 730 |
}
|
| 731 |
|
| 732 |
+
if extracted_segments:
|
| 733 |
+
extracted_texts = []
|
| 734 |
+
for s in extracted_segments:
|
| 735 |
+
w = ' '.join(s['words'])
|
| 736 |
+
category = s['category'].upper()
|
| 737 |
+
|
| 738 |
+
t = f"{CustomTokens.START_SEGMENT.value}_{category} {w} {CustomTokens.END_SEGMENT.value}_{category}"
|
| 739 |
+
extracted_texts.append(t)
|
| 740 |
+
|
| 741 |
+
extracted_text = '\n'.join(extracted_texts)
|
| 742 |
+
|
| 743 |
+
d['extracted'] = clean_text(extracted_text)
|
| 744 |
+
print(json.dumps(d), file=positive)
|
| 745 |
|
| 746 |
+
else:
|
| 747 |
+
d['extracted'] = CustomTokens.NO_SEGMENT.value
|
| 748 |
+
print(json.dumps(d), file=negative)
|
| 749 |
|
| 750 |
if preprocess_args.do_split:
|
| 751 |
print('Splitting')
|
src/segment.py
CHANGED
|
@@ -25,7 +25,7 @@ def get_overlapping_chunks_of_tokens(tokens, size, overlap):
|
|
| 25 |
|
| 26 |
|
| 27 |
# Generate up to max_tokens - SAFETY_TOKENS
|
| 28 |
-
SAFETY_TOKENS =
|
| 29 |
|
| 30 |
|
| 31 |
# TODO play around with this?
|
|
@@ -36,10 +36,10 @@ def add_labels_to_words(words, sponsor_segments):
|
|
| 36 |
|
| 37 |
# TODO binary search
|
| 38 |
for word in words:
|
| 39 |
-
word['
|
| 40 |
for sponsor_segment in sponsor_segments:
|
| 41 |
if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
|
| 42 |
-
word['
|
| 43 |
|
| 44 |
# TODO use extract_segment with mapping function?
|
| 45 |
# TODO remove sponsor segments that contain mostly empty space?
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
# Generate up to max_tokens - SAFETY_TOKENS
|
| 28 |
+
SAFETY_TOKENS = 12
|
| 29 |
|
| 30 |
|
| 31 |
# TODO play around with this?
|
|
|
|
| 36 |
|
| 37 |
# TODO binary search
|
| 38 |
for word in words:
|
| 39 |
+
word['category'] = None
|
| 40 |
for sponsor_segment in sponsor_segments:
|
| 41 |
if sponsor_segment['start'] <= word['start'] <= sponsor_segment['end']:
|
| 42 |
+
word['category'] = sponsor_segment['category']
|
| 43 |
|
| 44 |
# TODO use extract_segment with mapping function?
|
| 45 |
# TODO remove sponsor segments that contain mostly empty space?
|
src/shared.py
CHANGED
|
@@ -7,16 +7,17 @@ from typing import Optional
|
|
| 7 |
from dataclasses import dataclass, field
|
| 8 |
from enum import Enum
|
| 9 |
|
| 10 |
-
|
| 11 |
class CustomTokens(Enum):
|
|
|
|
|
|
|
| 12 |
URL = 'URL_TOKEN'
|
| 13 |
HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
|
| 14 |
NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
|
| 15 |
NUMBER = 'NUMBER_TOKEN'
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
|
| 21 |
SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
|
| 22 |
LONG_WORD = 'LONG_WORD_TOKEN'
|
|
@@ -26,8 +27,6 @@ class CustomTokens(Enum):
|
|
| 26 |
APPLAUSE = '[Applause]'
|
| 27 |
LAUGHTER = '[Laughter]'
|
| 28 |
|
| 29 |
-
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
| 30 |
-
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
| 31 |
PROFANITY = 'PROFANITY_TOKEN'
|
| 32 |
|
| 33 |
@classmethod
|
|
|
|
| 7 |
from dataclasses import dataclass, field
|
| 8 |
from enum import Enum
|
| 9 |
|
|
|
|
| 10 |
class CustomTokens(Enum):
|
| 11 |
+
EXTRACT_SEGMENTS_PREFIX = 'EXTRACT_SEGMENTS: '
|
| 12 |
+
|
| 13 |
URL = 'URL_TOKEN'
|
| 14 |
HYPHENATED_URL = 'HYPHENATED_URL_TOKEN'
|
| 15 |
NUMBER_PERCENTAGE = 'NUMBER_PERCENTAGE_TOKEN'
|
| 16 |
NUMBER = 'NUMBER_TOKEN'
|
| 17 |
|
| 18 |
+
START_SEGMENT = 'START_SEGMENT_TOKEN'
|
| 19 |
+
END_SEGMENT = 'END_SEGMENT_TOKEN'
|
| 20 |
+
NO_SEGMENT = 'NO_SEGMENT_FOUND'
|
| 21 |
|
| 22 |
SHORT_HYPHENATED = 'SHORT_HYPHENATED_TOKEN'
|
| 23 |
LONG_WORD = 'LONG_WORD_TOKEN'
|
|
|
|
| 27 |
APPLAUSE = '[Applause]'
|
| 28 |
LAUGHTER = '[Laughter]'
|
| 29 |
|
|
|
|
|
|
|
| 30 |
PROFANITY = 'PROFANITY_TOKEN'
|
| 31 |
|
| 32 |
@classmethod
|
src/train.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
from preprocess import load_datasets, DatasetArguments
|
| 2 |
-
from predict import ClassifierArguments, SPONSOR_MATCH_RE
|
| 3 |
-
from shared import device, GeneralArguments, OutputArguments
|
| 4 |
-
from model import ModelArguments
|
| 5 |
import transformers
|
| 6 |
-
from model import get_model, get_tokenizer
|
| 7 |
import logging
|
| 8 |
import os
|
| 9 |
import sys
|
|
@@ -22,7 +21,7 @@ from transformers.utils import check_min_version
|
|
| 22 |
from transformers.utils.versions import require_version
|
| 23 |
from sklearn.linear_model import LogisticRegression
|
| 24 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 25 |
-
|
| 26 |
import re
|
| 27 |
|
| 28 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
|
@@ -117,7 +116,7 @@ class DataTrainingArguments:
|
|
| 117 |
},
|
| 118 |
)
|
| 119 |
source_prefix: Optional[str] = field(
|
| 120 |
-
default=
|
| 121 |
'help': 'A prefix to add before every source text (useful for T5 models).'}
|
| 122 |
)
|
| 123 |
|
|
@@ -135,11 +134,11 @@ class SequenceTrainingArguments(OutputArguments, Seq2SeqTrainingArguments):
|
|
| 135 |
num_train_epochs: float = field(
|
| 136 |
default=1, metadata={'help': 'Total number of training epochs to perform.'})
|
| 137 |
|
| 138 |
-
save_steps: int = field(default=
|
| 139 |
'help': 'Save checkpoint every X updates steps.'})
|
| 140 |
-
eval_steps: int = field(default=
|
| 141 |
'help': 'Run an evaluation every X steps.'})
|
| 142 |
-
logging_steps: int = field(default=
|
| 143 |
'help': 'Log every X updates steps.'})
|
| 144 |
|
| 145 |
skip_train_transformer: bool = field(default=False, metadata={
|
|
@@ -257,8 +256,8 @@ def main():
|
|
| 257 |
|
| 258 |
ngram_range=(1, 2), # best so far
|
| 259 |
# max_features=8000 # remove for higher accuracy?
|
| 260 |
-
max_features=50000
|
| 261 |
-
|
| 262 |
)
|
| 263 |
|
| 264 |
train_test_data = {
|
|
@@ -277,11 +276,12 @@ def main():
|
|
| 277 |
dataset = raw_datasets[ds_type]
|
| 278 |
|
| 279 |
for row in dataset:
|
| 280 |
-
|
| 281 |
# Get matches:
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
| 285 |
matches = [row['text']]
|
| 286 |
|
| 287 |
for match in matches:
|
|
|
|
| 1 |
from preprocess import load_datasets, DatasetArguments
|
| 2 |
+
from predict import ClassifierArguments, SPONSOR_MATCH_RE
|
| 3 |
+
from shared import CustomTokens, device, GeneralArguments, OutputArguments
|
| 4 |
+
from model import ModelArguments, get_model, get_tokenizer
|
| 5 |
import transformers
|
|
|
|
| 6 |
import logging
|
| 7 |
import os
|
| 8 |
import sys
|
|
|
|
| 21 |
from transformers.utils.versions import require_version
|
| 22 |
from sklearn.linear_model import LogisticRegression
|
| 23 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 24 |
+
from utils import re_findall
|
| 25 |
import re
|
| 26 |
|
| 27 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
|
|
|
| 116 |
},
|
| 117 |
)
|
| 118 |
source_prefix: Optional[str] = field(
|
| 119 |
+
default=CustomTokens.EXTRACT_SEGMENTS_PREFIX.value, metadata={
|
| 120 |
'help': 'A prefix to add before every source text (useful for T5 models).'}
|
| 121 |
)
|
| 122 |
|
|
|
|
| 134 |
num_train_epochs: float = field(
|
| 135 |
default=1, metadata={'help': 'Total number of training epochs to perform.'})
|
| 136 |
|
| 137 |
+
save_steps: int = field(default=5000, metadata={
|
| 138 |
'help': 'Save checkpoint every X updates steps.'})
|
| 139 |
+
eval_steps: int = field(default=5000, metadata={
|
| 140 |
'help': 'Run an evaluation every X steps.'})
|
| 141 |
+
logging_steps: int = field(default=5000, metadata={
|
| 142 |
'help': 'Log every X updates steps.'})
|
| 143 |
|
| 144 |
skip_train_transformer: bool = field(default=False, metadata={
|
|
|
|
| 256 |
|
| 257 |
ngram_range=(1, 2), # best so far
|
| 258 |
# max_features=8000 # remove for higher accuracy?
|
| 259 |
+
# max_features=50000
|
| 260 |
+
max_features=10000
|
| 261 |
)
|
| 262 |
|
| 263 |
train_test_data = {
|
|
|
|
| 276 |
dataset = raw_datasets[ds_type]
|
| 277 |
|
| 278 |
for row in dataset:
|
|
|
|
| 279 |
# Get matches:
|
| 280 |
+
matches = re_findall(SPONSOR_MATCH_RE, row['extracted'])
|
| 281 |
+
|
| 282 |
+
return # TODO fix
|
| 283 |
+
|
| 284 |
+
if not matches:
|
| 285 |
matches = [row['text']]
|
| 286 |
|
| 287 |
for match in matches:
|
src/utils.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
import os
|
| 3 |
|
|
|
|
| 4 |
class Job:
|
| 5 |
def __init__(self, function, *args, **kwargs) -> None:
|
| 6 |
self.function = function
|
|
@@ -84,3 +86,7 @@ class InterruptibleThreadPool:
|
|
| 84 |
self.loop.close()
|
| 85 |
|
| 86 |
return self.jobs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
import asyncio
|
| 3 |
import os
|
| 4 |
|
| 5 |
+
|
| 6 |
class Job:
|
| 7 |
def __init__(self, function, *args, **kwargs) -> None:
|
| 8 |
self.function = function
|
|
|
|
| 86 |
self.loop.close()
|
| 87 |
|
| 88 |
return self.jobs
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def re_findall(pattern, string):
|
| 92 |
+
return [m.groupdict() for m in re.finditer(pattern, string)]
|