Spaces:
Running
Running
| import gradio as gr | |
| import itertools | |
| import random | |
| import json | |
| from .config import sents | |
| from typing import List, Tuple, Set, Dict | |
| from hashlib import md5, sha1 | |
| import spaces | |
| # from .synth import clear_stuff | |
| class User: | |
| def __init__(self, user_id: str): | |
| self.user_id = user_id | |
| self.voted_pairs: Set[Tuple[str, str]] = set() | |
| class Sample: | |
| def __init__(self, filename: str, transcript: str, modelName: str): | |
| self.filename = filename | |
| self.transcript = transcript | |
| self.modelName = modelName | |
| def to_dict(self): | |
| return { | |
| 'filename': self.filename, | |
| 'transcript': self.transcript, | |
| 'modelName': self.modelName, | |
| } | |
| # cache audio samples for quick voting | |
| cached_samples: List[Sample] = [] | |
| # cached_samples.append(Sample("audio1.mp3", "Hello, how are you?", "model1")) | |
| # cached_samples.append(Sample("audio2.mp3", "Hello, how are you?", "model2")) | |
| # load temporary samples | |
| json_data = '' | |
| try: | |
| with open("_cached_samples.json", "r") as read: | |
| loaded_samples = json.load(read) | |
| cached_samples = [Sample(**json_data) for json_data in loaded_samples] | |
| except: | |
| pass | |
| def asr_cached_for_dataset(): | |
| for caudio in cached_samples: | |
| pass | |
| return True | |
| voting_users = { | |
| # userid as the key and USER() as the value | |
| } | |
| # List[Tuple[Sample, Sample]] | |
| all_pairs = [] | |
| def get_userid(session_hash: str, request): | |
| # JS cookie | |
| if (session_hash != ''): | |
| # print('auth by session cookie') | |
| return sha1(bytes(session_hash.encode('ascii')), usedforsecurity=False).hexdigest() | |
| if request.username: | |
| # print('auth by username') | |
| # by HuggingFace username - requires `auth` to be enabled therefore denying access to anonymous users | |
| return sha1(bytes(request.username.encode('ascii')), usedforsecurity=False).hexdigest() | |
| else: | |
| # print('auth by ip') | |
| # by IP address - unreliable when gradio within HTML iframe | |
| # return sha1(bytes(request.client.host.encode('ascii')), usedforsecurity=False).hexdigest() | |
| # by browser session cookie - Gradio on HF is run in an HTML iframe, access to parent session required to reach session token | |
| # return sha1(bytes(request.headers.encode('ascii'))).hexdigest() | |
| # by browser session hash - Not a cookie, session hash changes on page reload | |
| return sha1(bytes(request.session_hash.encode('ascii')), usedforsecurity=False).hexdigest() | |
| def cache_sample(path, text, model): | |
| # skip caching if not a hardcoded sentence | |
| if (text not in sents): | |
| return False | |
| already_cached = False | |
| # check if already cached | |
| for cached_sample in cached_samples: | |
| # TODO: replace cached sample with a newer version? | |
| if (cached_sample.transcript == text and cached_sample.modelName == model): | |
| already_cached = True | |
| return True | |
| if (already_cached): | |
| return False | |
| try: | |
| cached_samples.append(Sample(path, text, model)) | |
| except: | |
| print('Error when trying to cache sample') | |
| return False | |
| # save list to JSON file | |
| cached_sample_dict = [cached_sample.to_dict() for cached_sample in cached_samples] | |
| try: | |
| with open("_cached_samples.json", "w") as write: | |
| json.dump( cached_sample_dict , write ) | |
| except: | |
| pass | |
| # Give user a cached audio sample pair they have yet to vote on | |
| def give_cached_sample(session_hash: str, autoplay: bool, request: gr.Request): | |
| # add new userid to voting_users from Browser session hash | |
| # stored only in RAM | |
| userid = get_userid(session_hash, request) | |
| if userid not in voting_users: | |
| voting_users[userid] = User(userid) | |
| def get_next_pair(user: User): | |
| # FIXME: all_pairs var out of scope | |
| # all_pairs = generate_matching_pairs(cached_samples) | |
| # for pair in all_pairs: | |
| for pair in generate_matching_pairs(cached_samples): | |
| hash1 = md5(bytes((pair[0].modelName + pair[0].transcript).encode('ascii'))).hexdigest() | |
| hash2 = md5(bytes((pair[1].modelName + pair[1].transcript).encode('ascii'))).hexdigest() | |
| pair_key = (hash1, hash2) | |
| if ( | |
| pair_key not in user.voted_pairs | |
| # or in reversed order | |
| and (pair_key[1], pair_key[0]) not in user.voted_pairs | |
| ): | |
| return pair | |
| return None | |
| pair = get_next_pair(voting_users[userid]) | |
| if pair is None: | |
| comp_defaults = [] | |
| for i in range(0, 14): | |
| comp_defaults.append(gr.update()) | |
| return [ | |
| *comp_defaults, | |
| # *clear_stuff(), | |
| # disable get cached sample button | |
| gr.update(interactive=False) | |
| ] | |
| return ( | |
| gr.update(visible=True, value=pair[0].transcript, elem_classes=['blurred-text']), | |
| "Synthesize 🐢", | |
| gr.update(visible=True), # r2 | |
| pair[0].modelName, # model1 | |
| pair[1].modelName, # model2 | |
| gr.update(visible=True, value=pair[0].filename, interactive=False, autoplay=autoplay), # aud1 | |
| gr.update(visible=True, value=pair[1].filename, interactive=False, autoplay=False), # aud2 | |
| gr.update(visible=True, interactive=False), #abetter | |
| gr.update(visible=True, interactive=False), #bbetter | |
| gr.update(visible=False), #prevmodel1 | |
| gr.update(visible=False), #prevmodel2 | |
| gr.update(visible=True), #nxt round btn | |
| # reset aplayed, bplayed audio playback events | |
| False, #aplayed | |
| False, #bplayed | |
| # fetch cached btn | |
| gr.update(interactive=True) | |
| ) | |
| def generate_matching_pairs(samples: List[Sample]) -> List[Tuple[Sample, Sample]]: | |
| transcript_groups: Dict[str, List[Sample]] = {} | |
| samples = random.sample(samples, k=len(samples)) | |
| for sample in samples: | |
| if sample.transcript not in transcript_groups: | |
| transcript_groups[sample.transcript] = [] | |
| transcript_groups[sample.transcript].append(sample) | |
| matching_pairs: List[Tuple[Sample, Sample]] = [] | |
| for group in transcript_groups.values(): | |
| matching_pairs.extend(list(itertools.combinations(group, 2))) | |
| return matching_pairs | |
| # note the vote on cached sample pair | |
| def voted_on_cached(modelName1: str, modelName2: str, transcript: str, session_hash: str, request: gr.Request): | |
| userid = get_userid(session_hash, request) | |
| # print(f'userid voted on cached: {userid}') | |
| if userid not in voting_users: | |
| voting_users[userid] = User(userid) | |
| hash1 = md5(bytes((modelName1 + transcript).encode('ascii'))).hexdigest() | |
| hash2 = md5(bytes((modelName2 + transcript).encode('ascii'))).hexdigest() | |
| voting_users[userid].voted_pairs.add((hash1, hash2)) | |
| return [] |