ParakeetRebeccaRosario/parakeet/datasets/data_table.py

152 lines
4.6 KiB
Python
Raw Normal View History

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
from pathlib import Path
from multiprocessing import Manager
import numpy as np
from paddle.io import Dataset
class DataTable(Dataset):
2021-06-11 00:59:07 +08:00
"""Dataset to load and convert data for general purpose.
Parameters
----------
data : List[Dict[str, Any]]
Metadata, a list of meta datum, each of which is composed of
several fields
fields : List[str], optional
Fields to use, if not specified, all the fields in the data are
used, by default None
converters : Dict[str, Callable], optional
Converters used to process each field, by default None
use_cache : bool, optional
Whether to use cache, by default False
Raises
------
ValueError
If there is some field that does not exist in data.
ValueError
If there is some field in converters that does not exist in fields.
"""
def __init__(self,
data: List[Dict[str, Any]],
fields: List[str]=None,
converters: Dict[str, Callable]=None,
use_cache: bool=False):
# metadata
self.data = data
assert len(data) > 0, "This dataset has no examples"
# peak an example to get existing fields.
first_example = self.data[0]
fields_in_data = first_example.keys()
# check all the requested fields exist
if fields is None:
self.fields = fields_in_data
else:
for field in fields:
if field not in fields_in_data:
raise ValueError(
f"The requested field ({field}) is not found"
f"in the data. Fields in the data is {fields_in_data}")
self.fields = fields
# check converters
if converters is None:
self.converters = {}
else:
for field in converters.keys():
if field not in self.fields:
raise ValueError(
f"The converter has a non existing field ({field})")
self.converters = converters
self.use_cache = use_cache
if use_cache:
self._initialize_cache()
def _initialize_cache(self):
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [None for _ in range(len(self))]
def _get_metadata(self, idx: int) -> Dict[str, Any]:
"""Return a meta-datum given an index."""
return self.data[idx]
def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a meta datum to an example by applying the corresponding
converters to each fields requested.
Parameters
----------
meta_datum : Dict[str, Any]
Meta datum
Returns
-------
Dict[str, Any]
Converted example
"""
example = {}
for field in self.fields:
converter = self.converters.get(field, None)
meta_datum_field = meta_datum[field]
if converter is not None:
converted_field = converter(meta_datum_field)
else:
converted_field = meta_datum_field
example[field] = converted_field
return example
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""Get an example given an index.
Parameters
----------
idx : int
Index of the example to get
Returns
-------
Dict[str, Any]
A converted example
"""
if self.use_cache and self.caches[idx] is not None:
return self.caches[idx]
meta_datum = self._get_metadata(idx)
example = self._convert(meta_datum)
if self.use_cache:
self.caches[idx] = example
return example
def __len__(self) -> int:
"""Returns the size of the dataset.
Returns
-------
int
The length of the dataset
"""
return len(self.data)