Merge pull request #83 from iclementine/develop
fix a bug when using a method other than forward with DataParallel
This commit is contained in:
commit
353212ebde
|
@ -119,7 +119,7 @@ class Experiment(ExperimentBase):
|
||||||
mel, wav, audio_starts = batch
|
mel, wav, audio_starts = batch
|
||||||
|
|
||||||
y = self.model(wav, mel, audio_starts)
|
y = self.model(wav, mel, audio_starts)
|
||||||
loss = self.model.loss(y, wav)
|
loss = self.model_core.loss(y, wav)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
iteration_time = time.time() - start
|
iteration_time = time.time() - start
|
||||||
|
@ -141,7 +141,7 @@ class Experiment(ExperimentBase):
|
||||||
valid_losses = []
|
valid_losses = []
|
||||||
mel, wav, audio_starts = next(valid_iterator)
|
mel, wav, audio_starts = next(valid_iterator)
|
||||||
y = self.model(wav, mel, audio_starts)
|
y = self.model(wav, mel, audio_starts)
|
||||||
loss = self.model.loss(y, wav)
|
loss = self.model_core.loss(y, wav)
|
||||||
valid_losses.append(float(loss))
|
valid_losses.append(float(loss))
|
||||||
valid_loss = np.mean(valid_losses)
|
valid_loss = np.mean(valid_losses)
|
||||||
self.visualizer.add_scalar(
|
self.visualizer.add_scalar(
|
||||||
|
|
Loading…
Reference in New Issue