123 lines
4.3 KiB
Python
123 lines
4.3 KiB
Python
import unittest
|
|
import numpy as np
|
|
import paddle
|
|
from paddle import io
|
|
from parakeet import data
|
|
|
|
class MyDataset(io.Dataset):
|
|
def __init__(self, size):
|
|
self._data = np.random.randn(size, 6)
|
|
|
|
def __getitem__(self, i):
|
|
return self._data[i]
|
|
|
|
def __len__(self):
|
|
return self._data.shape[0]
|
|
|
|
|
|
class TestTransformDataset(unittest.TestCase):
|
|
def test(self):
|
|
dataset = MyDataset(20)
|
|
dataset = data.TransformDataset(dataset, lambda x: np.abs(x))
|
|
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
|
|
print("TransformDataset")
|
|
for batch, in dataloader:
|
|
print(type(batch), batch.dtype, batch.shape)
|
|
|
|
|
|
class TestChainDataset(unittest.TestCase):
|
|
def test(self):
|
|
dataset1 = MyDataset(20)
|
|
dataset2 = MyDataset(40)
|
|
dataset = data.ChainDataset(dataset1, dataset2)
|
|
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
|
|
print("ChainDataset")
|
|
for batch, in dataloader:
|
|
print(type(batch), batch.dtype, batch.shape)
|
|
|
|
|
|
class TestTupleDataset(unittest.TestCase):
|
|
def test(self):
|
|
dataset1 = MyDataset(20)
|
|
dataset2 = MyDataset(20)
|
|
dataset = data.TupleDataset(dataset1, dataset2)
|
|
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
|
|
print("TupleDataset")
|
|
for field1, field2 in dataloader:
|
|
print(type(field1), field1.dtype, field1.shape)
|
|
print(type(field2), field2.dtype, field2.shape)
|
|
|
|
|
|
class TestDictDataset(unittest.TestCase):
|
|
def test(self):
|
|
dataset1 = MyDataset(20)
|
|
dataset2 = MyDataset(20)
|
|
dataset = data.DictDataset(field1=dataset1, field2=dataset2)
|
|
def collate_fn(examples):
|
|
examples_tuples = []
|
|
for example in examples:
|
|
examples_tuples.append(example.values())
|
|
return paddle.fluid.dataloader.dataloader_iter.default_collate_fn(examples_tuples)
|
|
|
|
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1, collate_fn=collate_fn)
|
|
print("DictDataset")
|
|
for field1, field2 in dataloader:
|
|
print(type(field1), field1.dtype, field1.shape)
|
|
print(type(field2), field2.dtype, field2.shape)
|
|
|
|
|
|
class TestSliceDataset(unittest.TestCase):
|
|
def test(self):
|
|
dataset = MyDataset(40)
|
|
dataset = data.SliceDataset(dataset, 0, 20)
|
|
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
|
|
print("SliceDataset")
|
|
for batch, in dataloader:
|
|
print(type(batch), batch.dtype, batch.shape)
|
|
|
|
|
|
class TestSplit(unittest.TestCase):
|
|
def test(self):
|
|
dataset = MyDataset(40)
|
|
train, valid = data.split(dataset, 10)
|
|
dataloader1 = io.DataLoader(train, batch_size=4, shuffle=True, num_workers=1)
|
|
dataloader2 = io.DataLoader(valid, batch_size=4, shuffle=True, num_workers=1)
|
|
print("First Dataset")
|
|
for batch, in dataloader1:
|
|
print(type(batch), batch.dtype, batch.shape)
|
|
|
|
print("Second Dataset")
|
|
for batch, in dataloader2:
|
|
print(type(batch), batch.dtype, batch.shape)
|
|
|
|
|
|
class TestSubsetDataset(unittest.TestCase):
|
|
def test(self):
|
|
dataset = MyDataset(40)
|
|
indices = np.random.choice(np.arange(40), [20], replace=False).tolist()
|
|
dataset = data.SubsetDataset(dataset, indices)
|
|
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
|
|
print("SubsetDataset")
|
|
for batch, in dataloader:
|
|
print(type(batch), batch.dtype, batch.shape)
|
|
|
|
|
|
class TestFilterDataset(unittest.TestCase):
|
|
def test(self):
|
|
dataset = MyDataset(40)
|
|
dataset = data.FilterDataset(dataset, lambda x: np.mean(x)> 0.3)
|
|
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
|
|
print("FilterDataset")
|
|
for batch, in dataloader:
|
|
print(type(batch), batch.dtype, batch.shape)
|
|
|
|
|
|
class TestCacheDataset(unittest.TestCase):
|
|
def test(self):
|
|
dataset = MyDataset(40)
|
|
dataset = data.CacheDataset(dataset)
|
|
dataloader = io.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=1)
|
|
print("CacheDataset")
|
|
for batch, in dataloader:
|
|
print(type(batch), batch.dtype, batch.shape)
|