diff --git a/pipe.py b/pipe.py index 42ba7e9..909b9ef 100644 --- a/pipe.py +++ b/pipe.py @@ -38,6 +38,9 @@ def __init__(self, function, *args, **kwargs): ) functools.update_wrapper(self, function) + def __or__(self, other): + return Pipe(lambda iterable: self.function(iterable) | other) + def __ror__(self, other): return self.function(other) diff --git a/tests/test_pipe.py b/tests/test_pipe.py index 8c51980..c4f3cf9 100644 --- a/tests/test_pipe.py +++ b/tests/test_pipe.py @@ -39,3 +39,8 @@ def test_enumerate(): data = [4, "abc", {"key": "value"}] expected = [(5, 4), (6, "abc"), (7, {"key": "value"})] assert list(data | pipe.enumerate(start=5)) == expected + + +def test_composition(): + p = pipe.where(lambda x: not x % 2) | pipe.take(5) + assert list(range(100) | p) == [0, 2, 4, 6, 8]