Spaces:
Running
Running
| import logging | |
| from dataclasses import dataclass | |
| import pyarrow as pa | |
| import datasets | |
| logger = logging.getLogger(__name__) | |
| FEATURES = datasets.Features( | |
| { | |
| "text": datasets.Value("string"), | |
| } | |
| ) | |
| class PubChemConfig(datasets.BuilderConfig): | |
| """BuilderConfig for text files.""" | |
| encoding: str = "utf-8" | |
| chunksize: int = 10 << 20 # 10MB | |
| class PubChem(datasets.ArrowBasedBuilder): | |
| BUILDER_CONFIG_CLASS = PubChemConfig | |
| def _info(self): | |
| return datasets.DatasetInfo(features=FEATURES) | |
| def _split_generators(self, dl_manager): | |
| """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. | |
| If str or List[str], then the dataset returns only the 'train' split. | |
| If dict, then keys should be from the `datasets.Split` enum. | |
| """ | |
| if not self.config.data_files: | |
| raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") | |
| data_files = dl_manager.download_and_extract(self.config.data_files) | |
| if isinstance(data_files, (str, list, tuple)): | |
| files = data_files | |
| if isinstance(files, str): | |
| files = [files] | |
| return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] | |
| splits = [] | |
| for split_name, files in data_files.items(): | |
| if isinstance(files, str): | |
| files = [files] | |
| splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) | |
| return splits | |
| def _generate_tables(self, files): | |
| for file_idx, file in enumerate(files): | |
| batch_idx = 0 | |
| with open(file, "r", encoding=self.config.encoding) as f: | |
| while True: | |
| batch = f.read(self.config.chunksize) | |
| if not batch: | |
| break | |
| batch += f.readline() # finish current line | |
| batch = batch.splitlines() | |
| batch = [word.split()[-1] for word in batch] | |
| pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()})) | |
| # Uncomment for debugging (will print the Arrow table size and elements) | |
| #logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") | |
| #logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) | |
| yield (file_idx, batch_idx), pa_table | |
| batch_idx += 1 | |