Merge pull request #83 from iclementine/develop

fix a bug when using a method other than forward with DataParallel
This commit is contained in:
Feiyu Chan 2021-01-11 17:28:45 +08:00 committed by GitHub
commit 353212ebde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -119,7 +119,7 @@ class Experiment(ExperimentBase):
mel, wav, audio_starts = batch mel, wav, audio_starts = batch
y = self.model(wav, mel, audio_starts) y = self.model(wav, mel, audio_starts)
loss = self.model.loss(y, wav) loss = self.model_core.loss(y, wav)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
iteration_time = time.time() - start iteration_time = time.time() - start
@ -141,7 +141,7 @@ class Experiment(ExperimentBase):
valid_losses = [] valid_losses = []
mel, wav, audio_starts = next(valid_iterator) mel, wav, audio_starts = next(valid_iterator)
y = self.model(wav, mel, audio_starts) y = self.model(wav, mel, audio_starts)
loss = self.model.loss(y, wav) loss = self.model_core.loss(y, wav)
valid_losses.append(float(loss)) valid_losses.append(float(loss))
valid_loss = np.mean(valid_losses) valid_loss = np.mean(valid_losses)
self.visualizer.add_scalar( self.visualizer.add_scalar(