Skip to content

Commit 3ddd4b0

Browse files
author
Timur Gilmullin
committed
#14: refactor validation.py
1 parent e63c5ed commit 3ddd4b0

1 file changed

Lines changed: 8 additions & 10 deletions

File tree

pybrain/tools/validation.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
__author__ = 'Michael Isik'
2-
1+
# -*- coding: utf-8 -*-
32

43
from numpy.random import permutation
54
from numpy import array, array_split, apply_along_axis, concatenate, ones, dot, delete, append, zeros, argmax
@@ -8,6 +7,7 @@
87
from pybrain.datasets.sequential import SequentialDataSet
98
from pybrain.datasets.supervised import SupervisedDataSet
109

10+
__author__ = 'Michael Isik'
1111

1212

1313
class Validator(object):
@@ -57,7 +57,10 @@ def MSE(cls, output, target, importance=None):
5757
# assert equal shapes
5858
output = array(output)
5959
target = array(target)
60-
assert output.shape == target.shape
60+
61+
if output.shape != target.shape:
62+
raise Exception("output.shape != target.shape")
63+
6164
if importance is not None:
6265
assert importance.shape == target.shape
6366
importance = importance.flatten()
@@ -215,7 +218,7 @@ def _calculateModuleOutputSequential(cls, module, dataset):
215218
outputs = []
216219
for seq in dataset._provideSequences():
217220
module.reset()
218-
for i in xrange(len(seq)):
221+
for i in range(len(seq)):
219222
output = module.activate(seq[i][0])
220223
outputs.append(output.copy())
221224
outputs = array(outputs)
@@ -370,7 +373,7 @@ def testOnSequenceData(module, dataset):
370373
# one-of-many values
371374
class_output = []
372375
class_target = []
373-
for j in xrange(len(output)):
376+
for j in range(len(output)):
374377
# sum up the output values of one sequence
375378
summed_output += output[j]
376379
# print(j, output[j], " --> ", summed_output)
@@ -391,8 +394,3 @@ def testOnSequenceData(module, dataset):
391394
# print(class_target)
392395
# print(class_output)
393396
return Validator.classificationPerformance(class_output, class_target)
394-
395-
396-
397-
398-

0 commit comments

Comments
 (0)