ParakeetEricRoss/tests/test_dataset.py

123 lines
4.3 KiB
Python
Raw Normal View History

2020-10-10 15:51:54 +08:00
import unittest
import numpy as np
import paddle
from paddle import io
from parakeet import data
class MyDataset(io.Dataset):
def __init__(self, size):
self._data = np.random.randn(size, 6)
def __getitem__(self, i):
return self._data[i]
def __len__(self):
return self._data.shape[0]
class TestTransformDataset(unittest.TestCase):
def test(self):
dataset = MyDataset(20)
dataset = data.TransformDataset(dataset, lambda x: np.abs(x))
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
print("TransformDataset")
for batch, in dataloader:
print(type(batch), batch.dtype, batch.shape)
class TestChainDataset(unittest.TestCase):
def test(self):
dataset1 = MyDataset(20)
dataset2 = MyDataset(40)
dataset = data.ChainDataset(dataset1, dataset2)
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
print("ChainDataset")
for batch, in dataloader:
print(type(batch), batch.dtype, batch.shape)
class TestTupleDataset(unittest.TestCase):
def test(self):
dataset1 = MyDataset(20)
dataset2 = MyDataset(20)
dataset = data.TupleDataset(dataset1, dataset2)
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
print("TupleDataset")
for field1, field2 in dataloader:
print(type(field1), field1.dtype, field1.shape)
print(type(field2), field2.dtype, field2.shape)
class TestDictDataset(unittest.TestCase):
def test(self):
dataset1 = MyDataset(20)
dataset2 = MyDataset(20)
dataset = data.DictDataset(field1=dataset1, field2=dataset2)
def collate_fn(examples):
examples_tuples = []
for example in examples:
examples_tuples.append(example.values())
return paddle.fluid.dataloader.dataloader_iter.default_collate_fn(examples_tuples)
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1, collate_fn=collate_fn)
print("DictDataset")
for field1, field2 in dataloader:
print(type(field1), field1.dtype, field1.shape)
print(type(field2), field2.dtype, field2.shape)
class TestSliceDataset(unittest.TestCase):
def test(self):
dataset = MyDataset(40)
dataset = data.SliceDataset(dataset, 0, 20)
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
print("SliceDataset")
for batch, in dataloader:
print(type(batch), batch.dtype, batch.shape)
class TestSplit(unittest.TestCase):
def test(self):
dataset = MyDataset(40)
train, valid = data.split(dataset, 10)
dataloader1 = io.DataLoader(train, batch_size=4, shuffle=True, num_workers=1)
dataloader2 = io.DataLoader(valid, batch_size=4, shuffle=True, num_workers=1)
print("First Dataset")
for batch, in dataloader1:
print(type(batch), batch.dtype, batch.shape)
print("Second Dataset")
for batch, in dataloader2:
print(type(batch), batch.dtype, batch.shape)
class TestSubsetDataset(unittest.TestCase):
def test(self):
dataset = MyDataset(40)
indices = np.random.choice(np.arange(40), [20], replace=False).tolist()
dataset = data.SubsetDataset(dataset, indices)
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
print("SubsetDataset")
for batch, in dataloader:
print(type(batch), batch.dtype, batch.shape)
class TestFilterDataset(unittest.TestCase):
def test(self):
dataset = MyDataset(40)
dataset = data.FilterDataset(dataset, lambda x: np.mean(x)> 0.3)
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
print("FilterDataset")
for batch, in dataloader:
print(type(batch), batch.dtype, batch.shape)
class TestCacheDataset(unittest.TestCase):
def test(self):
dataset = MyDataset(40)
dataset = data.CacheDataset(dataset)
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
print("CacheDataset")
for batch, in dataloader:
print(type(batch), batch.dtype, batch.shape)