Skip to content

Commit aabd4d7

Browse files
author
Timur Gilmullin
committed
#14: refactor utilities.py
1 parent 3ddd4b0 commit aabd4d7

1 file changed

Lines changed: 37 additions & 25 deletions

File tree

pybrain/utilities.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from __future__ import with_statement
1+
# -*- coding: utf-8 -*-
22

3-
__author__ = 'Tom Schaul, tom@idsia.ch; Justin Bayer, bayerj@in.tum.de'
3+
from __future__ import with_statement
44

55
import gc
66
import pickle
@@ -15,13 +15,17 @@
1515

1616
from scipy import where, array, exp, zeros, size, mat, median
1717

18+
__author__ = 'Tom Schaul, tom@idsia.ch; Justin Bayer, bayerj@in.tum.de'
19+
20+
1821
# file extension for load/save protocol mapping
1922
known_extensions = {
2023
'mat': 'matlab',
2124
'txt': 'ascii',
2225
'svm': 'libsvm',
2326
'pkl': 'pickle',
24-
'nc' : 'netcdf' }
27+
'nc': 'netcdf'
28+
}
2529

2630

2731
def abstractMethod():
@@ -149,42 +153,48 @@ class Serializable(object):
149153
supported.
150154
"""
151155

152-
def saveToFileLike(self, flo, format=None, **kwargs):
156+
def saveToFileLike(self, flo, formatObj=None, **kwargs):
153157
"""Save the object to a given file like object in the given format.
154158
"""
155-
format = 'pickle' if format is None else format
156-
save = getattr(self, "save_%s" % format, None)
159+
formatObj = 'pickle' if formatObj is None else formatObj
160+
save = getattr(self, "save_%s" % formatObj, None)
161+
157162
if save is None:
158-
raise ValueError("Unknown format '%s'." % format)
163+
raise ValueError("Unknown format '%s'." % formatObj)
164+
159165
save(flo, **kwargs)
160166

161167
@classmethod
162-
def loadFromFileLike(cls, flo, format=None):
168+
def loadFromFileLike(cls, flo, formatObj=None):
163169
"""Load the object to a given file like object with the given protocol.
164170
"""
165-
format = 'pickle' if format is None else format
166-
load = getattr(cls, "load_%s" % format, None)
171+
formatObj = 'pickle' if formatObj is None else formatObj
172+
load = getattr(cls, "load_%s" % formatObj, None)
173+
167174
if load is None:
168-
raise ValueError("Unknown format '%s'." % format)
175+
raise ValueError("Unknown format '%s'." % formatObj)
176+
169177
return load(flo)
170178

171-
def saveToFile(self, filename, format=None, **kwargs):
179+
def saveToFile(self, filename, formatObj=None, **kwargs):
172180
"""Save the object to file given by filename."""
173-
if format is None:
181+
if formatObj is None:
174182
# try to derive protocol from file extension
175-
format = formatFromExtension(filename)
176-
with file(filename, 'wb') as fp:
177-
self.saveToFileLike(fp, format, **kwargs)
183+
formatObj = formatFromExtension(filename)
184+
185+
with open(filename, 'wb') as fp:
186+
self.saveToFileLike(fp, formatObj, **kwargs)
178187

179188
@classmethod
180-
def loadFromFile(cls, filename, format=None):
189+
def loadFromFile(cls, filename, formatObj=None):
181190
"""Return an instance of the class that is saved in the file with the
182191
given filename in the specified format."""
183-
if format is None:
192+
if formatObj is None:
184193
# try to derive protocol from file extension
185-
format = formatFromExtension(filename)
186-
with file(filename, 'rbU') as fp:
187-
obj = cls.loadFromFileLike(fp, format)
194+
formatObj = formatFromExtension(filename)
195+
196+
with open(filename, 'rbU') as fp:
197+
obj = cls.loadFromFileLike(fp, formatObj)
188198
obj.filename = filename
189199
return obj
190200

@@ -485,8 +495,11 @@ def crossproduct(ss, row=None, level=0):
485495
if row is None:
486496
row = []
487497
if len(ss) > 1:
488-
return reduce(operator.add,
489-
[crossproduct(ss[1:], row + [i], level + 1) for i in ss[0]])
498+
import functools
499+
return functools.reduce(
500+
operator.add,
501+
[crossproduct(ss[1:], row + [i], level + 1) for i in ss[0]]
502+
)
490503
else:
491504
return [row + [i] for i in ss[0]]
492505

@@ -539,7 +552,7 @@ def permuteToBlocks2d(arr, blockheight, blockwidth):
539552
_height, width = arr.shape
540553
arr = arr.flatten()
541554
new = zeros(size(arr))
542-
for i in xrange(size(arr)):
555+
for i in range(size(arr)):
543556
blockx = (i % width) / blockwidth
544557
blocky = i / width / blockheight
545558
blockoffset = blocky * width / blockwidth + blockx
@@ -785,4 +798,3 @@ def weightedUtest(g1, w1, g2, w2):
785798
z = (u1 - mu) / sigu
786799
conf = norm.cdf(z)
787800
return conf
788-

0 commit comments

Comments
 (0)