|
|
|
import dataclasses |
|
import functools |
|
import logging |
|
import os |
|
from collections import OrderedDict |
|
from typing import List, Dict, Any |
|
|
|
import seqio |
|
from seqio import dataset_providers |
|
import tensorflow_datasets as tfds |
|
|
|
from .data_utils import _strip_metadata, build_tokenizer |
|
from .preprocesssors import * |
|
from .preprocesssors import _preprocess_scifi |
|
|
|
|
|
@dataclasses.dataclass |
|
class TaskSpec: |
|
name: str |
|
source: seqio.DataSourceInterface |
|
preprocessors: List |
|
style: str |
|
inference_preprocessors: List = None |
|
inference_only: bool = False |
|
decode_image: bool = False |
|
shuffle_after: Optional[int] = None |
|
ignore_errors: bool = False |
|
|
|
|
|
MULTITASK_TFDS_DATA_DIR = "/weka/oe-training-default/mm-olmo/tensorflow_datasets" |
|
|
|
TASKS: Dict[str, TaskSpec] = {} |
|
|
|
|
|
def add_task( |
|
name, |
|
source: seqio.DataSourceInterface, |
|
preprocessors: List, |
|
style: str, |
|
inf_preprocessor=None, |
|
inf_only=False, |
|
decode_image=False, |
|
shuffle_after=None, |
|
ignore_errors=False |
|
): |
|
TASKS[name] = TaskSpec( |
|
name, source, preprocessors, style, inf_preprocessor, inf_only, decode_image, |
|
shuffle_after, ignore_errors) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def add_image_size(ex): |
|
if ex["image"].dtype == tf.string: |
|
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False) |
|
img_h = tf.shape(ex["image"])[0] |
|
img_w = tf.shape(ex["image"])[1] |
|
ex["metadata/image_size"] = [img_w, img_h] |
|
|
|
|
|
@dataclasses.dataclass |
|
class TaskDatasetBuilder: |
|
"""tf.data.Dataset builder for task after shuffling, sharding, and initial model pre-processing |
|
have been applied""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
name: str |
|
source: Any |
|
preprocessors: List |
|
keep_metadata: bool |
|
shuffle_after: int |
|
sharding: str = "tfds_split" |
|
decode_image: bool = False |
|
ignore_errors: bool = False |
|
|
|
def get_dataset( |
|
self, |
|
sequence_length: Optional[Mapping[str, int]] = None, |
|
split: str = tfds.Split.TRAIN, |
|
shuffle: bool = True, |
|
shuffle_buffer_size: Optional[int] = 1000, |
|
seed: Optional[int] = None, |
|
shard_info: Optional[seqio.ShardInfo] = None, |
|
num_epochs: Optional[int] = 1, |
|
try_in_mem_cache: bool = True, |
|
trim_output_features: bool=True |
|
) -> tf.data.Dataset: |
|
source = self.source |
|
|
|
if self.sharding == "seqio": |
|
if source.supports_arbitrary_sharding: |
|
shard_data_source = True |
|
elif shard_info: |
|
|
|
shard_data_source = ( |
|
len(source.list_shards(split=split)) >= shard_info.num_shards |
|
) |
|
logging.info( |
|
"Sharding at the %s: %d of %d", |
|
"data source" if shard_data_source else "examples", |
|
shard_info.index + 1, |
|
shard_info.num_shards, |
|
) |
|
else: |
|
|
|
shard_data_source = True |
|
shard_info = None |
|
|
|
if "image" in source.tfds_dataset.info.features: |
|
if not self.decode_image: |
|
source.tfds_dataset._decoders = dict(image=tfds.decode.SkipDecoding()) |
|
|
|
if shard_data_source: |
|
ds = source.get_dataset( |
|
split=split, shuffle=shuffle, seed=seed, shard_info=shard_info) |
|
else: |
|
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed) |
|
ds = ds.shard(shard_info.num_shards, shard_info.index) |
|
elif self.sharding == "tfds_split": |
|
|
|
|
|
assert isinstance(self.source, seqio.TfdsDataSource) |
|
loader: seqio.LazyTfdsLoader = self.source.tfds_dataset |
|
dataset, data_dir = loader.get_split_params(split) |
|
shard_split = loader._map_split(split) |
|
if shard_info and shard_info.num_shards > 1: |
|
shard_split = tfds.even_splits(shard_split, n=shard_info.num_shards, drop_remainder=False)[shard_info.index] |
|
else: |
|
shard_split = shard_split |
|
read_config = loader.read_config |
|
read_config.shuffle_seed = seed |
|
read_config.skip_prefetch = True |
|
read_config.input_context = None |
|
|
|
if "image" in loader.info.features: |
|
decoders = dict(image=tfds.decode.SkipDecoding()) |
|
else: |
|
decoders = None |
|
ds = tfds.load( |
|
dataset, |
|
split=shard_split, |
|
data_dir=data_dir, |
|
shuffle_files=shuffle, |
|
download=True, |
|
try_gcs=True, |
|
read_config=read_config, |
|
decoders=decoders |
|
) |
|
else: |
|
raise NotImplementedError(self.sharding) |
|
|
|
num_shards = shard_info.num_shards if shard_info else 1 |
|
if try_in_mem_cache and ( |
|
source.num_input_examples(split) |
|
and source.num_input_examples(split) |
|
< 10000 * num_shards |
|
): |
|
logging.info(f"Automatically caching small dataset in memory: {self.name}:{split}") |
|
ds = ds.cache() |
|
|
|
|
|
|
|
if num_epochs != 1: |
|
ds = ds.repeat(num_epochs) |
|
|
|
preprocessors = [ |
|
seqio.add_kwargs_to_transform( |
|
_fn, |
|
sequence_length=sequence_length, |
|
output_features=None, |
|
) for _fn in self.preprocessors |
|
] |
|
|
|
with seqio.utils.map_seed_manager(seed): |
|
for fn in preprocessors[:self.shuffle_after]: |
|
ds = fn(ds) |
|
|
|
|
|
if not self.keep_metadata: |
|
ds = _strip_metadata(ds) |
|
|
|
if shuffle: |
|
if shuffle_buffer_size is None: |
|
raise ValueError("Shuffle is true, but shuffle_buffer_size is None") |
|
else: |
|
ds = ds.shuffle(shuffle_buffer_size, seed=seed) |
|
|
|
for fn in preprocessors[self.shuffle_after:]: |
|
ds = fn(ds) |
|
|
|
if self.ignore_errors: |
|
ds = ds.ignore_errors(log_warning=True) |
|
|
|
if trim_output_features: |
|
ds = seqio.trim_dataset(ds, sequence_length, sequence_length) |
|
|
|
return ds |
|
|
|
|
|
def get_task(preprocessor, name, is_training, for_inference, |
|
include_metadata=None, style_override=None) -> TaskDatasetBuilder: |
|
"""Get a builder for task `name` that is pre-processed by `preprocessor`""" |
|
|
|
task_spec = TASKS[name] |
|
if for_inference is None: |
|
for_inference = task_spec.inference_only |
|
elif task_spec.inference_only and not for_inference: |
|
raise ValueError(f"Inference=only task {task_spec.name} can only be used in inference mode") |
|
|
|
if include_metadata is None: |
|
include_metadata = for_inference |
|
|
|
if preprocessor is not None: |
|
style = style_override if style_override else task_spec.style |
|
preprocessor = preprocessor.get_preprocessor( |
|
is_training, for_inference, style, include_metadata) |
|
preprocessor = [preprocessor] |
|
else: |
|
preprocessor = [] |
|
task_preprocessors = task_spec.preprocessors |
|
if for_inference and task_spec.inference_preprocessors is not None: |
|
task_preprocessors = task_spec.inference_preprocessors |
|
if isinstance(task_spec.source, seqio.TfdsDataSource): |
|
from seqio.utils import _TFDS_DATA_DIR_OVERRIDE |
|
if _TFDS_DATA_DIR_OVERRIDE: |
|
|
|
task_spec.source.tfds_dataset._data_dir = None |
|
|
|
return TaskDatasetBuilder( |
|
task_spec.name, |
|
task_spec.source, |
|
task_preprocessors + preprocessor, |
|
keep_metadata=include_metadata, |
|
shuffle_after=(task_spec.shuffle_after if task_spec.shuffle_after |
|
else len(task_spec.preprocessors)), |
|
sharding="seqio", |
|
decode_image=task_spec.decode_image, |
|
ignore_errors=task_spec.ignore_errors, |
|
) |
|
|
|
|
|
add_task( |
|
"coco_caption_2017", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_all:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial(rekey, key_map={ |
|
"image/filename": ["image/filename"], |
|
"image": ["image"], |
|
"text": ["captions", "text"] |
|
}), |
|
functools.partial(flatten_parts, parts=["text"]), |
|
], |
|
inf_preprocessor=[ |
|
functools.partial(rekey, key_map={ |
|
"image/filename": ["image/filename"], |
|
"image": ["image"], |
|
"text": ["captions", "text"] |
|
}) |
|
], |
|
style="coco_captioning", |
|
) |
|
|
|
|
|
add_task( |
|
"coco_captioning_karpathy", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_captioning_karpathy:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(text="captions"), |
|
functools.partial(flatten_parts, parts=["text"]), |
|
], |
|
inf_preprocessor=[add_coco_url], |
|
style="coco_captioning", |
|
) |
|
|
|
|
|
add_task( |
|
"synth_counting", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="synth_counting:0.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[5120:]", "validation": "train[:5120]"} |
|
), |
|
preprocessors=[synth_count_preprocessor], |
|
inf_preprocessor=[synth_count_inf_preprocessor], |
|
style="synth_counting", |
|
) |
|
|
|
|
|
add_task( |
|
"khan_academy", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="khan_academy:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[1024:]", "validation": "train[:1024]"} |
|
), |
|
preprocessors=[extract_khan_academy], |
|
style="khan_academy", |
|
) |
|
|
|
for name, src in [ |
|
("vaia_qa_latex_image_math_subset", seqio.TfdsDataSource( |
|
tfds_name=f"vaia_qa_latex_image_short_answer:0.1.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "validation"} |
|
)), |
|
("vaia_qa_latex_image_all", seqio.TfdsDataSource( |
|
tfds_name=f"vaia_qa_latex_image_short_answer:0.1.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "validation"} |
|
)), |
|
]: |
|
add_task( |
|
f"{name}_short_answer", |
|
source=src, |
|
preprocessors=[ |
|
remove_is_long, |
|
remove_has_multiple_parts, |
|
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True), |
|
], |
|
style="vaia_qa", |
|
) |
|
add_task( |
|
f"{name}_short_answer_first", |
|
source=src, |
|
preprocessors=[ |
|
remove_is_long, |
|
remove_has_multiple_parts, |
|
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True), |
|
], |
|
style="vaia_qa_short_answer_first", |
|
) |
|
add_task( |
|
f"{name}_mc_only_short_answer", |
|
source=src, |
|
preprocessors=[ |
|
remove_is_long, |
|
remove_has_multiple_parts, |
|
filter_mc, |
|
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True), |
|
], |
|
style="vaia_qa_short_answer", |
|
) |
|
add_task( |
|
f"{name}_mc_only_short_answer_first", |
|
source=src, |
|
preprocessors=[ |
|
remove_is_long, |
|
remove_has_multiple_parts, |
|
filter_mc, |
|
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True), |
|
], |
|
style="vaia_qa_short_answer_first", |
|
) |
|
add_task( |
|
f"{name}_image_only_short_answer", |
|
source=src, |
|
preprocessors=[ |
|
image_only, |
|
remove_is_long, |
|
remove_has_multiple_parts, |
|
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True), |
|
], |
|
style="vaia_qa_short_answer", |
|
) |
|
add_task( |
|
f"{name}_image_only_short_answer_first", |
|
source=src, |
|
preprocessors=[ |
|
image_only, |
|
remove_is_long, |
|
remove_has_multiple_parts, |
|
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True), |
|
], |
|
style="vaia_qa_short_answer_first", |
|
) |
|
|
|
add_task( |
|
"vqa_online", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="vqa_online:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "validation", "test": "validation"} |
|
), |
|
preprocessors=[ |
|
build_question_with_context, |
|
extract_vqa_online, |
|
], |
|
style="vqa_online", |
|
) |
|
|
|
add_task( |
|
"vqa_online_gpt_longQ_longA", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="vqa_online_gpt_parsed:1.1.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "validation", "test": "validation"} |
|
), |
|
preprocessors=[ |
|
rename(question="question_long", answer="answer_long"), |
|
extract_vqa_online, |
|
], |
|
style="vqa_online", |
|
) |
|
|
|
|
|
add_task( |
|
"famous_birthdays", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="famous_birth_days:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[5120:]", "validation": "train[:5120]"} |
|
), |
|
preprocessors=[ |
|
famous_birthdays_preprocessor, |
|
functools.partial(name_entity_augmentation, p_high_color=0.0), |
|
], |
|
style="famous_birthdays", |
|
) |
|
|
|
|
|
add_task( |
|
"wiki_art", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="wiki_art:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[5120:]", "validation": "train[:5120]"} |
|
), |
|
preprocessors=[name_entity_augmentation, wiki_art_preprocessor], |
|
style="wiki_art", |
|
) |
|
|
|
add_task( |
|
"wiki_art_no_aug", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="wiki_art:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[5120:]", "validation": "train[:5120]"} |
|
), |
|
preprocessors=[wiki_art_preprocessor], |
|
style="wiki_art", |
|
) |
|
|
|
add_task( |
|
"atlas_obscura", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="atlas_obscura:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[5120:]", "validation": "train[:5120]"} |
|
), |
|
preprocessors=[ |
|
atlas_obscura_preprocessor, |
|
mild_color_aug_preprocessor |
|
], |
|
style="atlas_obscura", |
|
) |
|
|
|
|
|
add_task( |
|
"clocks", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="clocks:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
clocks_preprocessor, |
|
clock_augmentation |
|
], |
|
style="clocks", |
|
shuffle_after=0 |
|
) |
|
|
|
|
|
add_task( |
|
"count_bench", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="count_bench:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
count_bench_preprocessor, |
|
], |
|
style="count_bench", |
|
) |
|
|
|
|
|
add_task( |
|
"tulu_v2_sft", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="allenai__tulu_v2_sft_mixture:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[tulu_preprocessor], |
|
style="tulu_v2", |
|
) |
|
|
|
|
|
|
|
for is_count in [True, False]: |
|
if is_count: |
|
task = "point_count" |
|
else: |
|
task = "pointing" |
|
add_task( |
|
task, |
|
source=seqio.TfdsDataSource( |
|
tfds_name="pointing:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "validation"} |
|
), |
|
preprocessors=[ |
|
filter_points, |
|
functools.partial(pointing_preprocessor, with_count=is_count), |
|
split |
|
], |
|
style=task, |
|
) |
|
add_task( |
|
task + "_eval", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="pointing:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
filter_points, |
|
functools.partial(pointing_preprocessor, with_count=is_count), |
|
split |
|
], |
|
style=task, |
|
) |
|
add_task( |
|
task + "_high_freq", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="count_qa:0.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[2048:]", |
|
validation="train[:2048]" |
|
) |
|
), |
|
preprocessors=[ |
|
filter_points, |
|
fix_count_qa, |
|
functools.partial(pointing_preprocessor, with_count=is_count), |
|
split, |
|
], |
|
style=task, |
|
) |
|
add_task( |
|
"fast_flickr_count_qa_" + task, |
|
source=seqio.TfdsDataSource( |
|
tfds_name="fast_flickr_count_qa:1.0.4", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial(count_qa_preprocessor, with_count=is_count), |
|
], |
|
inf_preprocessor=[ |
|
functools.partial(count_qa_preprocessor, with_count=is_count, for_inference=True), |
|
], |
|
style=task, |
|
) |
|
|
|
|
|
add_task( |
|
"countbench_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="countbench_qa:1.2.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
inf_only=True, |
|
preprocessors=[ |
|
count_qa_preprocessor_inf, |
|
], |
|
style="point_count", |
|
) |
|
|
|
|
|
add_task( |
|
f"pointing_test", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="pointing:1.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
pointing_inf_preprocessor |
|
], |
|
style=task, |
|
inf_only=True, |
|
) |
|
|
|
|
|
add_task( |
|
"point_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="point_qa:0.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[512:]", |
|
validation="train[:512]" |
|
) |
|
), |
|
preprocessors=[extract_point_qa, split], |
|
style="point_qa", |
|
) |
|
|
|
add_task( |
|
"clocks_no_aug", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="clocks:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
clocks_preprocessor |
|
], |
|
style="clocks", |
|
) |
|
|
|
|
|
add_task( |
|
"clock_bench", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="clock_bench:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
clock_bench_preprocessor |
|
], |
|
inf_only=True, |
|
style="clocks", |
|
) |
|
|
|
add_task( |
|
"wiki_data", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="cockatoo_wiki:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"} |
|
), |
|
preprocessors=[extract_wiki_data], |
|
style="wiki_data", |
|
) |
|
|
|
|
|
add_task( |
|
"wiki_data_name", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="cockatoo_wiki:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"} |
|
), |
|
preprocessors=[ |
|
extract_wiki_data_name, |
|
mild_color_aug_preprocessor |
|
], |
|
style="wiki_data", |
|
) |
|
|
|
add_task( |
|
"wiki_data_describe", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="cockatoo_wiki:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"} |
|
), |
|
preprocessors=[extract_wiki_data_describe], |
|
inf_only=True, |
|
style="wiki_data", |
|
) |
|
|
|
add_task( |
|
"wiki_data_describe", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="cockatoo_wiki:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"} |
|
), |
|
preprocessors=[extract_wiki_data_describe], |
|
inf_only=True, |
|
style="wiki_data", |
|
) |
|
|
|
|
|
for name, src in [ |
|
("scifi_charts", seqio.TfdsDataSource( |
|
tfds_name="sci_fi_charts:1.0.6", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[1024:]", "validation": "train[:1024]"} |
|
)), |
|
("scifi_table", seqio.TfdsDataSource( |
|
tfds_name="sci_fi_table:1.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[1024:]", "validation": "train[:1024]"} |
|
)), |
|
("scifi_document", seqio.TfdsDataSource( |
|
tfds_name="sci_fi_document:1.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[1024:]", "validation": "train[:1024]"} |
|
)), |
|
("scifi_diagram", seqio.TfdsDataSource( |
|
tfds_name="sci_fi_diagram:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[1024:]", "validation": "train[:1024]"} |
|
)), |
|
("scifi_natural", seqio.TfdsDataSource( |
|
tfds_name="sci_fi_natural:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[128:]", "validation": "train[:128]"} |
|
)), |
|
("scifi_nutrition", seqio.TfdsDataSource( |
|
tfds_name="sci_fi_nutrition:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[128:]", "validation": "train[:128]"} |
|
)) |
|
]: |
|
add_task( |
|
name + "_qa", |
|
source=src, |
|
preprocessors=[ |
|
remove_no_qa, |
|
_preprocess_scifi, |
|
extract_individual_vqa, |
|
], |
|
inf_preprocessor=[ |
|
remove_no_qa, _preprocess_scifi, |
|
functools.partial(flatten_parts, parts=["question", "answer"]), |
|
extract_individual_vqa, |
|
], |
|
style=name, |
|
) |
|
add_task( |
|
name + "_qa_split", |
|
source=src, |
|
preprocessors=[ |
|
remove_no_qa, |
|
_preprocess_scifi, |
|
extract_individual_vqa, |
|
split |
|
], |
|
inf_preprocessor=[ |
|
remove_no_qa, _preprocess_scifi, |
|
functools.partial(flatten_parts, parts=["question", "answer"]), |
|
extract_individual_vqa, |
|
], |
|
style=name, |
|
) |
|
add_task( |
|
name + "_qa_exp", |
|
source=src, |
|
preprocessors=[ |
|
remove_no_qa, |
|
_preprocess_scifi, |
|
extract_scifi_qa_exp, |
|
extract_individual_vqa, |
|
], |
|
inf_preprocessor=[ |
|
remove_no_qa, _preprocess_scifi, |
|
extract_scifi_qa_exp, |
|
functools.partial(flatten_parts, parts=["question", "answer"]), |
|
extract_individual_vqa, |
|
], |
|
style=name + "_qa_exp", |
|
) |
|
add_task( |
|
name + "_qa_exp_split", |
|
source=src, |
|
preprocessors=[ |
|
remove_no_qa, |
|
_preprocess_scifi, |
|
extract_scifi_qa_exp, |
|
extract_individual_vqa, |
|
split, |
|
], |
|
inf_preprocessor=[ |
|
remove_no_qa, _preprocess_scifi, |
|
extract_scifi_qa_exp, |
|
functools.partial(flatten_parts, parts=["question", "answer"]), |
|
extract_individual_vqa, |
|
], |
|
style=name + "_qa_exp", |
|
) |
|
add_task( |
|
name + "_exp", |
|
source=src, |
|
preprocessors=[ |
|
remove_no_qa, |
|
_preprocess_scifi, |
|
scifi_explanation_only, |
|
extract_individual_vqa, |
|
split |
|
], |
|
style=name + "_exp" |
|
) |
|
add_task( |
|
name + "_demo", |
|
source=src, |
|
preprocessors=[ |
|
remove_no_qa, |
|
_preprocess_scifi, |
|
extract_scifi_qa_demo, |
|
extract_individual_vqa, |
|
split |
|
], |
|
style="scifi_demo" |
|
) |
|
|
|
|
|
add_task( |
|
"chart_qa_scifi", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="chart_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}), |
|
extract_individual_vqa, |
|
], |
|
style="scifi_charts_qa_exp", |
|
) |
|
|
|
|
|
add_task( |
|
"chart_qa_prompting", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="chart_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}), |
|
chartqa_prompting, |
|
extract_individual_vqa, |
|
], |
|
style="chart_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"chart_qa_prompting_explanation", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="chart_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}), |
|
chartqa_explanation, |
|
extract_individual_vqa, |
|
], |
|
style="chart_qa", |
|
) |
|
|
|
|
|
|
|
add_task( |
|
"coco_captioning_karpathy_multi", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_captioning_karpathy:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[rename(text="captions")], |
|
inf_preprocessor=[add_coco_url], |
|
style="coco_captioning", |
|
) |
|
|
|
|
|
add_task( |
|
"coco_caption_2017_grouped", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_all:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
rekey, key_map={ |
|
"image/filename": ["image/filename"], |
|
"image": ["image"], |
|
"text": ["captions", "text"] |
|
}), |
|
join_captions |
|
], |
|
style="coco_captioning_multiple", |
|
) |
|
|
|
|
|
add_task( |
|
"llava_pretrain", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="llava_pretrain:1.0.0", |
|
tfds_data_dir="gs://mm-olmo-datasets/", |
|
splits=dict( |
|
train="train[4096:]", |
|
validation="train[:4096]" |
|
) |
|
), |
|
preprocessors=[extract_llava], |
|
style="web_caption" |
|
) |
|
|
|
|
|
add_task( |
|
"rohun_images", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="rohun_images:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[], |
|
style="long_caption", |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"dense_caption_eval", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="dense_captioning_eval:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict(validation="train") |
|
), |
|
preprocessors=[], |
|
style="long_caption", |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"dense_caption_eval_dbg", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="dense_captioning_eval:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict(validation="train") |
|
), |
|
preprocessors=[ |
|
lambda ds: ds.filter(lambda x: x["url"] == "https://explore-multimodal-datasets.s3.us-west-2.amazonaws.com/eval-set/v0/eval-set/a211be07e2c9c722ef75093026a608856bd07ad935ebdedea6f2944b1f2d2b0e.jpg") |
|
], |
|
style="long_caption", |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"dense_caption_sample", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="dense_captioning_eval:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
validation="train" |
|
) |
|
), |
|
preprocessors=[select_dense_caption_sample], |
|
style="long_caption", |
|
) |
|
|
|
|
|
add_task( |
|
"cockatoo_1per_caption_287k", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="cockatoo_1per_caption_287k:1.0.5", |
|
tfds_data_dir="gs://mm-olmo-data/", |
|
splits=dict( |
|
train="train[5120:]", |
|
validation="train[:5120]" |
|
) |
|
), |
|
preprocessors=[ |
|
rename(text="caption"), |
|
], |
|
style="long_caption" |
|
) |
|
|
|
|
|
def _filter_large_ratio(ds): |
|
return ds.filter( |
|
lambda x: tf.shape(x["image"])[0] > tf.shape(x["image"])[1]*2 |
|
) |
|
|
|
|
|
add_task( |
|
f"cockatoo_dbg", |
|
source= seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:]", |
|
validation="train[:5120]" |
|
) |
|
) |
|
, |
|
preprocessors=[ |
|
_filter_large_ratio, |
|
extract_caption_and_transcript |
|
], |
|
style=["long_caption", "transcript"] |
|
) |
|
|
|
|
|
for name, src in [ |
|
("712k_sept6", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_712k_sept6:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("476k", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("476k_gpt_captions", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k_gpt_captions:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("100k_of_476k_gpt_captions", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k_gpt_captions:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:105120]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("200k_of_476k_gpt_captions", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k_gpt_captions:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:205120]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("300k_of_476k_gpt_captions", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k_gpt_captions:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:305120]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("400k_of_476k_gpt_captions", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k_gpt_captions:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:405120]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("400k_of_476k", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:405120]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("300k_of_476k", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:305120]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("200k_of_476k", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:205120]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("100k_of_476k", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_476k:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:105120]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("276k", seqio.TfdsDataSource( |
|
tfds_name="cockatoo:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:]", |
|
validation="train[:5120]" |
|
) |
|
)), |
|
("180k", seqio.TfdsDataSource( |
|
tfds_name="cockatoo:1.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[4096:]", |
|
validation="train[:4096]" |
|
) |
|
)), |
|
("84k_claude_captions", seqio.TfdsDataSource( |
|
tfds_name="cockatoo_84k_claude_captions:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[1000:]", |
|
validation="train[:1000]" |
|
) |
|
)), |
|
]: |
|
add_task( |
|
f"cockatoo_{name}", |
|
source=src, |
|
preprocessors=[extract_caption], |
|
style="long_caption" |
|
) |
|
|
|
add_task( |
|
f"cockatoo_and_transcript_{name}", |
|
source=src, |
|
preprocessors=[extract_caption_and_transcript], |
|
style=["long_caption", "transcript"] |
|
) |
|
|
|
add_task( |
|
f"cockatoo_and_transcript_stratified_{name}", |
|
source=src, |
|
preprocessors=[ |
|
extract_caption_and_transcript, |
|
|
|
|
|
seqio.CacheDatasetPlaceholder(), |
|
], |
|
style=["long_caption", "transcript"] |
|
) |
|
add_task( |
|
f"cockatoo_and_all_transcripts_{name}", |
|
source=src, |
|
preprocessors=[extract_caption_and_all_transcripts], |
|
style=["long_caption", "transcript", "transcript", "transcript"] |
|
) |
|
|
|
add_task( |
|
f"cockatoo_all_transcripts_{name}", |
|
source=src, |
|
preprocessors=[extract_all_transcripts], |
|
style="transcript" |
|
) |
|
add_task( |
|
f"cockatoo_transcripts_{name}", |
|
source=src, |
|
preprocessors=[extract_transcript], |
|
style="transcript" |
|
) |
|
|
|
|
|
TFRECORD_IMAGE_TEXT_FEATURES = { |
|
'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string), |
|
'text':tf.io.FixedLenFeature(shape=(), dtype=tf.string), |
|
} |
|
|
|
|
|
add_task( |
|
"laion400m", |
|
source=seqio.TFExampleDataSource( |
|
split_to_filepattern={ |
|
"train": os.path.join("gs://unified-io-2-us-east/", "pretrain-datasets", "laion400m", "1.0.0", "laion400m-train*"), |
|
}, |
|
feature_description=TFRECORD_IMAGE_TEXT_FEATURES, |
|
), |
|
preprocessors=[ |
|
functools.partial(rekey, key_map={ |
|
"image": ["image"], |
|
"text": ["text"] |
|
}), |
|
], |
|
style="laion", |
|
) |
|
|
|
|
|
add_task( |
|
"laion_2B", |
|
source=seqio.TFExampleDataSource( |
|
split_to_filepattern={ |
|
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "laion2b_en", "1.0.0", "laion2b_en-train*"), |
|
}, |
|
feature_description=TFRECORD_IMAGE_TEXT_FEATURES, |
|
), |
|
preprocessors=[ |
|
functools.partial(rekey, key_map={ |
|
"image": ["image"], |
|
"text": ["text"] |
|
}), |
|
], |
|
style="laion", |
|
) |
|
|
|
|
|
add_task( |
|
"region_caption_vg", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="vg:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[region_captions_to_dense], |
|
style="region_captions", |
|
) |
|
|
|
|
|
add_task( |
|
"pdfa_eng_wds", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="pdfa_eng_wds:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial(max_words, max_words=400), |
|
format_pdfa_eng_wds |
|
], |
|
style="pdfa_eng_wds", |
|
) |
|
|
|
|
|
add_task( |
|
"idl_words", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="idl_words:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[], |
|
style="idl_words", |
|
) |
|
|
|
|
|
|
|
open_image_v6_keys_to_features = { |
|
'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string), |
|
'image_id': tf.io.FixedLenFeature(shape=(), dtype=tf.string), |
|
'detection/label':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True), |
|
'detection/bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True), |
|
'detection/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64), |
|
'vrd/sub_label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True), |
|
'vrd/obj_label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True), |
|
'vrd/sub_bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True), |
|
'vrd/obj_bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True), |
|
'vrd/relation': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True), |
|
'vrd/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64), |
|
'cap/cap_caption': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True), |
|
'cap/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64), |
|
'seg/masks': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True), |
|
'seg/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64), |
|
'seg/label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True), |
|
'seg/bbox': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True), |
|
} |
|
|
|
|
|
add_task( |
|
"localized_narratives_v6", |
|
source=seqio.TFExampleDataSource( |
|
split_to_filepattern={ |
|
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "open_image_v6", "1.0.0", "open_image_v6-train*"), |
|
}, |
|
feature_description=open_image_v6_keys_to_features, |
|
), |
|
preprocessors=[extract_localized_narrative], |
|
style="localized_narratives", |
|
) |
|
|
|
|
|
add_task( |
|
"lvis_objects", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="lvis:1.2.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
extract_lvis, |
|
region_captions_to_dense, |
|
], |
|
style="lvis_objects", |
|
) |
|
|
|
|
|
add_task( |
|
"open_images_with_objects", |
|
|
|
source=seqio.TFExampleDataSource( |
|
split_to_filepattern={ |
|
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "open_image_v6", "1.0.0", "open_image_v6-train*"), |
|
}, |
|
feature_description=open_image_v6_keys_to_features, |
|
), |
|
preprocessors=[ |
|
extract_open_images_boxes, |
|
region_captions_to_dense, |
|
], |
|
style="visual_narratives_with_objects", |
|
) |
|
|
|
|
|
add_task( |
|
"cockatoo_with_acc_476k_gpt_captions", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="cockatoo_with_acc_476k_gpt_captions:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:]", |
|
validation="train[:5120]" |
|
) |
|
), |
|
preprocessors=[accuracy_conditioned_joint], |
|
inf_preprocessor=[functools.partial(accuracy_conditioned_joint, is_eval=True)], |
|
style=None |
|
) |
|
|
|
|
|
add_task( |
|
"dense_caption_eval_with_acc", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="dense_captioning_eval:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict(validation="train") |
|
), |
|
preprocessors=[functools.partial(accuracy_conditioned_joint, is_eval=True)], |
|
style="long_caption", |
|
inf_only=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
add_task( |
|
"science_qa_img", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="science_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
image_only, |
|
rename(answer_idx="answer"), |
|
build_question_with_hint, |
|
format_multiple_choice_qa |
|
], |
|
style="science_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"tabwmp_da", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="tab_mwp:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "dev", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(text="answer") |
|
], |
|
style="tabwmp_da", |
|
) |
|
|
|
|
|
add_task( |
|
"figure_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="figure_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train1", "validation": "validation1", "test": "no_annot_test1"} |
|
), |
|
preprocessors=[extract_figureqa, extract_individual_vqa], |
|
style="figure_qa", |
|
) |
|
|
|
add_task( |
|
"figure_qa_zero_shot", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="figure_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train1", "validation": "validation1", "test": "no_annot_test1"} |
|
), |
|
preprocessors=[extract_figureqa, convert_figureqa_answer, extract_individual_vqa], |
|
style="figure_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"plot_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="plot_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[extract_figureqa, extract_individual_vqa], |
|
inf_preprocessor=[ |
|
extract_figureqa, |
|
functools.partial(flatten_parts, parts=["questions", "answer", "question_id"]), |
|
extract_individual_vqa |
|
], |
|
style="plot_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"ai2_diagram", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="ai2_diagram:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(choices="answer_texts", answer_idx="correct_answer"), |
|
format_multiple_choice_qa |
|
], |
|
style="ai2_diagram", |
|
) |
|
|
|
|
|
add_task( |
|
"ai2_diagram_v2", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="ai2_diagram_v2:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
rename(choices="answer_texts", answer_idx="correct_answer"), |
|
format_ai2d |
|
], |
|
style="ai2_diagram", |
|
) |
|
|
|
|
|
add_task( |
|
"ai2_diagram_v2_transparent", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="ai2_diagram_v2_transparent:1.0.5", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
rename(choices="answer_texts", answer_idx="correct_answer"), |
|
format_ai2d |
|
], |
|
style="ai2_diagram", |
|
) |
|
|
|
|
|
|
|
add_task( |
|
"ai2_diagram_v2_mix_transparent", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="ai2_diagram_v2_mix_transparent:1.0.6", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={ |
|
"train": "train_mix", |
|
"validation": "validation_mix", |
|
"test": "test_mix", |
|
|
|
} |
|
), |
|
preprocessors=[ |
|
rename(choices="answer_texts", answer_idx="correct_answer"), |
|
format_ai2d |
|
], |
|
style="ai2_diagram", |
|
) |
|
|
|
add_task( |
|
"ai2_diagram_v2_mix_transparent_one_style", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="ai2_diagram_v2_mix_transparent:1.0.6", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={ |
|
"train": "train_mix", |
|
"validation": "validation_mix", |
|
"test": "test_mix", |
|
|
|
} |
|
), |
|
preprocessors=[ |
|
rename(choices="answer_texts", answer_idx="correct_answer"), |
|
functools.partial(format_ai2d, variable_style=False), |
|
], |
|
style="ai2_diagram", |
|
) |
|
|
|
|
|
for src, test_sets in [ |
|
["refclef_unc", ["testA", "testB", "testC", "testAB", "testBC"]], |
|
["refcoco_unc", ["testA", "testB"]], |
|
["refcocoplus_unc", ["testA", "testB"]], |
|
["refcocog_umd", ["test"]], |
|
]: |
|
if "coco" in src: |
|
add_url = [add_coco_url] |
|
else: |
|
add_url = [] |
|
splits = {x: x for x in test_sets} |
|
splits.update({"train": "train", "validation": "val"}) |
|
add_task( |
|
src, |
|
source=seqio.TfdsDataSource( |
|
tfds_name=f"{src}:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=splits |
|
), |
|
preprocessors=[refexp], |
|
inf_preprocessor=add_url + [ |
|
refexp_inf, |
|
|
|
functools.partial(flatten_parts, parts=["refexp", "metadata/bbox"]), |
|
|
|
functools.partial(flatten_parts, parts=["refexp"]) |
|
], |
|
style="refexp", |
|
decode_image=True, |
|
) |
|
add_task( |
|
src + "_pointing", |
|
source=seqio.TfdsDataSource( |
|
tfds_name=f"{src}:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=splits |
|
), |
|
preprocessors=[refexp_pointing], |
|
inf_preprocessor=add_url + [ |
|
refexp_pointing_inf, |
|
functools.partial(flatten_parts, parts=["refexp", "metadata/bbox", "metadata/mask", "metadata/answer"]), |
|
functools.partial(flatten_parts, parts=["refexp"]) |
|
], |
|
decode_image=True, |
|
style="refexp_pointing", |
|
) |
|
|
|
|
|
|
|
add_task( |
|
"ai2_diagram_test", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="ai2_diagram:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(choices="answer_texts", answer_idx="correct_answer"), |
|
format_multiple_choice_qa |
|
], |
|
style="ai2_diagram", |
|
) |
|
|
|
|
|
add_task( |
|
"gqa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="gqa:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
functools.partial(format_gqa, is_balanced=True), |
|
extract_individual_vqa, |
|
], |
|
inf_preprocessor=[ |
|
functools.partial(format_gqa, is_balanced=True), |
|
extract_individual_vqa, |
|
], |
|
style="gqa", |
|
) |
|
|
|
|
|
add_task( |
|
"gqa_multi", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="gqa:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
functools.partial(format_gqa, is_balanced=True, flatten=False), |
|
extract_individual_vqa, |
|
], |
|
inf_preprocessor=[ |
|
functools.partial(format_gqa, is_balanced=True, flatten=False), |
|
extract_individual_vqa, |
|
], |
|
style="gqa", |
|
) |
|
|
|
|
|
add_task( |
|
"text_vqa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="text_vqa:1.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
rekey, key_map={ |
|
"image": ["image"], |
|
"questions": ["question"], |
|
"answers": ["answers"], |
|
"id": ["question_id"] |
|
}), |
|
extract_individual_vqa, |
|
], |
|
style="text_vqa", |
|
) |
|
|
|
|
|
add_task( |
|
"okvqa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="ok_vqa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
rename(example_id="question_id"), |
|
add_coco_url, |
|
extract_individual_vqa, |
|
], |
|
style="okvqa", |
|
) |
|
|
|
add_task( |
|
"a_okvqa_da", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="a_ok_vqa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(**{ |
|
"example_id": "question_id", |
|
"answers": "direct_answers", |
|
"metadata/difficult_direct_answer": "difficult_direct_answer" |
|
}), |
|
extract_individual_vqa, |
|
], |
|
inf_preprocessor=[ |
|
filter_difficult_direct_answer, |
|
rename(**{ |
|
"example_id": "question_id", |
|
"answers": "direct_answers", |
|
"metadata/difficult_direct_answer": "difficult_direct_answer" |
|
}), |
|
add_coco_url, |
|
extract_individual_vqa, |
|
], |
|
style="a_okvqa_da", |
|
) |
|
|
|
|
|
add_task( |
|
"a_okvqa_mc", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="a_ok_vqa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(**{ |
|
"example_id": "question_id", |
|
"metadata/difficult_direct_answer": "difficult_direct_answer", |
|
"answer_idx": "correct_choice_idx" |
|
}), |
|
add_coco_url, |
|
format_multiple_choice_qa, |
|
], |
|
style="a_okvqa_mc", |
|
) |
|
|
|
|
|
add_task( |
|
"dv_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="dv_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val_easy"} |
|
), |
|
preprocessors=[ |
|
extract_figureqa, |
|
extract_individual_vqa, |
|
], |
|
inf_preprocessor=[ |
|
extract_figureqa, |
|
flatten_vqa, |
|
extract_individual_vqa |
|
], |
|
style="dv_qa", |
|
) |
|
|
|
|
|
@seqio.map_over_dataset |
|
def add_image_question_example_id(ex): |
|
key = tf.strings.join([ex["question"], "\n\n", ex["image"]]) |
|
ex["metadata/example_id"] = tf.strings.to_hash_bucket(key, 2**30) |
|
return ex |
|
|
|
|
|
add_task( |
|
"chart_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="chart_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}), |
|
add_image_question_example_id, |
|
extract_individual_vqa, |
|
], |
|
style="chart_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"chart_qa_ex", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="chart_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}), |
|
extract_individual_vqa, |
|
], |
|
style="scifi_charts_qa_exp", |
|
) |
|
|
|
|
|
add_task( |
|
"chart_qa_weighted", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="chart_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}), |
|
extract_individual_vqa, |
|
functools.partial(reweight_chartqa, human=2*20901/(20901+7398), aug=2*7398/(20901+7398)), |
|
], |
|
style="chart_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"chart_qa_human", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="chart_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(question="query", answer="label"), |
|
add_image_question_example_id, |
|
filter_human, |
|
extract_individual_vqa, |
|
], |
|
style="chart_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"chart_qa_aug", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="chart_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[ |
|
rename(question="query", answer="label"), |
|
filter_aug, |
|
extract_individual_vqa, |
|
], |
|
style="chart_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"doc_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="doc_qa:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[fix_doqa_url, extract_individual_vqa], |
|
style="doc_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"ocr_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="ocr_vqa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[extract_individual_vqa], |
|
inf_preprocessor=[flatten_vqa, extract_individual_vqa], |
|
style="ocr_vqa", |
|
) |
|
|
|
|
|
add_task( |
|
"st_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="st_vqa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"} |
|
), |
|
preprocessors=[extract_individual_vqa], |
|
inf_preprocessor=[extract_individual_vqa], |
|
style="st_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"tally_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="tally_qa:1.0.2", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "test"} |
|
), |
|
preprocessors=[ |
|
extract_tally_qa, |
|
extract_individual_vqa |
|
], |
|
inf_preprocessor=[ |
|
extract_tally_qa, |
|
flatten_vqa, |
|
extract_individual_vqa |
|
], |
|
style="tally_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"info_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="info_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[extract_individual_vqa], |
|
style="info_qa", |
|
) |
|
|
|
add_task( |
|
"android_control", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="android_control:2.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[extract_android_control], |
|
style="android_control", |
|
) |
|
|
|
for mode in ["ll", "hl", "hl_ll", "hl_cot"]: |
|
add_task( |
|
f"android_control_{mode}", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="android_control:2.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "train", "validation": "val", "test": "test"} |
|
), |
|
preprocessors=[functools.partial(extract_andriod_control_inf, mode=mode)], |
|
style="android_control", |
|
) |
|
|
|
|
|
map_coco_vqa = functools.partial(rekey, key_map={ |
|
"image": ["image"], |
|
"questions": ["vqa", "questions"], |
|
"answers": ["vqa", "answers"], |
|
"id": ["vqa", "id"], |
|
"metadata/image_url": ["metadata/image_url"], |
|
}) |
|
|
|
|
|
add_task( |
|
"coco_2017_vqa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_all:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
add_coco_url, |
|
map_coco_vqa, |
|
flatten_vqa, |
|
extract_individual_vqa |
|
], |
|
style="vqa2", |
|
) |
|
|
|
|
|
add_task( |
|
"cockatoo_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="cockatoo_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[5120:]", |
|
validation="train[:5120]" |
|
) |
|
), |
|
preprocessors=[rename(text="answer")], |
|
style=None, |
|
) |
|
|
|
|
|
add_task( |
|
"synthetic_qa_v3", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="synthetic_qa_v3:0.0.4", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[2048:]", |
|
validation="train[:2048]" |
|
) |
|
), |
|
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages], |
|
style="synthetic_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"synthetic_qa_v3_style_tag", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="synthetic_qa_v3:0.0.4", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[2048:]", |
|
validation="train[:2048]" |
|
) |
|
), |
|
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages], |
|
style="llm_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"synthetic_qa_v3_as_user_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="synthetic_qa_v3:0.0.4", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[2048:]", |
|
validation="train[:2048]" |
|
) |
|
), |
|
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages], |
|
style="user_qa", |
|
) |
|
|
|
|
|
add_task( |
|
"synthetic_qa_v3_multi_turn", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="synthetic_qa_v3:0.0.4", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[2048:]", |
|
validation="train[:2048]" |
|
) |
|
), |
|
preprocessors=[extract_cockatoo_qa_v2, filter_single_turn, prefix_how_many_messages], |
|
style="synthetic_qa", |
|
) |
|
|
|
|
|
NE_SHARDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] |
|
|
|
for i in NE_SHARDS: |
|
add_task( |
|
f"named_entity{i}", |
|
source=seqio.TfdsDataSource( |
|
tfds_name=f"named_entities_qa_{i}_of_18:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[1024:]", |
|
validation="train[:1024]" |
|
) |
|
), |
|
preprocessors=[filter_named_entity, extract_named_entity, extract_individual_vqa], |
|
inf_preprocessor=[ |
|
filter_named_entity, |
|
extract_named_entity, |
|
flatten_vqa, |
|
extract_individual_vqa |
|
], |
|
style="named_entity", |
|
ignore_errors=True |
|
) |
|
|
|
|
|
add_task( |
|
"user_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="user_qa:0.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits=dict( |
|
train="train[2048:]", |
|
validation="train[:2048]" |
|
) |
|
), |
|
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages], |
|
style="user_qa", |
|
) |
|
|
|
add_task( |
|
"user_questions_for_elo", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="user_questions_for_elo:0.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[functools.partial(extract_individual_vqa, test=True)], |
|
inf_only=True, |
|
style="demo", |
|
) |
|
|
|
|
|
def _filter_by_id(ds, prediction_file, max_seq_len): |
|
with open(prediction_file) as f: |
|
predictions = json.load(f) |
|
is_long = [] |
|
lens = [] |
|
tokenizer = build_tokenizer("hf-Qwen/Qwen2-7B") |
|
for pred in predictions: |
|
n_tokens = len(tokenizer.encode(pred["prediction"])) |
|
lens.append(n_tokens) |
|
if n_tokens >= max_seq_len: |
|
is_long.append(pred["example_id"]) |
|
is_long = tf.constant(is_long) |
|
logging.info(f"Filtering for {len(is_long)} ids") |
|
return ds.filter(lambda ex: tf.reduce_any(ex["example_id"] == is_long)) |
|
|
|
|
|
|
|
add_task( |
|
"user_questions_for_elo", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="user_questions_for_elo:0.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[functools.partial(extract_individual_vqa, test=True)], |
|
inf_only=True, |
|
style="demo", |
|
) |
|
|
|
|
|
add_task( |
|
"user_questions_for_elo_long", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="user_questions_for_elo:0.0.3", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial(_filter_by_id, prediction_file="/weka/oe-training-default/chrisc/cockatoo/models/uber-model-v11/70b-335-30k-3.2-resume8k-noopt/predictions-ck20000-user_questions_for_elo-test/predictions.json", max_seq_len=230), |
|
functools.partial(extract_individual_vqa, test=True) |
|
], |
|
inf_only=True, |
|
style="demo", |
|
) |
|
|
|
|
|
add_task( |
|
"coco_2014_vqa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_2014_all:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
add_coco_url, |
|
map_coco_vqa, |
|
flatten_vqa, |
|
extract_individual_vqa |
|
], |
|
inf_preprocessor=[ |
|
add_coco_url, |
|
map_coco_vqa, |
|
flatten_vqa, |
|
extract_individual_vqa |
|
], |
|
style="vqa2", |
|
) |
|
|
|
|
|
add_task( |
|
"coco_2014_vqa_multi", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_2014_all:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
add_coco_url, |
|
map_coco_vqa, |
|
extract_individual_vqa |
|
], |
|
inf_preprocessor=[ |
|
add_coco_url, |
|
map_coco_vqa, |
|
flatten_vqa, |
|
extract_individual_vqa |
|
], |
|
style="vqa2", |
|
) |
|
|
|
|
|
add_task( |
|
"coco_2017_vqa_multi", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_all:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
add_coco_url, |
|
map_coco_vqa, |
|
extract_individual_vqa |
|
], |
|
inf_preprocessor=[ |
|
add_coco_url, |
|
map_coco_vqa, |
|
flatten_vqa, |
|
extract_individual_vqa |
|
], |
|
style="vqa2", |
|
) |
|
|
|
|
|
add_task( |
|
"vqa_v2_test", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="coco_test_all:1.0.1", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial(rekey, key_map={ |
|
"image": ["image"], |
|
"questions": ["vqa", "questions"], |
|
"answers": ["vqa", "answers"], |
|
"id": ["vqa", "id"], |
|
}), |
|
flatten_vqa, |
|
functools.partial(extract_individual_vqa, test=True) |
|
], |
|
style="vqa2", |
|
inf_only=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
add_task( |
|
"seed_bench_test", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="seed_bench:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
format_multiple_choice_qa, |
|
], |
|
style="a_okvqa_mc", |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"pope_test", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="pope:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
add_coco_url, |
|
extract_individual_vqa |
|
], |
|
style="vqa2", |
|
inf_only=True |
|
) |
|
|
|
|
|
MME_SOURCE = seqio.TfdsDataSource( |
|
tfds_name="mme:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
) |
|
|
|
|
|
add_task( |
|
"mme_test", |
|
|
|
source=MME_SOURCE, |
|
preprocessors=[ |
|
functools.partial(flatten_parts, parts=["questions", "answers"]), |
|
rename(question="questions", answer="answers"), |
|
extract_individual_vqa, |
|
], |
|
style="vqa2", |
|
inf_only=True |
|
) |
|
|
|
add_task( |
|
"real_world_qa_test", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="real_world_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
format_multiple_style_qa, |
|
types=['multiple_choice', 'short_answer'], |
|
styles=['a_okvqa_mc', 'vqa2'], |
|
default_style="a_okvqa_mc", |
|
), |
|
], |
|
style=None, |
|
inf_only=True |
|
) |
|
|
|
add_task( |
|
"real_world_qa_no_instruction", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="real_world_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
functools.partial(format_multiple_style_qa, strip_instruction=True), |
|
types=['multiple_choice', 'short_answer'], |
|
styles=['a_okvqa_mc', 'vqa2'], |
|
default_style="a_okvqa_mc", |
|
), |
|
], |
|
style=None, |
|
inf_only=True |
|
) |
|
|
|
add_task( |
|
"real_world_qa_dbg", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="real_world_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial( |
|
format_multiple_style_qa, |
|
types=['multiple_choice', 'short_answer'], |
|
styles=['user_qa', 'user_qa'], |
|
default_style="user_qa", |
|
), |
|
], |
|
style=None, |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"mmmu", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="mmmu:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"train": "dev"}, |
|
), |
|
preprocessors=[ |
|
rename(img_type="metadata/img_type"), |
|
functools.partial( |
|
extract_mmmu, |
|
types=['multiple-choice', 'open'], |
|
styles=['a_okvqa_mc', 'vqa2'], |
|
default_style="a_okvqa_mc", |
|
), |
|
], |
|
style=None, |
|
) |
|
|
|
|
|
add_task( |
|
"mmmu_test", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="mmmu:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "validation", "test": "test"}, |
|
), |
|
preprocessors=[ |
|
rename(img_type="metadata/img_type"), |
|
extract_mmmu, |
|
], |
|
style=None, |
|
inf_only=True |
|
) |
|
|
|
for style in ["vaia_qa", "vaia_qa_short_answer_first", "vqa_online", ]: |
|
add_task( |
|
f"mmmu_test_{style}", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="mmmu:1.0.0", |
|
|
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "validation", "test": "test", "dev": "dev"}, |
|
), |
|
preprocessors=[ |
|
rename(img_type="metadata/img_type"), |
|
extract_mmmu_cot, |
|
], |
|
style=style, |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"math_vista_test", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="math_vista:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "testmini", "test": "test"}, |
|
), |
|
preprocessors=[ |
|
functools.partial(rekey, key_map={ |
|
"id": ["id"], |
|
"query": ["query"], |
|
"image": ["image"], |
|
"choices": ["choices"], |
|
"answer": ["answer"], |
|
"metadata/question_type": ["question_type"], |
|
"metadata/answer_type": ["answer_type"], |
|
"metadata/precision": ["precision"], |
|
"metadata/split": ["metadata/split"], |
|
}), |
|
functools.partial(extract_math_vista, styles=['a_okvqa_mc', 'vqa2']), |
|
], |
|
style=None, |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"math_vista_v2", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="math_vista:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "testmini", "test": "test"}, |
|
), |
|
preprocessors=[ |
|
functools.partial(rekey, key_map={ |
|
"id": ["id"], |
|
"query": ["query"], |
|
"image": ["image"], |
|
"choices": ["choices"], |
|
"answer": ["answer"], |
|
"metadata/question_type": ["question_type"], |
|
"metadata/answer_type": ["answer_type"], |
|
"metadata/precision": ["precision"], |
|
"metadata/split": ["metadata/split"], |
|
}), |
|
reformat_math_vista, |
|
functools.partial( |
|
extract_math_vista, |
|
styles=['a_okvqa_mc', 'vqa2'], |
|
), |
|
], |
|
style=None, |
|
inf_only=True |
|
) |
|
|
|
|
|
MM_BENCH_SRC = seqio.TfdsDataSource( |
|
tfds_name="mmbench:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "dev", "test": "test"}, |
|
) |
|
|
|
add_task( |
|
"mmbench_test", |
|
source=MM_BENCH_SRC, |
|
preprocessors=[format_mmbench], |
|
style="a_okvqa_mc", |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"sugar_crepe_test", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="sugar_crepe:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
add_coco_url, |
|
functools.partial(flatten_parts, parts=["choices", "answer_idx", "metadata/answer_type"]), |
|
format_multiple_choice_qa, |
|
], |
|
style="a_okvqa_mc", |
|
inf_only=True |
|
) |
|
|
|
|
|
add_task( |
|
"blink_test", |
|
|
|
source=seqio.TfdsDataSource( |
|
tfds_name="blink:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
), |
|
preprocessors=[ |
|
functools.partial(rekey, key_map={ |
|
"id": ["id"], |
|
"question": ["prompt"], |
|
"image": ["image_concat"], |
|
"choices": ["choices"], |
|
"answer_idx": ["answer_idx"], |
|
"metadata/subtask": ["metadata/subtask"], |
|
"metadata/question": ["question"], |
|
}), |
|
format_multiple_choice_qa, |
|
output_options, |
|
], |
|
style="a_okvqa_mc", |
|
inf_only=True |
|
) |
|
|
|
add_task( |
|
"oscarbench_qa", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="oscarbench_qa:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "val"} |
|
), |
|
preprocessors=[oscar_preprocessor], |
|
style="oscarbench_qa" |
|
|
|
) |
|
|
|
add_task( |
|
"charxiv", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="charxiv:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "validation", "test": "test"} |
|
), |
|
preprocessors=[charxiv_preprocessor, extract_individual_vqa], |
|
inf_preprocessor=[ |
|
charxiv_preprocessor, |
|
functools.partial(flatten_parts, parts=["question", "answer"]), |
|
extract_individual_vqa, |
|
], |
|
style="charxiv", |
|
) |
|
|
|
add_task( |
|
"charxiv_descriptive", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="charxiv:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "validation", "test": "test"} |
|
), |
|
preprocessors=[charxiv_descriptive_preprocessor, extract_individual_vqa], |
|
inf_preprocessor=[ |
|
charxiv_descriptive_preprocessor, |
|
functools.partial(flatten_parts, parts=["question", "answer"]), |
|
extract_individual_vqa, |
|
], |
|
style="charxiv_descriptive", |
|
) |
|
|
|
add_task( |
|
"charxiv_reasoning", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="charxiv:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "validation", "test": "test"} |
|
), |
|
preprocessors=[charxiv_reasoning_preprocessor, extract_individual_vqa], |
|
style="charxiv_reasoning", |
|
) |
|
|
|
for tablevqa_name in ["fintabnetqa", "vwtq", "vwtq_syn"]: |
|
add_task( |
|
tablevqa_name, |
|
source=seqio.TfdsDataSource( |
|
tfds_name=f"{tablevqa_name}:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "test[:125]", "test": "test"} |
|
), |
|
preprocessors=[tablevqa_preprocessor, extract_individual_vqa], |
|
style=tablevqa_name, |
|
) |
|
|
|
add_task( |
|
"vtabfact", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="vtabfact:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "test[:125]", "test": "test"} |
|
), |
|
preprocessors=[vtabfact_preprocessor, extract_individual_vqa], |
|
style="vtabfact", |
|
) |
|
|
|
add_task( |
|
"nutrition_fact", |
|
source=seqio.TfdsDataSource( |
|
tfds_name="nutrition_fact:1.0.0", |
|
tfds_data_dir=MULTITASK_TFDS_DATA_DIR, |
|
splits={"validation": "test", "test": "test"} |
|
), |
|
preprocessors=[nutrition_fact_preprocessor, extract_individual_vqa], |
|
inf_preprocessor=[ |
|
nutrition_fact_preprocessor, |
|
functools.partial(flatten_parts, parts=["question", "answer"]), |
|
extract_individual_vqa, |
|
], |
|
style="nutrition_fact", |
|
inf_only=True |
|
) |
|
|
|
for k in ["chart_qa", "info_qa", "doc_qa", "text_vqa", "coco_2014_vqa", |
|
"ai2_diagram_v2_mix_transparent", "chart_qa_human"]: |
|
TASKS[k + "_demo"] = dataclasses.replace(TASKS[k], style="demo") |
|
|