Skip to content

Commit 98332be

Browse files
author
Timur Gilmullin
committed
#14: refactor trainer.py
1 parent 54a142e commit 98332be

1 file changed

Lines changed: 9 additions & 7 deletions

File tree

pybrain/supervised/trainers/trainer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
__author__ = 'Tom Schaul, tom@idsia.ch'
2-
__version__ = '$Id$'
1+
# -*- coding: utf-8 -*-
32

43
from pybrain.utilities import Named, abstractMethod
54

5+
__author__ = 'Tom Schaul, tom@idsia.ch'
6+
67

78
class Trainer(Named):
89
""" A trainer determines how to change the adaptive parameters of a module.
@@ -19,8 +20,11 @@ def setData(self, dataset):
1920
"""Associate the given dataset with the trainer."""
2021
self.ds = dataset
2122
if dataset:
22-
assert dataset.indim == self.module.indim
23-
assert dataset.outdim == self.module.outdim
23+
if dataset.indim != self.module.indim:
24+
raise Exception("{} not equals to {}".format(str(dataset.indim), str(self.module.indim)))
25+
26+
if dataset.outdim != self.module.outdim:
27+
raise Exception("{} not equals to {}".format(str(dataset.outdim), str(self.module.outdim)))
2428

2529
def trainOnDataset(self, dataset, *args, **kwargs):
2630
"""Set the dataset and train.
@@ -36,8 +40,6 @@ def trainEpochs(self, epochs=1, *args, **kwargs):
3640
for dummy in range(epochs):
3741
self.train(*args, **kwargs)
3842

39-
def train(self):
43+
def train(self, *args, **kwargs):
4044
"""Train on the current dataset, for a single epoch."""
4145
abstractMethod()
42-
43-

0 commit comments

Comments
 (0)