MolmoE-1B-0924 / tasks.py
Muennighoff's picture
Cp over files
18652d8
raw
history blame
71.6 kB
# Module that can be imported to register all tasks
import dataclasses
import functools
import logging
import os
from collections import OrderedDict
from typing import List, Dict, Any
import seqio
from seqio import dataset_providers
import tensorflow_datasets as tfds
from .data_utils import _strip_metadata, build_tokenizer
from .preprocesssors import *
from .preprocesssors import _preprocess_scifi
@dataclasses.dataclass
class TaskSpec:
name: str
source: seqio.DataSourceInterface
preprocessors: List
style: str
inference_preprocessors: List = None
inference_only: bool = False
decode_image: bool = False
shuffle_after: Optional[int] = None
ignore_errors: bool = False
MULTITASK_TFDS_DATA_DIR = "/weka/oe-training-default/mm-olmo/tensorflow_datasets"
TASKS: Dict[str, TaskSpec] = {}
def add_task(
name,
source: seqio.DataSourceInterface,
preprocessors: List,
style: str,
inf_preprocessor=None,
inf_only=False,
decode_image=False,
shuffle_after=None,
ignore_errors=False
):
TASKS[name] = TaskSpec(
name, source, preprocessors, style, inf_preprocessor, inf_only, decode_image,
shuffle_after, ignore_errors)
@seqio.map_over_dataset
def add_image_size(ex):
if ex["image"].dtype == tf.string:
ex["image"] = tf.image.decode_image(ex['image'], channels=3, expand_animations=False)
img_h = tf.shape(ex["image"])[0]
img_w = tf.shape(ex["image"])[1]
ex["metadata/image_size"] = [img_w, img_h]
@dataclasses.dataclass
class TaskDatasetBuilder:
"""tf.data.Dataset builder for task after shuffling, sharding, and initial model pre-processing
have been applied"""
# This class is a simplified and customized version of seqio.Task
#
# The main differences are:
# 1: Does not prefetch by default, which wastes a small amount of RAM if we are using the
# dataset in a mixture which can just have its own top-level prefetch
# 2: Reduce threshold for memory caching which is way too high for image datasets by default
# 3: Can customize when shuffling occurs to help minimizes RAM usage, in general shuffling
# should happen before building image crops and tokenization so the shuffle and
# dataset checkpoint take less memory
# 4: Don't decoding images until after shuffling for the same reason
# 5: Support splitting with tfds.map_split so we never have to fall back to example sharding
# not default at the moment since its not well tested
# 6: Removes caching/output feature spec stuff from seqio that we don't need
name: str
source: Any
preprocessors: List
keep_metadata: bool
shuffle_after: int
sharding: str = "tfds_split"
decode_image: bool = False
ignore_errors: bool = False
def get_dataset(
self, # pytype: disable=signature-mismatch # overriding-default-value-checks
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
shuffle: bool = True,
shuffle_buffer_size: Optional[int] = 1000,
seed: Optional[int] = None,
shard_info: Optional[seqio.ShardInfo] = None,
num_epochs: Optional[int] = 1,
try_in_mem_cache: bool = True,
trim_output_features: bool=True
) -> tf.data.Dataset:
source = self.source
if self.sharding == "seqio":
if source.supports_arbitrary_sharding:
shard_data_source = True
elif shard_info:
# Whether we should shard at source or on the examples from the source.
shard_data_source = (
len(source.list_shards(split=split)) >= shard_info.num_shards
)
logging.info(
"Sharding at the %s: %d of %d",
"data source" if shard_data_source else "examples",
shard_info.index + 1,
shard_info.num_shards,
)
else:
# Call get_dataset on the source without a shard_info.
shard_data_source = True
shard_info = None
if "image" in source.tfds_dataset.info.features:
if not self.decode_image:
source.tfds_dataset._decoders = dict(image=tfds.decode.SkipDecoding())
if shard_data_source:
ds = source.get_dataset(
split=split, shuffle=shuffle, seed=seed, shard_info=shard_info)
else:
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
ds = ds.shard(shard_info.num_shards, shard_info.index)
elif self.sharding == "tfds_split":
# Shard with `tfds.even_splits`, which is seems to be recommended for mult-host training
# https://github.com/tensorflow/datasets/blob/master/docs/splits.md#tfdseven_splits--multi-host-training
assert isinstance(self.source, seqio.TfdsDataSource)
loader: seqio.LazyTfdsLoader = self.source.tfds_dataset
dataset, data_dir = loader.get_split_params(split)
shard_split = loader._map_split(split)
if shard_info and shard_info.num_shards > 1:
shard_split = tfds.even_splits(shard_split, n=shard_info.num_shards, drop_remainder=False)[shard_info.index]
else:
shard_split = shard_split
read_config = loader.read_config
read_config.shuffle_seed = seed
read_config.skip_prefetch = True
read_config.input_context = None
# Don't decode images until after shuffling to save RAM
if "image" in loader.info.features:
decoders = dict(image=tfds.decode.SkipDecoding())
else:
decoders = None
ds = tfds.load(
dataset,
split=shard_split,
data_dir=data_dir,
shuffle_files=shuffle,
download=True,
try_gcs=True,
read_config=read_config,
decoders=decoders
)
else:
raise NotImplementedError(self.sharding)
num_shards = shard_info.num_shards if shard_info else 1
if try_in_mem_cache and (
source.num_input_examples(split)
and source.num_input_examples(split)
< 10000 * num_shards
):
logging.info(f"Automatically caching small dataset in memory: {self.name}:{split}")
ds = ds.cache()
# We repeat before calling any (potentially) stochastic
# preprocessors in order to take new samples each epoch.
if num_epochs != 1:
ds = ds.repeat(num_epochs)
preprocessors = [
seqio.add_kwargs_to_transform(
_fn,
sequence_length=sequence_length,
output_features=None,
) for _fn in self.preprocessors
]
with seqio.utils.map_seed_manager(seed):
for fn in preprocessors[:self.shuffle_after]:
ds = fn(ds)
# Strip metadata before shuffling if possible so its doesn't waste space
if not self.keep_metadata:
ds = _strip_metadata(ds)
if shuffle:
if shuffle_buffer_size is None:
raise ValueError("Shuffle is true, but shuffle_buffer_size is None")
else:
ds = ds.shuffle(shuffle_buffer_size, seed=seed)
for fn in preprocessors[self.shuffle_after:]:
ds = fn(ds)
if self.ignore_errors:
ds = ds.ignore_errors(log_warning=True)
if trim_output_features:
ds = seqio.trim_dataset(ds, sequence_length, sequence_length)
return ds
def get_task(preprocessor, name, is_training, for_inference,
include_metadata=None, style_override=None) -> TaskDatasetBuilder:
"""Get a builder for task `name` that is pre-processed by `preprocessor`"""
task_spec = TASKS[name]
if for_inference is None:
for_inference = task_spec.inference_only
elif task_spec.inference_only and not for_inference:
raise ValueError(f"Inference=only task {task_spec.name} can only be used in inference mode")
if include_metadata is None:
include_metadata = for_inference
if preprocessor is not None:
style = style_override if style_override else task_spec.style
preprocessor = preprocessor.get_preprocessor(
is_training, for_inference, style, include_metadata)
preprocessor = [preprocessor]
else:
preprocessor = []
task_preprocessors = task_spec.preprocessors
if for_inference and task_spec.inference_preprocessors is not None:
task_preprocessors = task_spec.inference_preprocessors
if isinstance(task_spec.source, seqio.TfdsDataSource):
from seqio.utils import _TFDS_DATA_DIR_OVERRIDE
if _TFDS_DATA_DIR_OVERRIDE:
# Stop annoying override warnings flooding the log
task_spec.source.tfds_dataset._data_dir = None
return TaskDatasetBuilder(
task_spec.name,
task_spec.source,
task_preprocessors + preprocessor,
keep_metadata=include_metadata,
shuffle_after=(task_spec.shuffle_after if task_spec.shuffle_after
else len(task_spec.preprocessors)),
sharding="seqio",
decode_image=task_spec.decode_image,
ignore_errors=task_spec.ignore_errors,
)
add_task(
"coco_caption_2017",
source=seqio.TfdsDataSource(
tfds_name="coco_all:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(rekey, key_map={
"image/filename": ["image/filename"],
"image": ["image"],
"text": ["captions", "text"]
}),
functools.partial(flatten_parts, parts=["text"]),
],
inf_preprocessor=[
functools.partial(rekey, key_map={
"image/filename": ["image/filename"],
"image": ["image"],
"text": ["captions", "text"]
})
],
style="coco_captioning",
)
add_task(
"coco_captioning_karpathy",
source=seqio.TfdsDataSource(
tfds_name="coco_captioning_karpathy:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(text="captions"),
functools.partial(flatten_parts, parts=["text"]),
],
inf_preprocessor=[add_coco_url],
style="coco_captioning",
)
add_task(
"synth_counting",
source=seqio.TfdsDataSource(
tfds_name="synth_counting:0.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[5120:]", "validation": "train[:5120]"}
),
preprocessors=[synth_count_preprocessor],
inf_preprocessor=[synth_count_inf_preprocessor],
style="synth_counting",
)
add_task(
"khan_academy",
source=seqio.TfdsDataSource(
tfds_name="khan_academy:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[1024:]", "validation": "train[:1024]"}
),
preprocessors=[extract_khan_academy],
style="khan_academy",
)
for name, src in [
("vaia_qa_latex_image_math_subset", seqio.TfdsDataSource(
tfds_name=f"vaia_qa_latex_image_short_answer:0.1.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "validation"}
)),
("vaia_qa_latex_image_all", seqio.TfdsDataSource(
tfds_name=f"vaia_qa_latex_image_short_answer:0.1.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "validation"}
)),
]:
add_task(
f"{name}_short_answer",
source=src,
preprocessors=[
remove_is_long,
remove_has_multiple_parts,
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
],
style="vaia_qa",
)
add_task(
f"{name}_short_answer_first",
source=src,
preprocessors=[
remove_is_long,
remove_has_multiple_parts,
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
],
style="vaia_qa_short_answer_first",
)
add_task(
f"{name}_mc_only_short_answer",
source=src,
preprocessors=[
remove_is_long,
remove_has_multiple_parts,
filter_mc,
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
],
style="vaia_qa_short_answer",
)
add_task(
f"{name}_mc_only_short_answer_first",
source=src,
preprocessors=[
remove_is_long,
remove_has_multiple_parts,
filter_mc,
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
],
style="vaia_qa_short_answer_first",
)
add_task(
f"{name}_image_only_short_answer",
source=src,
preprocessors=[
image_only,
remove_is_long,
remove_has_multiple_parts,
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True),
],
style="vaia_qa_short_answer",
)
add_task(
f"{name}_image_only_short_answer_first",
source=src,
preprocessors=[
image_only,
remove_is_long,
remove_has_multiple_parts,
functools.partial(extract_vaia_qa_latex_image, add_short_answer=True, set_short_answer_first=True),
],
style="vaia_qa_short_answer_first",
)
add_task(
"vqa_online",
source=seqio.TfdsDataSource(
tfds_name="vqa_online:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "validation", "test": "validation"}
),
preprocessors=[
build_question_with_context,
extract_vqa_online,
],
style="vqa_online",
)
add_task(
"vqa_online_gpt_longQ_longA",
source=seqio.TfdsDataSource(
tfds_name="vqa_online_gpt_parsed:1.1.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "validation", "test": "validation"}
),
preprocessors=[
rename(question="question_long", answer="answer_long"),
extract_vqa_online,
],
style="vqa_online",
)
add_task(
"famous_birthdays",
source=seqio.TfdsDataSource(
tfds_name="famous_birth_days:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[5120:]", "validation": "train[:5120]"}
),
preprocessors=[
famous_birthdays_preprocessor,
functools.partial(name_entity_augmentation, p_high_color=0.0),
],
style="famous_birthdays",
)
add_task(
"wiki_art",
source=seqio.TfdsDataSource(
tfds_name="wiki_art:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[5120:]", "validation": "train[:5120]"}
),
preprocessors=[name_entity_augmentation, wiki_art_preprocessor],
style="wiki_art",
)
add_task(
"wiki_art_no_aug",
source=seqio.TfdsDataSource(
tfds_name="wiki_art:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[5120:]", "validation": "train[:5120]"}
),
preprocessors=[wiki_art_preprocessor],
style="wiki_art",
)
add_task(
"atlas_obscura",
source=seqio.TfdsDataSource(
tfds_name="atlas_obscura:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[5120:]", "validation": "train[:5120]"}
),
preprocessors=[
atlas_obscura_preprocessor,
mild_color_aug_preprocessor
],
style="atlas_obscura",
)
add_task(
"clocks",
source=seqio.TfdsDataSource(
tfds_name="clocks:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
clocks_preprocessor,
clock_augmentation
],
style="clocks",
shuffle_after=0
)
add_task(
"count_bench",
source=seqio.TfdsDataSource(
tfds_name="count_bench:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
count_bench_preprocessor,
],
style="count_bench",
)
add_task(
"tulu_v2_sft",
source=seqio.TfdsDataSource(
tfds_name="allenai__tulu_v2_sft_mixture:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[tulu_preprocessor],
style="tulu_v2",
)
# Pointing / Point+Count datasets
for is_count in [True, False]:
if is_count:
task = "point_count"
else:
task = "pointing"
add_task(
task,
source=seqio.TfdsDataSource(
tfds_name="pointing:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "validation"}
),
preprocessors=[
filter_points,
functools.partial(pointing_preprocessor, with_count=is_count),
split
],
style=task,
)
add_task(
task + "_eval", # pointing validation set
source=seqio.TfdsDataSource(
tfds_name="pointing:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
filter_points,
functools.partial(pointing_preprocessor, with_count=is_count),
split
],
style=task,
)
add_task(
task + "_high_freq",
source=seqio.TfdsDataSource(
tfds_name="count_qa:0.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[2048:]",
validation="train[:2048]"
)
),
preprocessors=[
filter_points,
fix_count_qa, # Fix a tfrecord bug TODO fix the underlying records
functools.partial(pointing_preprocessor, with_count=is_count),
split,
],
style=task,
)
add_task(
"fast_flickr_count_qa_" + task,
source=seqio.TfdsDataSource(
tfds_name="fast_flickr_count_qa:1.0.4",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(count_qa_preprocessor, with_count=is_count),
],
inf_preprocessor=[
functools.partial(count_qa_preprocessor, with_count=is_count, for_inference=True),
],
style=task,
)
add_task(
"countbench_qa",
source=seqio.TfdsDataSource(
tfds_name="countbench_qa:1.2.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
inf_only=True,
preprocessors=[
count_qa_preprocessor_inf,
],
style="point_count",
)
add_task(
f"pointing_test", # pointing set with segmentation ground truths
source=seqio.TfdsDataSource(
tfds_name="pointing:1.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
pointing_inf_preprocessor
],
style=task,
inf_only=True,
)
add_task(
"point_qa",
source=seqio.TfdsDataSource(
tfds_name="point_qa:0.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[512:]",
validation="train[:512]"
)
),
preprocessors=[extract_point_qa, split],
style="point_qa",
)
add_task(
"clocks_no_aug",
source=seqio.TfdsDataSource(
tfds_name="clocks:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
clocks_preprocessor
],
style="clocks",
)
add_task(
"clock_bench",
source=seqio.TfdsDataSource(
tfds_name="clock_bench:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
clock_bench_preprocessor
],
inf_only=True,
style="clocks",
)
add_task(
"wiki_data",
source=seqio.TfdsDataSource(
tfds_name="cockatoo_wiki:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
),
preprocessors=[extract_wiki_data],
style="wiki_data",
)
add_task(
"wiki_data_name",
source=seqio.TfdsDataSource(
tfds_name="cockatoo_wiki:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
),
preprocessors=[
extract_wiki_data_name,
mild_color_aug_preprocessor
],
style="wiki_data",
)
add_task(
"wiki_data_describe",
source=seqio.TfdsDataSource(
tfds_name="cockatoo_wiki:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
),
preprocessors=[extract_wiki_data_describe],
inf_only=True,
style="wiki_data",
)
add_task(
"wiki_data_describe",
source=seqio.TfdsDataSource(
tfds_name="cockatoo_wiki:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[10240:]", "validation": "train[:5120]", "test": "train[5120:10240]"}
),
preprocessors=[extract_wiki_data_describe],
inf_only=True,
style="wiki_data",
)
for name, src in [
("scifi_charts", seqio.TfdsDataSource(
tfds_name="sci_fi_charts:1.0.6",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[1024:]", "validation": "train[:1024]"}
)),
("scifi_table", seqio.TfdsDataSource(
tfds_name="sci_fi_table:1.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[1024:]", "validation": "train[:1024]"}
)),
("scifi_document", seqio.TfdsDataSource(
tfds_name="sci_fi_document:1.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[1024:]", "validation": "train[:1024]"}
)),
("scifi_diagram", seqio.TfdsDataSource(
tfds_name="sci_fi_diagram:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[1024:]", "validation": "train[:1024]"}
)),
("scifi_natural", seqio.TfdsDataSource(
tfds_name="sci_fi_natural:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[128:]", "validation": "train[:128]"}
)),
("scifi_nutrition", seqio.TfdsDataSource(
tfds_name="sci_fi_nutrition:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[128:]", "validation": "train[:128]"}
))
]:
add_task(
name + "_qa",
source=src,
preprocessors=[
remove_no_qa,
_preprocess_scifi,
extract_individual_vqa,
],
inf_preprocessor=[
remove_no_qa, _preprocess_scifi,
functools.partial(flatten_parts, parts=["question", "answer"]),
extract_individual_vqa,
],
style=name,
)
add_task(
name + "_qa_split",
source=src,
preprocessors=[
remove_no_qa,
_preprocess_scifi,
extract_individual_vqa,
split
],
inf_preprocessor=[
remove_no_qa, _preprocess_scifi,
functools.partial(flatten_parts, parts=["question", "answer"]),
extract_individual_vqa,
],
style=name,
)
add_task(
name + "_qa_exp",
source=src,
preprocessors=[
remove_no_qa,
_preprocess_scifi,
extract_scifi_qa_exp,
extract_individual_vqa,
],
inf_preprocessor=[
remove_no_qa, _preprocess_scifi,
extract_scifi_qa_exp,
functools.partial(flatten_parts, parts=["question", "answer"]),
extract_individual_vqa,
],
style=name + "_qa_exp",
)
add_task(
name + "_qa_exp_split",
source=src,
preprocessors=[
remove_no_qa,
_preprocess_scifi,
extract_scifi_qa_exp,
extract_individual_vqa,
split,
],
inf_preprocessor=[
remove_no_qa, _preprocess_scifi,
extract_scifi_qa_exp,
functools.partial(flatten_parts, parts=["question", "answer"]),
extract_individual_vqa,
],
style=name + "_qa_exp",
)
add_task(
name + "_exp",
source=src,
preprocessors=[
remove_no_qa,
_preprocess_scifi,
scifi_explanation_only,
extract_individual_vqa,
split
],
style=name + "_exp"
)
add_task(
name + "_demo",
source=src,
preprocessors=[
remove_no_qa,
_preprocess_scifi,
extract_scifi_qa_demo,
extract_individual_vqa,
split
],
style="scifi_demo"
)
add_task(
"chart_qa_scifi",
source=seqio.TfdsDataSource(
tfds_name="chart_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
extract_individual_vqa,
],
style="scifi_charts_qa_exp",
)
add_task(
"chart_qa_prompting",
source=seqio.TfdsDataSource(
tfds_name="chart_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
chartqa_prompting,
extract_individual_vqa,
],
style="chart_qa",
)
add_task(
"chart_qa_prompting_explanation",
source=seqio.TfdsDataSource(
tfds_name="chart_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
chartqa_explanation,
extract_individual_vqa,
],
style="chart_qa",
)
add_task(
"coco_captioning_karpathy_multi",
source=seqio.TfdsDataSource(
tfds_name="coco_captioning_karpathy:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[rename(text="captions")],
inf_preprocessor=[add_coco_url],
style="coco_captioning",
)
add_task(
"coco_caption_2017_grouped",
source=seqio.TfdsDataSource(
tfds_name="coco_all:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(
rekey, key_map={
"image/filename": ["image/filename"],
"image": ["image"],
"text": ["captions", "text"]
}),
join_captions
],
style="coco_captioning_multiple",
)
add_task(
"llava_pretrain",
source=seqio.TfdsDataSource(
tfds_name="llava_pretrain:1.0.0",
tfds_data_dir="gs://mm-olmo-datasets/",
splits=dict(
train="train[4096:]",
validation="train[:4096]"
)
),
preprocessors=[extract_llava],
style="web_caption"
)
add_task(
"rohun_images",
source=seqio.TfdsDataSource(
tfds_name="rohun_images:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[],
style="long_caption",
inf_only=True
)
add_task(
"dense_caption_eval",
source=seqio.TfdsDataSource(
tfds_name="dense_captioning_eval:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(validation="train")
),
preprocessors=[],
style="long_caption",
inf_only=True
)
add_task(
"dense_caption_eval_dbg",
source=seqio.TfdsDataSource(
tfds_name="dense_captioning_eval:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(validation="train")
),
preprocessors=[
lambda ds: ds.filter(lambda x: x["url"] == "https://explore-multimodal-datasets.s3.us-west-2.amazonaws.com/eval-set/v0/eval-set/a211be07e2c9c722ef75093026a608856bd07ad935ebdedea6f2944b1f2d2b0e.jpg")
],
style="long_caption",
inf_only=True
)
add_task(
"dense_caption_sample",
source=seqio.TfdsDataSource(
tfds_name="dense_captioning_eval:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
validation="train"
)
),
preprocessors=[select_dense_caption_sample],
style="long_caption",
)
add_task(
"cockatoo_1per_caption_287k",
source=seqio.TfdsDataSource(
tfds_name="cockatoo_1per_caption_287k:1.0.5",
tfds_data_dir="gs://mm-olmo-data/",
splits=dict(
train="train[5120:]",
validation="train[:5120]"
)
),
preprocessors=[
rename(text="caption"),
],
style="long_caption"
)
def _filter_large_ratio(ds):
return ds.filter(
lambda x: tf.shape(x["image"])[0] > tf.shape(x["image"])[1]*2
)
add_task(
f"cockatoo_dbg",
source= seqio.TfdsDataSource(
tfds_name="cockatoo_476k:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:]",
validation="train[:5120]"
)
)
,
preprocessors=[
_filter_large_ratio,
extract_caption_and_transcript
],
style=["long_caption", "transcript"]
)
for name, src in [
("712k_sept6", seqio.TfdsDataSource(
tfds_name="cockatoo_712k_sept6:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:]",
validation="train[:5120]"
)
)),
("476k", seqio.TfdsDataSource(
tfds_name="cockatoo_476k:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:]",
validation="train[:5120]"
)
)),
("476k_gpt_captions", seqio.TfdsDataSource(
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:]",
validation="train[:5120]"
)
)),
("100k_of_476k_gpt_captions", seqio.TfdsDataSource(
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:105120]",
validation="train[:5120]"
)
)),
("200k_of_476k_gpt_captions", seqio.TfdsDataSource(
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:205120]",
validation="train[:5120]"
)
)),
("300k_of_476k_gpt_captions", seqio.TfdsDataSource(
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:305120]",
validation="train[:5120]"
)
)),
("400k_of_476k_gpt_captions", seqio.TfdsDataSource(
tfds_name="cockatoo_476k_gpt_captions:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:405120]",
validation="train[:5120]"
)
)),
("400k_of_476k", seqio.TfdsDataSource(
tfds_name="cockatoo_476k:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:405120]",
validation="train[:5120]"
)
)),
("300k_of_476k", seqio.TfdsDataSource(
tfds_name="cockatoo_476k:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:305120]",
validation="train[:5120]"
)
)),
("200k_of_476k", seqio.TfdsDataSource(
tfds_name="cockatoo_476k:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:205120]",
validation="train[:5120]"
)
)),
("100k_of_476k", seqio.TfdsDataSource(
tfds_name="cockatoo_476k:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:105120]",
validation="train[:5120]"
)
)),
("276k", seqio.TfdsDataSource(
tfds_name="cockatoo:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:]",
validation="train[:5120]"
)
)),
("180k", seqio.TfdsDataSource(
tfds_name="cockatoo:1.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[4096:]",
validation="train[:4096]"
)
)),
("84k_claude_captions", seqio.TfdsDataSource(
tfds_name="cockatoo_84k_claude_captions:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[1000:]",
validation="train[:1000]"
)
)),
]:
add_task(
f"cockatoo_{name}",
source=src,
preprocessors=[extract_caption],
style="long_caption"
)
add_task(
f"cockatoo_and_transcript_{name}",
source=src,
preprocessors=[extract_caption_and_transcript],
style=["long_caption", "transcript"]
)
add_task(
f"cockatoo_and_transcript_stratified_{name}",
source=src,
preprocessors=[
extract_caption_and_transcript,
# put this here to hack seqio into repeating the dataset after
# `extract_caption_and_transcript` which will properly stratify the transcripts
seqio.CacheDatasetPlaceholder(),
],
style=["long_caption", "transcript"]
)
add_task(
f"cockatoo_and_all_transcripts_{name}",
source=src,
preprocessors=[extract_caption_and_all_transcripts],
style=["long_caption", "transcript", "transcript", "transcript"]
)
add_task(
f"cockatoo_all_transcripts_{name}",
source=src,
preprocessors=[extract_all_transcripts],
style="transcript"
)
add_task(
f"cockatoo_transcripts_{name}",
source=src,
preprocessors=[extract_transcript],
style="transcript"
)
TFRECORD_IMAGE_TEXT_FEATURES = {
'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
'text':tf.io.FixedLenFeature(shape=(), dtype=tf.string),
}
add_task(
"laion400m",
source=seqio.TFExampleDataSource(
split_to_filepattern={
"train": os.path.join("gs://unified-io-2-us-east/", "pretrain-datasets", "laion400m", "1.0.0", "laion400m-train*"),
},
feature_description=TFRECORD_IMAGE_TEXT_FEATURES,
),
preprocessors=[
functools.partial(rekey, key_map={
"image": ["image"],
"text": ["text"]
}),
],
style="laion",
)
add_task(
"laion_2B",
source=seqio.TFExampleDataSource(
split_to_filepattern={
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "laion2b_en", "1.0.0", "laion2b_en-train*"),
},
feature_description=TFRECORD_IMAGE_TEXT_FEATURES,
),
preprocessors=[
functools.partial(rekey, key_map={
"image": ["image"],
"text": ["text"]
}),
],
style="laion",
)
add_task(
"region_caption_vg",
source=seqio.TfdsDataSource(
tfds_name="vg:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[region_captions_to_dense],
style="region_captions",
)
add_task(
"pdfa_eng_wds",
source=seqio.TfdsDataSource(
tfds_name="pdfa_eng_wds:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(max_words, max_words=400),
format_pdfa_eng_wds
],
style="pdfa_eng_wds",
)
add_task(
"idl_words",
source=seqio.TfdsDataSource(
tfds_name="idl_words:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[],
style="idl_words",
)
open_image_v6_keys_to_features = {
'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
'image_id': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
'detection/label':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
'detection/bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
'detection/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
'vrd/sub_label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
'vrd/obj_label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
'vrd/sub_bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
'vrd/obj_bbox':tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
'vrd/relation': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
'vrd/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
'cap/cap_caption': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
'cap/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
'seg/masks': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
'seg/num':tf.io.FixedLenFeature(shape=(), dtype=tf.int64),
'seg/label': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.string, allow_missing=True),
'seg/bbox': tf.io.FixedLenSequenceFeature(shape=(), dtype=tf.float32, allow_missing=True),
}
add_task(
"localized_narratives_v6",
source=seqio.TFExampleDataSource(
split_to_filepattern={
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "open_image_v6", "1.0.0", "open_image_v6-train*"),
},
feature_description=open_image_v6_keys_to_features,
),
preprocessors=[extract_localized_narrative],
style="localized_narratives",
)
add_task(
"lvis_objects",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="lvis:1.2.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
extract_lvis,
region_captions_to_dense,
],
style="lvis_objects",
)
add_task(
"open_images_with_objects",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TFExampleDataSource(
split_to_filepattern={
"train": os.path.join(MULTITASK_TFDS_DATA_DIR, "open_image_v6", "1.0.0", "open_image_v6-train*"),
},
feature_description=open_image_v6_keys_to_features,
),
preprocessors=[
extract_open_images_boxes,
region_captions_to_dense,
],
style="visual_narratives_with_objects",
)
add_task(
"cockatoo_with_acc_476k_gpt_captions",
source=seqio.TfdsDataSource(
tfds_name="cockatoo_with_acc_476k_gpt_captions:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:]",
validation="train[:5120]"
)
),
preprocessors=[accuracy_conditioned_joint],
inf_preprocessor=[functools.partial(accuracy_conditioned_joint, is_eval=True)],
style=None
)
add_task(
"dense_caption_eval_with_acc",
source=seqio.TfdsDataSource(
tfds_name="dense_captioning_eval:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(validation="train")
),
preprocessors=[functools.partial(accuracy_conditioned_joint, is_eval=True)],
style="long_caption",
inf_only=True
)
# ************************
# VQA Datasets
# ************************
add_task(
"science_qa_img",
source=seqio.TfdsDataSource(
tfds_name="science_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
image_only,
rename(answer_idx="answer"),
build_question_with_hint,
format_multiple_choice_qa
],
style="science_qa",
)
add_task(
"tabwmp_da",
source=seqio.TfdsDataSource(
tfds_name="tab_mwp:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "dev", "test": "test"}
),
preprocessors=[
rename(text="answer")
],
style="tabwmp_da",
)
add_task(
"figure_qa",
source=seqio.TfdsDataSource(
tfds_name="figure_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train1", "validation": "validation1", "test": "no_annot_test1"}
),
preprocessors=[extract_figureqa, extract_individual_vqa],
style="figure_qa",
)
add_task(
"figure_qa_zero_shot",
source=seqio.TfdsDataSource(
tfds_name="figure_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train1", "validation": "validation1", "test": "no_annot_test1"}
),
preprocessors=[extract_figureqa, convert_figureqa_answer, extract_individual_vqa],
style="figure_qa",
)
add_task(
"plot_qa",
source=seqio.TfdsDataSource(
tfds_name="plot_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[extract_figureqa, extract_individual_vqa],
inf_preprocessor=[
extract_figureqa,
functools.partial(flatten_parts, parts=["questions", "answer", "question_id"]),
extract_individual_vqa
],
style="plot_qa",
)
add_task(
"ai2_diagram",
source=seqio.TfdsDataSource(
tfds_name="ai2_diagram:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
),
preprocessors=[
rename(choices="answer_texts", answer_idx="correct_answer"),
format_multiple_choice_qa
],
style="ai2_diagram",
)
add_task(
"ai2_diagram_v2",
source=seqio.TfdsDataSource(
tfds_name="ai2_diagram_v2:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
rename(choices="answer_texts", answer_idx="correct_answer"),
format_ai2d
],
style="ai2_diagram",
)
add_task(
"ai2_diagram_v2_transparent",
source=seqio.TfdsDataSource(
tfds_name="ai2_diagram_v2_transparent:1.0.5",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
rename(choices="answer_texts", answer_idx="correct_answer"),
format_ai2d
],
style="ai2_diagram",
)
# ai2_diagram_v2 mixed with addiitonal abc label questions with transparent box.
# Shares the same image split as ai2_diagram_v2.
add_task(
"ai2_diagram_v2_mix_transparent",
source=seqio.TfdsDataSource(
tfds_name="ai2_diagram_v2_mix_transparent:1.0.6",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={
"train": "train_mix",
"validation": "validation_mix",
"test": "test_mix", # test should only use either transparent or opaque
# "test": "test_opaque",
}
),
preprocessors=[
rename(choices="answer_texts", answer_idx="correct_answer"),
format_ai2d
],
style="ai2_diagram",
)
add_task(
"ai2_diagram_v2_mix_transparent_one_style",
source=seqio.TfdsDataSource(
tfds_name="ai2_diagram_v2_mix_transparent:1.0.6",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={
"train": "train_mix",
"validation": "validation_mix",
"test": "test_mix", # test should only use either transparent or opaque
# "test": "test_opaque",
}
),
preprocessors=[
rename(choices="answer_texts", answer_idx="correct_answer"),
functools.partial(format_ai2d, variable_style=False),
],
style="ai2_diagram",
)
for src, test_sets in [
["refclef_unc", ["testA", "testB", "testC", "testAB", "testBC"]],
["refcoco_unc", ["testA", "testB"]],
["refcocoplus_unc", ["testA", "testB"]],
["refcocog_umd", ["test"]],
]:
if "coco" in src:
add_url = [add_coco_url]
else:
add_url = []
splits = {x: x for x in test_sets}
splits.update({"train": "train", "validation": "val"})
add_task(
src,
source=seqio.TfdsDataSource(
tfds_name=f"{src}:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=splits
),
preprocessors=[refexp],
inf_preprocessor=add_url + [
refexp_inf,
# Flatten objects
functools.partial(flatten_parts, parts=["refexp", "metadata/bbox"]),
# Flatten expressions
functools.partial(flatten_parts, parts=["refexp"])
],
style="refexp",
decode_image=True,
)
add_task(
src + "_pointing",
source=seqio.TfdsDataSource(
tfds_name=f"{src}:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=splits
),
preprocessors=[refexp_pointing],
inf_preprocessor=add_url + [
refexp_pointing_inf,
functools.partial(flatten_parts, parts=["refexp", "metadata/bbox", "metadata/mask", "metadata/answer"]),
functools.partial(flatten_parts, parts=["refexp"])
],
decode_image=True,
style="refexp_pointing",
)
# FIXME
add_task(
"ai2_diagram_test",
source=seqio.TfdsDataSource(
tfds_name="ai2_diagram:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
),
preprocessors=[
rename(choices="answer_texts", answer_idx="correct_answer"),
format_multiple_choice_qa
],
style="ai2_diagram",
)
add_task(
"gqa",
source=seqio.TfdsDataSource(
tfds_name="gqa:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
functools.partial(format_gqa, is_balanced=True),
extract_individual_vqa,
],
inf_preprocessor=[
functools.partial(format_gqa, is_balanced=True),
extract_individual_vqa,
],
style="gqa",
)
add_task(
"gqa_multi",
source=seqio.TfdsDataSource(
tfds_name="gqa:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
functools.partial(format_gqa, is_balanced=True, flatten=False),
extract_individual_vqa,
],
inf_preprocessor=[
functools.partial(format_gqa, is_balanced=True, flatten=False),
extract_individual_vqa,
],
style="gqa",
)
add_task(
"text_vqa",
source=seqio.TfdsDataSource(
tfds_name="text_vqa:1.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(
rekey, key_map={
"image": ["image"],
"questions": ["question"],
"answers": ["answers"],
"id": ["question_id"]
}),
extract_individual_vqa,
],
style="text_vqa",
)
add_task(
"okvqa",
source=seqio.TfdsDataSource(
tfds_name="ok_vqa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
rename(example_id="question_id"),
add_coco_url,
extract_individual_vqa,
],
style="okvqa",
)
add_task(
"a_okvqa_da",
source=seqio.TfdsDataSource(
tfds_name="a_ok_vqa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(**{
"example_id": "question_id",
"answers": "direct_answers",
"metadata/difficult_direct_answer": "difficult_direct_answer"
}),
extract_individual_vqa,
],
inf_preprocessor=[
filter_difficult_direct_answer,
rename(**{
"example_id": "question_id",
"answers": "direct_answers",
"metadata/difficult_direct_answer": "difficult_direct_answer"
}),
add_coco_url,
extract_individual_vqa,
],
style="a_okvqa_da",
)
add_task(
"a_okvqa_mc",
source=seqio.TfdsDataSource(
tfds_name="a_ok_vqa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(**{
"example_id": "question_id",
"metadata/difficult_direct_answer": "difficult_direct_answer",
"answer_idx": "correct_choice_idx"
}),
add_coco_url,
format_multiple_choice_qa,
],
style="a_okvqa_mc",
)
add_task(
"dv_qa",
source=seqio.TfdsDataSource(
tfds_name="dv_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val_easy"}
),
preprocessors=[
extract_figureqa,
extract_individual_vqa,
],
inf_preprocessor=[
extract_figureqa,
flatten_vqa,
extract_individual_vqa
],
style="dv_qa",
)
@seqio.map_over_dataset
def add_image_question_example_id(ex):
key = tf.strings.join([ex["question"], "\n\n", ex["image"]])
ex["metadata/example_id"] = tf.strings.to_hash_bucket(key, 2**30)
return ex
add_task(
"chart_qa",
source=seqio.TfdsDataSource(
tfds_name="chart_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
add_image_question_example_id,
extract_individual_vqa,
],
style="chart_qa",
)
add_task(
"chart_qa_ex",
source=seqio.TfdsDataSource(
tfds_name="chart_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
extract_individual_vqa,
],
style="scifi_charts_qa_exp",
)
add_task(
"chart_qa_weighted",
source=seqio.TfdsDataSource(
tfds_name="chart_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(question="query", answer="label", **{"metadata/is_human": "is_human"}),
extract_individual_vqa,
functools.partial(reweight_chartqa, human=2*20901/(20901+7398), aug=2*7398/(20901+7398)),
],
style="chart_qa",
)
add_task(
"chart_qa_human",
source=seqio.TfdsDataSource(
tfds_name="chart_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(question="query", answer="label"),
add_image_question_example_id,
filter_human,
extract_individual_vqa,
],
style="chart_qa",
)
add_task(
"chart_qa_aug",
source=seqio.TfdsDataSource(
tfds_name="chart_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[
rename(question="query", answer="label"),
filter_aug,
extract_individual_vqa,
],
style="chart_qa",
)
add_task(
"doc_qa",
source=seqio.TfdsDataSource(
tfds_name="doc_qa:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[fix_doqa_url, extract_individual_vqa],
style="doc_qa",
)
add_task(
"ocr_qa",
source=seqio.TfdsDataSource(
tfds_name="ocr_vqa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[extract_individual_vqa],
inf_preprocessor=[flatten_vqa, extract_individual_vqa],
style="ocr_vqa",
)
add_task(
"st_qa",
source=seqio.TfdsDataSource(
tfds_name="st_vqa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train[1024:]", "validation": "train[:1024]", "test": "test"}
),
preprocessors=[extract_individual_vqa],
inf_preprocessor=[extract_individual_vqa],
style="st_qa",
)
add_task(
"tally_qa",
source=seqio.TfdsDataSource(
tfds_name="tally_qa:1.0.2",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "test"}
),
preprocessors=[
extract_tally_qa,
extract_individual_vqa
],
inf_preprocessor=[
extract_tally_qa,
flatten_vqa,
extract_individual_vqa
],
style="tally_qa",
)
add_task(
"info_qa",
source=seqio.TfdsDataSource(
tfds_name="info_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[extract_individual_vqa],
style="info_qa",
)
add_task(
"android_control",
source=seqio.TfdsDataSource(
tfds_name="android_control:2.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[extract_android_control],
style="android_control",
)
for mode in ["ll", "hl", "hl_ll", "hl_cot"]:
add_task(
f"android_control_{mode}",
source=seqio.TfdsDataSource(
tfds_name="android_control:2.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "train", "validation": "val", "test": "test"}
),
preprocessors=[functools.partial(extract_andriod_control_inf, mode=mode)],
style="android_control",
)
map_coco_vqa = functools.partial(rekey, key_map={
"image": ["image"],
"questions": ["vqa", "questions"],
"answers": ["vqa", "answers"],
"id": ["vqa", "id"],
"metadata/image_url": ["metadata/image_url"],
})
add_task(
"coco_2017_vqa",
source=seqio.TfdsDataSource(
tfds_name="coco_all:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
add_coco_url,
map_coco_vqa,
flatten_vqa,
extract_individual_vqa
],
style="vqa2",
)
add_task(
"cockatoo_qa",
source=seqio.TfdsDataSource(
tfds_name="cockatoo_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[5120:]",
validation="train[:5120]"
)
),
preprocessors=[rename(text="answer")],
style=None,
)
add_task(
"synthetic_qa_v3",
source=seqio.TfdsDataSource(
tfds_name="synthetic_qa_v3:0.0.4",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[2048:]",
validation="train[:2048]"
)
),
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
style="synthetic_qa",
)
add_task(
"synthetic_qa_v3_style_tag",
source=seqio.TfdsDataSource(
tfds_name="synthetic_qa_v3:0.0.4",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[2048:]",
validation="train[:2048]"
)
),
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
style="llm_qa",
)
add_task(
"synthetic_qa_v3_as_user_qa",
source=seqio.TfdsDataSource(
tfds_name="synthetic_qa_v3:0.0.4",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[2048:]",
validation="train[:2048]"
)
),
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
style="user_qa",
)
add_task(
"synthetic_qa_v3_multi_turn",
source=seqio.TfdsDataSource(
tfds_name="synthetic_qa_v3:0.0.4",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[2048:]",
validation="train[:2048]"
)
),
preprocessors=[extract_cockatoo_qa_v2, filter_single_turn, prefix_how_many_messages],
style="synthetic_qa",
)
NE_SHARDS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
for i in NE_SHARDS:
add_task(
f"named_entity{i}",
source=seqio.TfdsDataSource(
tfds_name=f"named_entities_qa_{i}_of_18:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[1024:]",
validation="train[:1024]"
)
),
preprocessors=[filter_named_entity, extract_named_entity, extract_individual_vqa],
inf_preprocessor=[
filter_named_entity,
extract_named_entity,
flatten_vqa,
extract_individual_vqa
],
style="named_entity",
ignore_errors=True
)
add_task(
"user_qa",
source=seqio.TfdsDataSource(
tfds_name="user_qa:0.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits=dict(
train="train[2048:]",
validation="train[:2048]"
)
),
preprocessors=[extract_cockatoo_qa_v2, prefix_how_many_messages],
style="user_qa",
)
add_task(
"user_questions_for_elo",
source=seqio.TfdsDataSource(
tfds_name="user_questions_for_elo:0.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[functools.partial(extract_individual_vqa, test=True)],
inf_only=True,
style="demo",
)
def _filter_by_id(ds, prediction_file, max_seq_len):
with open(prediction_file) as f:
predictions = json.load(f)
is_long = []
lens = []
tokenizer = build_tokenizer("hf-Qwen/Qwen2-7B")
for pred in predictions:
n_tokens = len(tokenizer.encode(pred["prediction"]))
lens.append(n_tokens)
if n_tokens >= max_seq_len:
is_long.append(pred["example_id"])
is_long = tf.constant(is_long)
logging.info(f"Filtering for {len(is_long)} ids")
return ds.filter(lambda ex: tf.reduce_any(ex["example_id"] == is_long))
add_task(
"user_questions_for_elo",
source=seqio.TfdsDataSource(
tfds_name="user_questions_for_elo:0.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[functools.partial(extract_individual_vqa, test=True)],
inf_only=True,
style="demo",
)
add_task(
"user_questions_for_elo_long",
source=seqio.TfdsDataSource(
tfds_name="user_questions_for_elo:0.0.3",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(_filter_by_id, prediction_file="/weka/oe-training-default/chrisc/cockatoo/models/uber-model-v11/70b-335-30k-3.2-resume8k-noopt/predictions-ck20000-user_questions_for_elo-test/predictions.json", max_seq_len=230),
functools.partial(extract_individual_vqa, test=True)
],
inf_only=True,
style="demo",
)
add_task(
"coco_2014_vqa",
source=seqio.TfdsDataSource(
tfds_name="coco_2014_all:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
add_coco_url,
map_coco_vqa,
flatten_vqa,
extract_individual_vqa
],
inf_preprocessor=[
add_coco_url,
map_coco_vqa,
flatten_vqa,
extract_individual_vqa
],
style="vqa2",
)
add_task(
"coco_2014_vqa_multi",
source=seqio.TfdsDataSource(
tfds_name="coco_2014_all:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
add_coco_url,
map_coco_vqa,
extract_individual_vqa
],
inf_preprocessor=[
add_coco_url,
map_coco_vqa,
flatten_vqa,
extract_individual_vqa
],
style="vqa2",
)
add_task(
"coco_2017_vqa_multi",
source=seqio.TfdsDataSource(
tfds_name="coco_all:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
add_coco_url,
map_coco_vqa,
extract_individual_vqa
],
inf_preprocessor=[
add_coco_url,
map_coco_vqa,
flatten_vqa,
extract_individual_vqa
],
style="vqa2",
)
add_task(
"vqa_v2_test",
source=seqio.TfdsDataSource(
tfds_name="coco_test_all:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(rekey, key_map={
"image": ["image"],
"questions": ["vqa", "questions"],
"answers": ["vqa", "answers"],
"id": ["vqa", "id"],
}),
flatten_vqa,
functools.partial(extract_individual_vqa, test=True)
],
style="vqa2",
inf_only=True
)
# ************************
# Eval-only Datasets
# ************************
add_task(
"seed_bench_test",
source=seqio.TfdsDataSource(
tfds_name="seed_bench:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
format_multiple_choice_qa,
],
style="a_okvqa_mc",
inf_only=True
)
add_task(
"pope_test",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="pope:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
add_coco_url,
extract_individual_vqa
],
style="vqa2",
inf_only=True
)
MME_SOURCE = seqio.TfdsDataSource(
tfds_name="mme:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
)
add_task(
"mme_test",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=MME_SOURCE,
preprocessors=[
functools.partial(flatten_parts, parts=["questions", "answers"]),
rename(question="questions", answer="answers"),
extract_individual_vqa,
],
style="vqa2",
inf_only=True
)
add_task(
"real_world_qa_test",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="real_world_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(
format_multiple_style_qa,
types=['multiple_choice', 'short_answer'],
styles=['a_okvqa_mc', 'vqa2'],
default_style="a_okvqa_mc",
),
],
style=None,
inf_only=True
)
add_task(
"real_world_qa_no_instruction",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="real_world_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(
functools.partial(format_multiple_style_qa, strip_instruction=True),
types=['multiple_choice', 'short_answer'],
styles=['a_okvqa_mc', 'vqa2'],
default_style="a_okvqa_mc",
),
],
style=None,
inf_only=True
)
add_task(
"real_world_qa_dbg",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="real_world_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(
format_multiple_style_qa,
types=['multiple_choice', 'short_answer'],
styles=['user_qa', 'user_qa'],
default_style="user_qa",
),
],
style=None,
inf_only=True
)
add_task(
"mmmu",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="mmmu:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"train": "dev"},
),
preprocessors=[
rename(img_type="metadata/img_type"),
functools.partial(
extract_mmmu,
types=['multiple-choice', 'open'],
styles=['a_okvqa_mc', 'vqa2'],
default_style="a_okvqa_mc",
),
],
style=None,
)
add_task(
"mmmu_test",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="mmmu:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "validation", "test": "test"},
),
preprocessors=[
rename(img_type="metadata/img_type"),
extract_mmmu,
],
style=None,
inf_only=True
)
for style in ["vaia_qa", "vaia_qa_short_answer_first", "vqa_online", ]:
add_task(
f"mmmu_test_{style}",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="mmmu:1.0.0",
# tfds_name="mmmu_khan_academy:1.0.1",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "validation", "test": "test", "dev": "dev"},
),
preprocessors=[
rename(img_type="metadata/img_type"),
extract_mmmu_cot,
],
style=style,
inf_only=True
)
add_task(
"math_vista_test",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="math_vista:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "testmini", "test": "test"},
),
preprocessors=[
functools.partial(rekey, key_map={
"id": ["id"],
"query": ["query"],
"image": ["image"],
"choices": ["choices"],
"answer": ["answer"],
"metadata/question_type": ["question_type"],
"metadata/answer_type": ["answer_type"],
"metadata/precision": ["precision"],
"metadata/split": ["metadata/split"],
}),
functools.partial(extract_math_vista, styles=['a_okvqa_mc', 'vqa2']),
],
style=None,
inf_only=True
)
add_task(
"math_vista_v2",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="math_vista:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "testmini", "test": "test"},
),
preprocessors=[
functools.partial(rekey, key_map={
"id": ["id"],
"query": ["query"],
"image": ["image"],
"choices": ["choices"],
"answer": ["answer"],
"metadata/question_type": ["question_type"],
"metadata/answer_type": ["answer_type"],
"metadata/precision": ["precision"],
"metadata/split": ["metadata/split"],
}),
reformat_math_vista,
functools.partial(
extract_math_vista,
styles=['a_okvqa_mc', 'vqa2'],
),
],
style=None,
inf_only=True
)
MM_BENCH_SRC = seqio.TfdsDataSource(
tfds_name="mmbench:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "dev", "test": "test"},
)
add_task(
"mmbench_test",
source=MM_BENCH_SRC,
preprocessors=[format_mmbench],
style="a_okvqa_mc",
inf_only=True
)
add_task(
"sugar_crepe_test",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="sugar_crepe:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
add_coco_url,
functools.partial(flatten_parts, parts=["choices", "answer_idx", "metadata/answer_type"]),
format_multiple_choice_qa,
],
style="a_okvqa_mc",
inf_only=True
)
add_task(
"blink_test",
# A TfdsTask takes in a TFDS name instead of a tf.data.Dataset function.
source=seqio.TfdsDataSource(
tfds_name="blink:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
),
preprocessors=[
functools.partial(rekey, key_map={
"id": ["id"],
"question": ["prompt"],
"image": ["image_concat"],
"choices": ["choices"],
"answer_idx": ["answer_idx"],
"metadata/subtask": ["metadata/subtask"],
"metadata/question": ["question"],
}),
format_multiple_choice_qa,
output_options,
],
style="a_okvqa_mc",
inf_only=True
)
add_task(
"oscarbench_qa",
source=seqio.TfdsDataSource(
tfds_name="oscarbench_qa:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "val"}
),
preprocessors=[oscar_preprocessor],
style="oscarbench_qa"
)
add_task(
"charxiv",
source=seqio.TfdsDataSource(
tfds_name="charxiv:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "validation", "test": "test"}
),
preprocessors=[charxiv_preprocessor, extract_individual_vqa],
inf_preprocessor=[
charxiv_preprocessor,
functools.partial(flatten_parts, parts=["question", "answer"]),
extract_individual_vqa,
],
style="charxiv",
)
add_task(
"charxiv_descriptive",
source=seqio.TfdsDataSource(
tfds_name="charxiv:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "validation", "test": "test"}
),
preprocessors=[charxiv_descriptive_preprocessor, extract_individual_vqa],
inf_preprocessor=[
charxiv_descriptive_preprocessor,
functools.partial(flatten_parts, parts=["question", "answer"]),
extract_individual_vqa,
],
style="charxiv_descriptive",
)
add_task(
"charxiv_reasoning",
source=seqio.TfdsDataSource(
tfds_name="charxiv:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "validation", "test": "test"}
),
preprocessors=[charxiv_reasoning_preprocessor, extract_individual_vqa],
style="charxiv_reasoning",
)
for tablevqa_name in ["fintabnetqa", "vwtq", "vwtq_syn"]:
add_task(
tablevqa_name,
source=seqio.TfdsDataSource(
tfds_name=f"{tablevqa_name}:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "test[:125]", "test": "test"}
),
preprocessors=[tablevqa_preprocessor, extract_individual_vqa],
style=tablevqa_name,
)
add_task(
"vtabfact",
source=seqio.TfdsDataSource(
tfds_name="vtabfact:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "test[:125]", "test": "test"}
),
preprocessors=[vtabfact_preprocessor, extract_individual_vqa],
style="vtabfact",
)
add_task(
"nutrition_fact",
source=seqio.TfdsDataSource(
tfds_name="nutrition_fact:1.0.0",
tfds_data_dir=MULTITASK_TFDS_DATA_DIR,
splits={"validation": "test", "test": "test"}
),
preprocessors=[nutrition_fact_preprocessor, extract_individual_vqa],
inf_preprocessor=[
nutrition_fact_preprocessor,
functools.partial(flatten_parts, parts=["question", "answer"]),
extract_individual_vqa,
],
style="nutrition_fact",
inf_only=True
)
for k in ["chart_qa", "info_qa", "doc_qa", "text_vqa", "coco_2014_vqa",
"ai2_diagram_v2_mix_transparent", "chart_qa_human"]:
TASKS[k + "_demo"] = dataclasses.replace(TASKS[k], style="demo")