# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Union, Optional, Callable, Tuple, List, Dict, Any from pathlib import Path from multiprocessing import Manager import numpy as np from paddle.io import Dataset class DataTable(Dataset): """Dataset to load and convert data for general purpose. Parameters ---------- data : List[Dict[str, Any]] Metadata, a list of meta datum, each of which is composed of several fields fields : List[str], optional Fields to use, if not specified, all the fields in the data are used, by default None converters : Dict[str, Callable], optional Converters used to process each field, by default None use_cache : bool, optional Whether to use cache, by default False Raises ------ ValueError If there is some field that does not exist in data. ValueError If there is some field in converters that does not exist in fields. """ def __init__(self, data: List[Dict[str, Any]], fields: List[str]=None, converters: Dict[str, Callable]=None, use_cache: bool=False): # metadata self.data = data assert len(data) > 0, "This dataset has no examples" # peak an example to get existing fields. first_example = self.data[0] fields_in_data = first_example.keys() # check all the requested fields exist if fields is None: self.fields = fields_in_data else: for field in fields: if field not in fields_in_data: raise ValueError( f"The requested field ({field}) is not found" f"in the data. Fields in the data is {fields_in_data}") self.fields = fields # check converters if converters is None: self.converters = {} else: for field in converters.keys(): if field not in self.fields: raise ValueError( f"The converter has a non existing field ({field})") self.converters = converters self.use_cache = use_cache if use_cache: self._initialize_cache() def _initialize_cache(self): self.manager = Manager() self.caches = self.manager.list() self.caches += [None for _ in range(len(self))] def _get_metadata(self, idx: int) -> Dict[str, Any]: """Return a meta-datum given an index.""" return self.data[idx] def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]: """Convert a meta datum to an example by applying the corresponding converters to each fields requested. Parameters ---------- meta_datum : Dict[str, Any] Meta datum Returns ------- Dict[str, Any] Converted example """ example = {} for field in self.fields: converter = self.converters.get(field, None) meta_datum_field = meta_datum[field] if converter is not None: converted_field = converter(meta_datum_field) else: converted_field = meta_datum_field example[field] = converted_field return example def __getitem__(self, idx: int) -> Dict[str, Any]: """Get an example given an index. Parameters ---------- idx : int Index of the example to get Returns ------- Dict[str, Any] A converted example """ if self.use_cache and self.caches[idx] is not None: return self.caches[idx] meta_datum = self._get_metadata(idx) example = self._convert(meta_datum) if self.use_cache: self.caches[idx] = example return example def __len__(self) -> int: """Returns the size of the dataset. Returns ------- int The length of the dataset """ return len(self.data)