Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import csv | |
import datetime | |
import io | |
import json | |
import os | |
from abc import ABC, abstractmethod | |
from typing import TYPE_CHECKING, Any, List, Optional | |
import gradio as gr | |
from gradio import encryptor, utils | |
if TYPE_CHECKING: | |
from gradio.components import Component | |
class FlaggingCallback(ABC): | |
""" | |
An abstract class for defining the methods that any FlaggingCallback should have. | |
""" | |
def setup(self, components: List[Component], flagging_dir: str): | |
""" | |
This method should be overridden and ensure that everything is set up correctly for flag(). | |
This method gets called once at the beginning of the Interface.launch() method. | |
Parameters: | |
components: Set of components that will provide flagged data. | |
flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()). | |
""" | |
pass | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: Optional[str] = None, | |
flag_index: Optional[int] = None, | |
username: Optional[str] = None, | |
) -> int: | |
""" | |
This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments. | |
This gets called every time the <flag> button is pressed. | |
Parameters: | |
interface: The Interface object that is being used to launch the flagging interface. | |
flag_data: The data to be flagged. | |
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. | |
flag_index (optional): The index of the sample that is being flagged. | |
username (optional): The username of the user that is flagging the data, if logged in. | |
Returns: | |
(int) The total number of samples that have been flagged. | |
""" | |
pass | |
class SimpleCSVLogger(FlaggingCallback): | |
""" | |
A simple example implementation of the FlaggingCallback abstract class | |
provided for illustrative purposes. | |
""" | |
def setup(self, components: List[Component], flagging_dir: str): | |
self.components = components | |
self.flagging_dir = flagging_dir | |
os.makedirs(flagging_dir, exist_ok=True) | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: Optional[str] = None, | |
flag_index: Optional[int] = None, | |
username: Optional[str] = None, | |
) -> int: | |
flagging_dir = self.flagging_dir | |
log_filepath = os.path.join(flagging_dir, "log.csv") | |
csv_data = [] | |
for component, sample in zip(self.components, flag_data): | |
csv_data.append( | |
component.save_flagged( | |
flagging_dir, | |
component.label, | |
sample, | |
None, | |
) | |
) | |
with open(log_filepath, "a", newline="") as csvfile: | |
writer = csv.writer(csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") | |
writer.writerow(csv_data) | |
with open(log_filepath, "r") as csvfile: | |
line_count = len([None for row in csv.reader(csvfile)]) - 1 | |
return line_count | |
class CSVLogger(FlaggingCallback): | |
""" | |
The default implementation of the FlaggingCallback abstract class. | |
Logs the input and output data to a CSV file. Supports encryption. | |
""" | |
def setup( | |
self, | |
components: List[Component], | |
flagging_dir: str, | |
encryption_key: Optional[str] = None, | |
): | |
self.components = components | |
self.flagging_dir = flagging_dir | |
self.encryption_key = encryption_key | |
os.makedirs(flagging_dir, exist_ok=True) | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: Optional[str] = None, | |
flag_index: Optional[int] = None, | |
username: Optional[str] = None, | |
) -> int: | |
flagging_dir = self.flagging_dir | |
log_filepath = os.path.join(flagging_dir, "log.csv") | |
is_new = not os.path.exists(log_filepath) | |
if flag_index is None: | |
csv_data = [] | |
for component, sample in zip(self.components, flag_data): | |
csv_data.append( | |
component.save_flagged( | |
flagging_dir, | |
component.label, | |
sample, | |
self.encryption_key, | |
) | |
if sample is not None | |
else "" | |
) | |
csv_data.append(flag_option if flag_option is not None else "") | |
csv_data.append(username if username is not None else "") | |
csv_data.append(str(datetime.datetime.now())) | |
if is_new: | |
headers = [component.label for component in self.components] + [ | |
"flag", | |
"username", | |
"timestamp", | |
] | |
def replace_flag_at_index(file_content): | |
file_content = io.StringIO(file_content) | |
content = list(csv.reader(file_content)) | |
header = content[0] | |
flag_col_index = header.index("flag") | |
content[flag_index][flag_col_index] = flag_option | |
output = io.StringIO() | |
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") | |
writer.writerows(content) | |
return output.getvalue() | |
if self.encryption_key: | |
output = io.StringIO() | |
if not is_new: | |
with open(log_filepath, "rb") as csvfile: | |
encrypted_csv = csvfile.read() | |
decrypted_csv = encryptor.decrypt( | |
self.encryption_key, encrypted_csv | |
) | |
file_content = decrypted_csv.decode() | |
if flag_index is not None: | |
file_content = replace_flag_at_index(file_content) | |
output.write(file_content) | |
writer = csv.writer(output, quoting=csv.QUOTE_NONNUMERIC, quotechar="'") | |
if flag_index is None: | |
if is_new: | |
writer.writerow(headers) | |
writer.writerow(csv_data) | |
with open(log_filepath, "wb") as csvfile: | |
csvfile.write( | |
encryptor.encrypt(self.encryption_key, output.getvalue().encode()) | |
) | |
else: | |
if flag_index is None: | |
with open(log_filepath, "a", newline="") as csvfile: | |
writer = csv.writer( | |
csvfile, quoting=csv.QUOTE_NONNUMERIC, quotechar="'" | |
) | |
if is_new: | |
writer.writerow(headers) | |
writer.writerow(csv_data) | |
else: | |
with open(log_filepath) as csvfile: | |
file_content = csvfile.read() | |
file_content = replace_flag_at_index(file_content) | |
with open( | |
log_filepath, "w", newline="" | |
) as csvfile: # newline parameter needed for Windows | |
csvfile.write(file_content) | |
with open(log_filepath, "r") as csvfile: | |
line_count = len([None for row in csv.reader(csvfile)]) - 1 | |
return line_count | |
class HuggingFaceDatasetSaver(FlaggingCallback): | |
""" | |
A FlaggingCallback that saves flagged data to a HuggingFace dataset. | |
""" | |
def __init__( | |
self, | |
hf_foken: str, | |
dataset_name: str, | |
organization: Optional[str] = None, | |
private: bool = False, | |
verbose: bool = True, | |
): | |
""" | |
Params: | |
hf_token (str): The token to use to access the huggingface API. | |
dataset_name (str): The name of the dataset to save the data to, e.g. | |
"image-classifier-1" | |
organization (str): The name of the organization to which to attach | |
the datasets. If None, the dataset attaches to the user only. | |
private (bool): If the dataset does not already exist, whether it | |
should be created as a private dataset or public. Private datasets | |
may require paid huggingface.co accounts | |
verbose (bool): Whether to print out the status of the dataset | |
creation. | |
""" | |
self.hf_foken = hf_foken | |
self.dataset_name = dataset_name | |
self.organization_name = organization | |
self.dataset_private = private | |
self.verbose = verbose | |
def setup(self, components: List[Component], flagging_dir: str): | |
""" | |
Params: | |
flagging_dir (str): local directory where the dataset is cloned, | |
updated, and pushed from. | |
""" | |
try: | |
import huggingface_hub | |
except (ImportError, ModuleNotFoundError): | |
raise ImportError( | |
"Package `huggingface_hub` not found is needed " | |
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'." | |
) | |
path_to_dataset_repo = huggingface_hub.create_repo( | |
# name=self.dataset_name, https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py | |
repo_id=self.dataset_name, | |
token=self.hf_foken, | |
private=self.dataset_private, | |
repo_type="dataset", | |
exist_ok=True, | |
) | |
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10" | |
self.components = components | |
self.flagging_dir = flagging_dir | |
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name) | |
self.repo = huggingface_hub.Repository( | |
local_dir=self.dataset_dir, | |
clone_from=path_to_dataset_repo, | |
use_auth_token=self.hf_foken, | |
) | |
self.repo.git_pull() | |
# Should filename be user-specified? | |
self.log_file = os.path.join(self.dataset_dir, "data.csv") | |
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json") | |
def flag( | |
self, | |
flag_data: List[Any], | |
flag_option: Optional[str] = None, | |
flag_index: Optional[int] = None, | |
username: Optional[str] = None, | |
) -> int: | |
is_new = not os.path.exists(self.log_file) | |
infos = {"flagged": {"features": {}}} | |
with open(self.log_file, "a", newline="") as csvfile: | |
writer = csv.writer(csvfile) | |
# File previews for certain input and output types | |
file_preview_types = { | |
gr.inputs.Audio: "Audio", | |
gr.outputs.Audio: "Audio", | |
gr.inputs.Image: "Image", | |
gr.outputs.Image: "Image", | |
} | |
# Generate the headers and dataset_infos | |
if is_new: | |
headers = [] | |
for component, sample in zip(self.components, flag_data): | |
headers.append(component.label) | |
headers.append(component.label) | |
infos["flagged"]["features"][component.label] = { | |
"dtype": "string", | |
"_type": "Value", | |
} | |
if isinstance(component, tuple(file_preview_types)): | |
headers.append(component.label + " file") | |
for _component, _type in file_preview_types.items(): | |
if isinstance(component, _component): | |
infos["flagged"]["features"][ | |
component.label + " file" | |
] = {"_type": _type} | |
break | |
headers.append("flag") | |
infos["flagged"]["features"]["flag"] = { | |
"dtype": "string", | |
"_type": "Value", | |
} | |
writer.writerow(headers) | |
# Generate the row corresponding to the flagged sample | |
csv_data = [] | |
for component, sample in zip(self.components, flag_data): | |
filepath = component.save_flagged( | |
self.dataset_dir, component.label, sample, None | |
) | |
csv_data.append(filepath) | |
if isinstance(component, tuple(file_preview_types)): | |
csv_data.append( | |
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath) | |
) | |
csv_data.append(flag_option if flag_option is not None else "") | |
writer.writerow(csv_data) | |
if is_new: | |
json.dump(infos, open(self.infos_file, "w")) | |
with open(self.log_file, "r") as csvfile: | |
line_count = len([None for row in csv.reader(csvfile)]) - 1 | |
self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count)) | |
return line_count |