| """ | |
| helper util to calculate dataset lengths | |
| """ | |
| import numpy as np | |
| def get_dataset_lengths(dataset): | |
| if "length" in dataset.data.column_names: | |
| lengths = np.array(dataset.data.column("length")) | |
| elif "position_ids" in dataset.data.column_names: | |
| position_ids = dataset.data.column("position_ids") | |
| lengths = np.array([x[-1] + 1 for x in position_ids]) | |
| else: | |
| input_ids = dataset.data.column("input_ids") | |
| lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) | |
| return lengths | |
| return lengths | |