fix win train loader
This commit is contained in:
parent
d63431a633
commit
d5651dfa99
|
@ -199,8 +199,12 @@ def train(config,
|
||||||
train_reader_cost = 0.0
|
train_reader_cost = 0.0
|
||||||
batch_sum = 0
|
batch_sum = 0
|
||||||
batch_start = time.time()
|
batch_start = time.time()
|
||||||
for idx, batch in enumerate(train_dataloader()):
|
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||||
|
) == "Windows" else len(train_dataloader)
|
||||||
|
for idx, batch in enumerate(train_dataloader):
|
||||||
train_reader_cost += time.time() - batch_start
|
train_reader_cost += time.time() - batch_start
|
||||||
|
if idx >= max_iter:
|
||||||
|
break
|
||||||
lr = optimizer.get_lr()
|
lr = optimizer.get_lr()
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
if use_srn:
|
if use_srn:
|
||||||
|
|
Loading…
Reference in New Issue