Spaces:
Running
on
Zero
Running
on
Zero
| ############################################################## | |
| ### Tools for creating preference tests (MUSHRA, ABX, etc) ### | |
| ############################################################## | |
| import copy | |
| import csv | |
| import random | |
| import sys | |
| import traceback | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import List | |
| import gradio as gr | |
| from audiotools.core.util import find_audio | |
| ################################################################ | |
| ### Logic for audio player, and adding audio / play buttons. ### | |
| ################################################################ | |
| WAVESURFER = """<div id="waveform"></div><div id="wave-timeline"></div>""" | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| max-width: 840px !important; | |
| } | |
| region.wavesurfer-region:before { | |
| content: attr(data-region-label); | |
| } | |
| block { | |
| min-width: 0 !important; | |
| } | |
| #wave-timeline { | |
| background-color: rgba(0, 0, 0, 0.8); | |
| } | |
| .head.svelte-1cl284s { | |
| display: none; | |
| } | |
| """ | |
| load_wavesurfer_js = """ | |
| function load_wavesurfer() { | |
| function load_script(url) { | |
| const script = document.createElement('script'); | |
| script.src = url; | |
| document.body.appendChild(script); | |
| return new Promise((res, rej) => { | |
| script.onload = function() { | |
| res(); | |
| } | |
| script.onerror = function () { | |
| rej(); | |
| } | |
| }); | |
| } | |
| function create_wavesurfer() { | |
| var options = { | |
| container: '#waveform', | |
| waveColor: '#F2F2F2', // Set a darker wave color | |
| progressColor: 'white', // Set a slightly lighter progress color | |
| loaderColor: 'white', // Set a slightly lighter loader color | |
| cursorColor: 'black', // Set a slightly lighter cursor color | |
| backgroundColor: '#00AAFF', // Set a black background color | |
| barWidth: 4, | |
| barRadius: 3, | |
| barHeight: 1, // the height of the wave | |
| plugins: [ | |
| WaveSurfer.regions.create({ | |
| regionsMinLength: 0.0, | |
| dragSelection: { | |
| slop: 5 | |
| }, | |
| color: 'hsla(200, 50%, 70%, 0.4)', | |
| }), | |
| WaveSurfer.timeline.create({ | |
| container: "#wave-timeline", | |
| primaryLabelInterval: 5.0, | |
| secondaryLabelInterval: 1.0, | |
| primaryFontColor: '#F2F2F2', | |
| secondaryFontColor: '#F2F2F2', | |
| }), | |
| ] | |
| }; | |
| wavesurfer = WaveSurfer.create(options); | |
| wavesurfer.on('region-created', region => { | |
| wavesurfer.regions.clear(); | |
| }); | |
| wavesurfer.on('finish', function () { | |
| var loop = document.getElementById("loop-button").textContent.includes("ON"); | |
| if (loop) { | |
| wavesurfer.play(); | |
| } | |
| else { | |
| var button_elements = document.getElementsByClassName('playpause') | |
| var buttons = Array.from(button_elements); | |
| for (let j = 0; j < buttons.length; j++) { | |
| buttons[j].classList.remove("primary"); | |
| buttons[j].classList.add("secondary"); | |
| buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") | |
| } | |
| } | |
| }); | |
| wavesurfer.on('region-out', function () { | |
| var loop = document.getElementById("loop-button").textContent.includes("ON"); | |
| if (!loop) { | |
| var button_elements = document.getElementsByClassName('playpause') | |
| var buttons = Array.from(button_elements); | |
| for (let j = 0; j < buttons.length; j++) { | |
| buttons[j].classList.remove("primary"); | |
| buttons[j].classList.add("secondary"); | |
| buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") | |
| } | |
| wavesurfer.pause(); | |
| } | |
| }); | |
| console.log("Created WaveSurfer object.") | |
| } | |
| load_script('https://unpkg.com/[email protected]') | |
| .then(() => { | |
| load_script("https://unpkg.com/[email protected]/dist/plugin/wavesurfer.timeline.min.js") | |
| .then(() => { | |
| load_script('https://unpkg.com/[email protected]/dist/plugin/wavesurfer.regions.min.js') | |
| .then(() => { | |
| console.log("Loaded regions"); | |
| create_wavesurfer(); | |
| document.getElementById("start-survey").click(); | |
| }) | |
| }) | |
| }); | |
| } | |
| """ | |
| play = lambda i: """ | |
| function play() { | |
| var audio_elements = document.getElementsByTagName('audio'); | |
| var button_elements = document.getElementsByClassName('playpause') | |
| var audio_array = Array.from(audio_elements); | |
| var buttons = Array.from(button_elements); | |
| var src_link = audio_array[{i}].getAttribute("src"); | |
| console.log(src_link); | |
| var loop = document.getElementById("loop-button").textContent.includes("ON"); | |
| var playing = buttons[{i}].textContent.includes("Stop"); | |
| for (let j = 0; j < buttons.length; j++) { | |
| if (j != {i} || playing) { | |
| buttons[j].classList.remove("primary"); | |
| buttons[j].classList.add("secondary"); | |
| buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") | |
| } | |
| else { | |
| buttons[j].classList.remove("secondary"); | |
| buttons[j].classList.add("primary"); | |
| buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop") | |
| } | |
| } | |
| if (playing) { | |
| wavesurfer.pause(); | |
| wavesurfer.seekTo(0.0); | |
| } | |
| else { | |
| wavesurfer.load(src_link); | |
| wavesurfer.on('ready', function () { | |
| var region = Object.values(wavesurfer.regions.list)[0]; | |
| if (region != null) { | |
| region.loop = loop; | |
| region.play(); | |
| } else { | |
| wavesurfer.play(); | |
| } | |
| }); | |
| } | |
| } | |
| """.replace( | |
| "{i}", str(i) | |
| ) | |
| clear_regions = """ | |
| function clear_regions() { | |
| wavesurfer.clearRegions(); | |
| } | |
| """ | |
| reset_player = """ | |
| function reset_player() { | |
| wavesurfer.clearRegions(); | |
| wavesurfer.pause(); | |
| wavesurfer.seekTo(0.0); | |
| var button_elements = document.getElementsByClassName('playpause') | |
| var buttons = Array.from(button_elements); | |
| for (let j = 0; j < buttons.length; j++) { | |
| buttons[j].classList.remove("primary"); | |
| buttons[j].classList.add("secondary"); | |
| buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") | |
| } | |
| } | |
| """ | |
| loop_region = """ | |
| function loop_region() { | |
| var element = document.getElementById("loop-button"); | |
| var loop = element.textContent.includes("OFF"); | |
| console.log(loop); | |
| try { | |
| var region = Object.values(wavesurfer.regions.list)[0]; | |
| region.loop = loop; | |
| } catch {} | |
| if (loop) { | |
| element.classList.remove("secondary"); | |
| element.classList.add("primary"); | |
| element.textContent = "Looping ON"; | |
| } else { | |
| element.classList.remove("primary"); | |
| element.classList.add("secondary"); | |
| element.textContent = "Looping OFF"; | |
| } | |
| } | |
| """ | |
| class Player: | |
| def __init__(self, app): | |
| self.app = app | |
| self.app.load(_js=load_wavesurfer_js) | |
| self.app.css = CUSTOM_CSS | |
| self.wavs = [] | |
| self.position = 0 | |
| def create(self): | |
| gr.HTML(WAVESURFER) | |
| gr.Markdown( | |
| "Click and drag on the waveform above to select a region for playback. " | |
| "Once created, the region can be moved around and resized. " | |
| "Clear the regions using the button below. Hit play on one of the buttons below to start!" | |
| ) | |
| with gr.Row(): | |
| clear = gr.Button("Clear region") | |
| loop = gr.Button("Looping OFF", elem_id="loop-button") | |
| loop.click(None, _js=loop_region) | |
| clear.click(None, _js=clear_regions) | |
| gr.HTML("<hr>") | |
| def add(self, name: str = "Play"): | |
| i = self.position | |
| self.wavs.append( | |
| { | |
| "audio": gr.Audio(visible=False), | |
| "button": gr.Button(name, elem_classes=["playpause"]), | |
| "position": i, | |
| } | |
| ) | |
| self.wavs[-1]["button"].click(None, _js=play(i)) | |
| self.position += 1 | |
| return self.wavs[-1] | |
| def to_list(self): | |
| return [x["audio"] for x in self.wavs] | |
| ############################################################ | |
| ### Keeping track of users, and CSS for the progress bar ### | |
| ############################################################ | |
| load_tracker = lambda name: """ | |
| function load_name() { | |
| function setCookie(name, value, exp_days) { | |
| var d = new Date(); | |
| d.setTime(d.getTime() + (exp_days*24*60*60*1000)); | |
| var expires = "expires=" + d.toGMTString(); | |
| document.cookie = name + "=" + value + ";" + expires + ";path=/"; | |
| } | |
| function getCookie(name) { | |
| var cname = name + "="; | |
| var decodedCookie = decodeURIComponent(document.cookie); | |
| var ca = decodedCookie.split(';'); | |
| for(var i = 0; i < ca.length; i++){ | |
| var c = ca[i]; | |
| while(c.charAt(0) == ' '){ | |
| c = c.substring(1); | |
| } | |
| if(c.indexOf(cname) == 0){ | |
| return c.substring(cname.length, c.length); | |
| } | |
| } | |
| return ""; | |
| } | |
| name = getCookie("{name}"); | |
| if (name == "") { | |
| name = Math.random().toString(36).slice(2); | |
| console.log(name); | |
| setCookie("name", name, 30); | |
| } | |
| name = getCookie("{name}"); | |
| return name; | |
| } | |
| """.replace( | |
| "{name}", name | |
| ) | |
| # Progress bar | |
| progress_template = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Progress Bar</title> | |
| <style> | |
| .progress-bar { | |
| background-color: #ddd; | |
| border-radius: 4px; | |
| height: 30px; | |
| width: 100%; | |
| position: relative; | |
| } | |
| .progress { | |
| background-color: #00AAFF; | |
| border-radius: 4px; | |
| height: 100%; | |
| width: {PROGRESS}%; /* Change this value to control the progress */ | |
| } | |
| .progress-text { | |
| position: absolute; | |
| top: 50%; | |
| left: 50%; | |
| transform: translate(-50%, -50%); | |
| font-size: 18px; | |
| font-family: Arial, sans-serif; | |
| font-weight: bold; | |
| color: #333 !important; | |
| text-shadow: 1px 1px #fff; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="progress-bar"> | |
| <div class="progress"></div> | |
| <div class="progress-text">{TEXT}</div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| def create_tracker(app, cookie_name="name"): | |
| user = gr.Text(label="user", interactive=True, visible=False, elem_id="user") | |
| app.load(_js=load_tracker(cookie_name), outputs=user) | |
| return user | |
| ################################################################# | |
| ### CSS and HTML for labeling sliders for both ABX and MUSHRA ### | |
| ################################################################# | |
| slider_abx = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>Labels Example</title> | |
| <style> | |
| body { | |
| margin: 0; | |
| padding: 0; | |
| } | |
| .labels-container { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| width: 100%; | |
| height: 40px; | |
| padding: 0px 12px 0px; | |
| } | |
| .label { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| width: 33%; | |
| height: 100%; | |
| font-weight: bold; | |
| text-transform: uppercase; | |
| padding: 10px; | |
| font-family: Arial, sans-serif; | |
| font-size: 16px; | |
| font-weight: 700; | |
| letter-spacing: 1px; | |
| line-height: 1.5; | |
| } | |
| .label-a { | |
| background-color: #00AAFF; | |
| color: #333 !important; | |
| } | |
| .label-tie { | |
| background-color: #f97316; | |
| color: #333 !important; | |
| } | |
| .label-b { | |
| background-color: #00AAFF; | |
| color: #333 !important; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="labels-container"> | |
| <div class="label label-a">Prefer A</div> | |
| <div class="label label-tie">Toss-up</div> | |
| <div class="label label-b">Prefer B</div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| slider_mushra = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>Labels Example</title> | |
| <style> | |
| body { | |
| margin: 0; | |
| padding: 0; | |
| } | |
| .labels-container { | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| width: 100%; | |
| height: 30px; | |
| padding: 10px; | |
| } | |
| .label { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| width: 20%; | |
| height: 100%; | |
| font-weight: bold; | |
| text-transform: uppercase; | |
| padding: 10px; | |
| font-family: Arial, sans-serif; | |
| font-size: 13.5px; | |
| font-weight: 700; | |
| line-height: 1.5; | |
| } | |
| .label-bad { | |
| background-color: #ff5555; | |
| color: #333 !important; | |
| } | |
| .label-poor { | |
| background-color: #ffa500; | |
| color: #333 !important; | |
| } | |
| .label-fair { | |
| background-color: #ffd700; | |
| color: #333 !important; | |
| } | |
| .label-good { | |
| background-color: #97d997; | |
| color: #333 !important; | |
| } | |
| .label-excellent { | |
| background-color: #04c822; | |
| color: #333 !important; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="labels-container"> | |
| <div class="label label-bad">bad</div> | |
| <div class="label label-poor">poor</div> | |
| <div class="label label-fair">fair</div> | |
| <div class="label label-good">good</div> | |
| <div class="label label-excellent">excellent</div> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| ######################################################### | |
| ### Handling loading audio and tracking session state ### | |
| ######################################################### | |
| class Samples: | |
| def __init__(self, folder: str, shuffle: bool = True, n_samples: int = None): | |
| files = find_audio(folder) | |
| samples = defaultdict(lambda: defaultdict()) | |
| for f in files: | |
| condition = f.parent.stem | |
| samples[f.name][condition] = f | |
| self.samples = samples | |
| self.names = list(samples.keys()) | |
| self.filtered = False | |
| self.current = 0 | |
| if shuffle: | |
| random.shuffle(self.names) | |
| self.n_samples = len(self.names) if n_samples is None else n_samples | |
| def get_updates(self, idx, order): | |
| key = self.names[idx] | |
| return [gr.update(value=str(self.samples[key][o])) for o in order] | |
| def progress(self): | |
| try: | |
| pct = self.current / len(self) * 100 | |
| except: # pragma: no cover | |
| pct = 100 | |
| text = f"On {self.current} / {len(self)} samples" | |
| pbar = ( | |
| copy.copy(progress_template) | |
| .replace("{PROGRESS}", str(pct)) | |
| .replace("{TEXT}", str(text)) | |
| ) | |
| return gr.update(value=pbar) | |
| def __len__(self): | |
| return self.n_samples | |
| def filter_completed(self, user, save_path): | |
| if not self.filtered: | |
| done = [] | |
| if Path(save_path).exists(): | |
| with open(save_path, "r") as f: | |
| reader = csv.DictReader(f) | |
| done = [r["sample"] for r in reader if r["user"] == user] | |
| self.names = [k for k in self.names if k not in done] | |
| self.names = self.names[: self.n_samples] | |
| self.filtered = True # Avoid filtering more than once per session. | |
| def get_next_sample(self, reference, conditions): | |
| random.shuffle(conditions) | |
| if reference is not None: | |
| self.order = [reference] + conditions | |
| else: | |
| self.order = conditions | |
| try: | |
| updates = self.get_updates(self.current, self.order) | |
| self.current += 1 | |
| done = gr.update(interactive=True) | |
| pbar = self.progress() | |
| except: | |
| traceback.print_exc() | |
| updates = [gr.update() for _ in range(len(self.order))] | |
| done = gr.update(value="No more samples!", interactive=False) | |
| self.current = len(self) | |
| pbar = self.progress() | |
| return updates, done, pbar | |
| def save_result(result, save_path): | |
| with open(save_path, mode="a", newline="") as file: | |
| writer = csv.DictWriter(file, fieldnames=sorted(list(result.keys()))) | |
| if file.tell() == 0: | |
| writer.writeheader() | |
| writer.writerow(result) | |