Heating controller with neural thermal model written in Python
Jacek Kowalski
2018-06-24 425bf71fc0b24b547006686d83404c54b983de0b
commit | author | age
425bf7 1 import unittest
JK 2
3 from lib.SlidingWindow import SlidingWindow
4
5
6 class SlidingWindowTests(unittest.TestCase):
7
8     def test_model_only_init(self):
9         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
10                                        past_values=0, future_values=0)
11         for i in range(1, 5):
12             self.assertFalse(sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3}))
13         for i in range(5, 7):
14             self.assertTrue(sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3}))
15
16     def test_model_only_values(self):
17         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
18                                        past_values=0, future_values=0)
19         for i in range(1, 5):
20             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
21
22         self.assertListEqual(
23             sliding_window.get_model_values(),
24             [2, 4, 6, 3, 6, 9]
25         )
26
27     def test_model_only_values_next(self):
28         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
29                                        past_values=0, future_values=0)
30         for i in range(1, 6):
31             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
32
33         self.assertListEqual(
34             sliding_window.get_model_values(),
35             [2, 4, 6, 3, 6, 9]
36         )
37
38     def test_model_target_values(self):
39         sliding_window = SlidingWindow(model_past_values=2, past_values=0, future_values=0)
40         sliding_window.add_observation({'time': 1, 'a': 2, 'temp_in': 3})
41         sliding_window.add_observation({'time': 2, 'a': 4, 'temp_in': 6})
42         sliding_window.add_observation({'time': 3, 'a': 6, 'temp_in': 9})
43         self.assertTrue(sliding_window.add_observation({'time': 4, 'a': 8, 'temp_in': 12}))
44
45         self.assertEqual(sliding_window.get_model_target(), 9)
46
47     def test_all_init(self):
48         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
49                                        past_values=1, future_values=2)
50         for i in range(1, 6):
51             self.assertFalse(sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3}))
52         for i in range(6, 7):
53             self.assertTrue(sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3}))
54
55     def test_all_values(self):
56         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
57                                        past_values=1, past_fields=('a',),
58                                        future_values=2, future_fields=('b',))
59         for i in range(1, 7):
60             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
61
62         self.assertListEqual(
63             sliding_window.get_model_values(),
64             [2, 4, 6, 3, 6, 9]
65         )
66
67         self.assertListEqual(
68             sliding_window._get_past_values(),
69             [6]
70         )
71
72         self.assertListEqual(
73             sliding_window._get_future_values(),
74             [12, 15]
75         )
76
77     def test_all_values_next(self):
78         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
79                                        past_values=1, past_fields=('a',),
80                                        future_values=2, future_fields=('b',))
81         for i in range(1, 8):
82             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
83
84         self.assertListEqual(
85             sliding_window.get_model_values(),
86             [4, 6, 8, 6, 9, 12]
87         )
88         self.assertEqual(sliding_window.get_previous_value('a'), 6)
89         self.assertEqual(sliding_window.get_current_value('a'), 8)
90         self.assertEqual(sliding_window.get_next_value('a'), 10)
91         self.assertEqual(sliding_window.get_previous_value('b'), 9)
92         self.assertEqual(sliding_window.get_current_value('b'), 12)
93         self.assertEqual(sliding_window.get_next_value('b'), 15)
94
95         self.assertListEqual(
96             sliding_window._get_past_values(),
97             [8]
98         )
99
100         self.assertListEqual(
101             sliding_window._get_future_values(),
102             [15, 18]
103         )
104
105     def test_replace_value_current(self):
106         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
107                                        past_values=1, past_fields=('a',),
108                                        future_values=2, future_fields=('b',))
109         for i in range(1, 7):
110             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
111
112         sliding_window.set_current_value('a', -1)
113
114         self.assertListEqual(
115             sliding_window.get_model_values(),
116             [2, 4, -1, 3, 6, 9]
117         )
118
119         self.assertListEqual(
120             sliding_window._get_past_values(),
121             [-1]
122         )
123
124         self.assertListEqual(
125             sliding_window._get_future_values(),
126             [12, 15]
127         )
128
129     def test_replace_value_next(self):
130         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
131                                        past_values=1, past_fields=('a',),
132                                        future_values=2, future_fields=('b',))
133         for i in range(1, 7):
134             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
135
136         sliding_window.set_next_value('b', -1)
137
138         self.assertListEqual(
139             sliding_window.get_model_values(),
140             [2, 4, 6, 3, 6, 9]
141         )
142
143         self.assertListEqual(
144             sliding_window._get_past_values(),
145             [6]
146         )
147
148         self.assertListEqual(
149             sliding_window._get_future_values(),
150             [-1, 15]
151         )
152
153     def test_replace_value_next_next(self):
154         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
155                                        past_values=1, past_fields=('a',),
156                                        future_values=2, future_fields=('b',))
157         for i in range(1, 7):
158             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
159
160         sliding_window.set_current_value('a', -1)
161         sliding_window.set_next_value('b', -2)
162
163         for i in range(7, 8):
164             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
165
166         self.assertEqual(sliding_window.get_previous_value('a'), -1)
167         self.assertEqual(sliding_window.get_current_value('b'), -2)
168
169     def test_next(self):
170         sliding_window = SlidingWindow(model_past_values=3, model_past_fields=('a', 'b'),
171                                        past_values=1, past_fields=('a',),
172                                        model_future_values=1, model_future_fields=('b',),
173                                        future_values=2, future_fields=('b',))
174
175         for i in range(1, 7):
176             sliding_window.add_observation({'time': i, 'a': i * 2, 'b': i * 3})
177
178         self.assertListEqual(
179             sliding_window.get_model_values(),
180             [2, 4, 6, 3, 6, 9, 12]
181         )
182
183         sliding_window.next()
184
185         self.assertListEqual(
186             sliding_window.get_model_values(),
187             [4, 6, 8, 6, 9, 12, 15]
188         )
189