1- __author__ = 'Tom Schaul, tom@idsia.ch'
2- __version__ = '$Id$'
1+ # -*- coding: utf-8 -*-
32
43from pybrain .utilities import Named , abstractMethod
54
5+ __author__ = 'Tom Schaul, tom@idsia.ch'
6+
67
78class Trainer (Named ):
89 """ A trainer determines how to change the adaptive parameters of a module.
@@ -19,8 +20,11 @@ def setData(self, dataset):
1920 """Associate the given dataset with the trainer."""
2021 self .ds = dataset
2122 if dataset :
22- assert dataset .indim == self .module .indim
23- assert dataset .outdim == self .module .outdim
23+ if dataset .indim != self .module .indim :
24+ raise Exception ("{} not equals to {}" .format (str (dataset .indim ), str (self .module .indim )))
25+
26+ if dataset .outdim != self .module .outdim :
27+ raise Exception ("{} not equals to {}" .format (str (dataset .outdim ), str (self .module .outdim )))
2428
2529 def trainOnDataset (self , dataset , * args , ** kwargs ):
2630 """Set the dataset and train.
@@ -36,8 +40,6 @@ def trainEpochs(self, epochs=1, *args, **kwargs):
3640 for dummy in range (epochs ):
3741 self .train (* args , ** kwargs )
3842
39- def train (self ):
43+ def train (self , * args , ** kwargs ):
4044 """Train on the current dataset, for a single epoch."""
4145 abstractMethod ()
42-
43-
0 commit comments