-
Notifications
You must be signed in to change notification settings - Fork 13
/
test_curry.py
98 lines (71 loc) · 2.59 KB
/
test_curry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import math
import unittest
from functools import reduce
from curry import curry
def func(a, b, *, c, d):
return (a, b), dict(c=c, d=d)
def add(a, b):
return a + b
def add_default(a, b=10):
return a + b
class TestCurry(unittest.TestCase):
def test_basic_examples(self):
f = curry(func)
self.assertEqual(f(1)(2)(c=10)(d=20),
((1, 2), dict(c=10, d=20)))
def test_builtins(self):
f = curry(map)(lambda x: x + 1)
self.assertEqual(list(f([1, 2, 3, 4])),
[2, 3, 4, 5])
f = curry(math.pow)(2)
self.assertAlmostEqual(f(4), 16)
def test_attributes(self):
f = curry(func)(1)(2)(c=10)
self.assertEqual(f.__name__, 'func')
def test_argument_check(self):
self.assertRaises(TypeError, curry, 1)
def test_args_dont_persist(self):
curried_func = curry(func)
f = curried_func(1)(2)(c=10)
g = curried_func('a')('b')(c='c')
self.assertEqual(f(d=20),
((1, 2), dict(c=10, d=20)))
self.assertEqual(g(d='d'),
(('a', 'b'), dict(c='c', d='d')))
def test_args_dont_persist_after_first(self):
factory = curry(lambda a, b, c: None)
curried = factory(1)
given_b = curried(2)
given_c = curried(3)
self.assertIsNotNone(given_c)
def test_kwargs_dont_persist(self):
factory = curry(lambda a=None, b=None, c=None: None)
curried = factory(a=None)
given_b = curried(b=None)
given_c = curried(c=None)
self.assertIsNotNone(given_c)
def test_mutable_args(self):
def concat(a, b):
ret = []
ret.extend(a)
ret.extend(b)
return ret
concat = curry(concat)
self.assertEqual([1, 2, 3, 4], concat([1, 2])([3, 4]))
def test_positional_kwargs(self):
add_default = curry(lambda a, b=10: a + b)
self.assertEqual(3, add_default(1)(2))
def test_specify_arity(self):
sum_ = lambda *xs: reduce(add, xs, 0)
add_arity_2 = curry(sum_, n=2)
self.assertEqual(add_arity_2(10, 10), 20)
self.assertEqual(curry(sum_, n=2, use_defaults=True)(10, 10), 20)
add_ten = add_arity_2(10)
self.assertEqual(add_ten(10), 20)
def test_defaults(self):
curried_add = curry(add_default, use_defaults=True)
self.assertEqual(curried_add(20), 30)
self.assertEqual(curried_add(20, 20), 40)
self.assertRaises(TypeError, curried_add, {20, 40})
if __name__ == '__main__':
unittest.main()