Spaces:
Runtime error
Runtime error
import json | |
import math | |
import os | |
import random | |
import uuid | |
from datetime import datetime | |
import gradio as gr | |
import jsonlines | |
import pyarrow as pa | |
import s3fs | |
from datasets import Dataset | |
from huggingface_hub import HfApi | |
S3 = s3fs.S3FileSystem(anon=False, key=os.getenv("AWS_ACCESS_KEY_ID"), secret=os.getenv("AWS_SECRET_ACCESS_KEY")) | |
BASE_S3_DIR = "s3://geclm-datasets/samples/" | |
LABELLING_COMPLETE_TEXT = ( | |
"Completed the labelling the sample for the {} dataset. Please consider labelling other datasets." | |
) | |
DATASETS = [ | |
"c4", | |
"bigcode_python_code", | |
"bigcode_python_github_issues", | |
"bigcode_python_jupyter_markdowned_clean_dedup", | |
"books3", | |
"gutenberg_raw", | |
"reddit_threaded", | |
"enwiki_data", | |
"s2orc_dedup", | |
"stackexchange2", | |
"commoncrawl", | |
] | |
def get_parquet_lines(dataset, sample_size=1000): | |
s3_paths = S3.glob(BASE_S3_DIR + dataset + "/*") | |
if len(s3_paths) == 0: | |
raise FileNotFoundError(f"Nothing found at {path}") | |
print("Number of parquet files", len(s3_paths)) | |
s3_path = random.choice(s3_paths) | |
print("Reading", s3_path) | |
lines = [] | |
with S3.open(s3_path) as f: | |
pf = pa.parquet.ParquetFile(f) | |
for ix_row_group in range(pf.metadata.num_row_groups): | |
# We load dataset by row group - 1000 rows at a time | |
# using open_input_stream would return bytes per bytes not row per row | |
table = pf.read_row_group(ix_row_group) | |
lines.extend(table.to_pylist()) | |
random.shuffle(lines) | |
return lines[:sample_size] | |
def get_local_lines(dataset): | |
lines = [] | |
with jsonlines.open("data/{}_examples_with_stats.json".format(dataset), "r") as f: | |
for line in f: | |
lines.append(line) | |
return lines | |
def line_generator(lines_dict, dataset): | |
for line in lines_dict[dataset]: | |
yield line | |
# local_lines = {dataset: get_local_lines(dataset) for dataset in DATASETS} | |
# line_generators_local = {dataset: line_generator(local_lines, dataset) for dataset in DATASETS} | |
# Parallelize the below ? | |
s3_lines = {dataset: get_parquet_lines(dataset) for dataset in DATASETS} | |
line_generators_s3 = {dataset: line_generator(s3_lines, dataset) for dataset in DATASETS} | |
def send_report(sample, dataset, reason, annotator, campaign): | |
text_col = "text" | |
if text_col not in sample: | |
text_col = "content" | |
text = sample[text_col] | |
sample.pop(text_col) | |
if "record_timestamp" in sample: | |
sample.pop("record_timestamp") | |
sample_id = "" | |
if "id" not in sample: | |
if "title" in sample: | |
sample_id = sample["title"] | |
else: | |
sample_id = sample["id"] | |
with jsonlines.open("report.jsonl", "w") as f: | |
f.write( | |
{ | |
"dataset": dataset, | |
"docid": sample_id, | |
"text": text, | |
"metadata": json.dumps(sample), | |
"reason": reason, | |
"annotator": annotator, | |
"campaign": campaign, | |
"timestamp": str(datetime.now()), | |
} | |
) | |
api = HfApi() | |
api.upload_file( | |
path_or_fileobj="report.jsonl", | |
path_in_repo="report-{}.jsonl".format(uuid.uuid4()), | |
repo_id="HuggingFaceGECLM/data_feedback", | |
repo_type="dataset", | |
token=os.environ.get("geclm_token"), | |
) | |
def get_title_and_text_for_line(next_line): | |
text_col = "text" | |
if text_col not in next_line: | |
text_col = "content" | |
text = next_line[text_col] | |
label = "" | |
if "title" in next_line: | |
label = next_line["title"] | |
if "url" in next_line: | |
label += " | " + next_line["url"] | |
elif "metadata" in next_line: | |
if next_line["metadata"] is not None: | |
print(next_line["metadata"]) | |
if isinstance(next_line["metadata"], list) and len(next_line["metadata"]) > 0: | |
label = next_line["metadata"][0] | |
elif isinstance(next_line["metadata"], str): | |
metadata = json.loads(next_line["metadata"]) | |
if "document_url" in metadata: | |
label = metadata["document_url"] | |
elif "url" in next_line: | |
label = next_line["url"] | |
return text, label | |
if __name__ == "__main__": | |
demo = gr.Blocks() | |
with demo: | |
current_sample_state = gr.State(dict()) | |
description = gr.Markdown( | |
value="""GecLM annotations. All annotations are recorded in the [data_feedback](https://huggingface.co/datasets/HuggingFaceGECLM/data_feedback) dataset. | |
""", | |
) | |
with gr.Row(): | |
annotator = gr.Textbox( | |
lines=1, | |
max_lines=1, | |
placeholder="Optionally provide your name here if you'd like it to be recorded.", | |
label="Annotator", | |
) | |
campaign = gr.Textbox( | |
lines=1, | |
max_lines=1, | |
placeholder="Optionally provide the name of the annotation campagin for ease of filtering the reports.", | |
label="Annotation campaign", | |
) | |
with gr.Row(): | |
dataset = gr.Dropdown( | |
choices=DATASETS, | |
value="Pick a dataset below", | |
label="Dataset", | |
) | |
with gr.Row(): | |
reason_txt = gr.Textbox( | |
label="Flagging reason", | |
placeholder="Provide the reason for flagging if you think the sample is bad.", | |
visible=False, | |
) | |
with gr.Row(): | |
bad_btn = gr.Button("Bad β", visible=False) | |
good_btn = gr.Button("Next β ", visible=False) | |
with gr.Row(): | |
text = gr.Textbox(visible=False, label="Datapoint", lines=500, max_lines=500) | |
def get_next_line(dataset): | |
try: | |
next_line = next(line_generators_s3[dataset]) | |
text, label = get_title_and_text_for_line(next_line) | |
except StopIteration: | |
text = LABELLING_COMPLETE_TEXT.format(dataset) | |
next_line = text | |
return [ | |
gr.update( | |
value=text, | |
visible=True, | |
label=label, | |
), | |
next_line, | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
] | |
def report_bad_line_and_next(current_sample, dataset, reason, annotator, campaign): | |
if current_sample != LABELLING_COMPLETE_TEXT.format(dataset): | |
send_report(current_sample, dataset, reason, annotator, campaign) | |
try: | |
next_line = next(line_generators_s3[dataset]) | |
text, label = get_title_and_text_for_line(next_line) | |
except StopIteration: | |
text = LABELLING_COMPLETE_TEXT.format(dataset) | |
next_line = text | |
return [ | |
gr.update( | |
value=text, | |
visible=True, | |
label=label, | |
), | |
gr.update( | |
value="", | |
placeholder="Provide the reason for flagging if you think the sample is bad.", | |
), | |
next_line, | |
] | |
good_btn.click( | |
get_next_line, | |
inputs=dataset, | |
outputs=[text, current_sample_state, reason_txt, good_btn, bad_btn], | |
) | |
dataset.change( | |
get_next_line, | |
inputs=dataset, | |
outputs=[text, current_sample_state, reason_txt, good_btn, bad_btn], | |
) | |
bad_btn.click( | |
report_bad_line_and_next, | |
inputs=[current_sample_state, dataset, reason_txt, annotator, campaign], | |
outputs=[text, reason_txt, current_sample_state], | |
) | |
demo.launch(enable_queue=False, debug=True) | |