xref: /aosp_15_r20/external/autotest/frontend/afe/rpc_interface_unittest.py (revision 9c5db1993ded3edbeafc8092d69fe5de2ee02df7)
1#!/usr/bin/env python3
2# pylint: disable=missing-docstring
3
4from __future__ import absolute_import
5
6import datetime
7import unittest
8from unittest.mock import patch
9from unittest.mock import MagicMock
10
11import six
12from autotest_lib.client.common_lib import (control_data, error, global_config,
13                                            priorities)
14from autotest_lib.client.common_lib.cros import dev_server
15from autotest_lib.frontend import setup_django_environment
16from autotest_lib.frontend.afe import (frontend_test_utils, model_logic,
17                                       models, rpc_interface, rpc_utils)
18from autotest_lib.server import utils as server_utils
19from autotest_lib.server.cros import provision
20from autotest_lib.server.cros.dynamic_suite import (constants,
21                                                    control_file_getter,
22                                                    frontend_wrappers,
23                                                    suite_common)
24
25import common
26
27CLIENT = control_data.CONTROL_TYPE_NAMES.CLIENT
28SERVER = control_data.CONTROL_TYPE_NAMES.SERVER
29
30_hqe_status = models.HostQueueEntry.Status
31
32
33class ShardHeartbeatTest(unittest.TestCase):
34
35    _PRIORITY = priorities.Priority.DEFAULT
36
37
38    def _do_heartbeat_and_assert_response(self, shard_hostname='shard1',
39                                          upload_jobs=(), upload_hqes=(),
40                                          known_jobs=(), known_hosts=(),
41                                          **kwargs):
42        known_job_ids = [job.id for job in known_jobs]
43        known_host_ids = [host.id for host in known_hosts]
44        known_host_statuses = [host.status for host in known_hosts]
45
46        retval = rpc_interface.shard_heartbeat(
47            shard_hostname=shard_hostname,
48            jobs=upload_jobs, hqes=upload_hqes,
49            known_job_ids=known_job_ids, known_host_ids=known_host_ids,
50            known_host_statuses=known_host_statuses)
51
52        self._assert_shard_heartbeat_response(shard_hostname, retval,
53                                              **kwargs)
54
55        return shard_hostname
56
57
58    def _assert_shard_heartbeat_response(self, shard_hostname, retval, jobs=[],
59                                         hosts=[], hqes=[],
60                                         incorrect_host_ids=[]):
61
62        retval_hosts, retval_jobs = retval['hosts'], retval['jobs']
63        retval_incorrect_hosts = retval['incorrect_host_ids']
64
65        expected_jobs = [
66            (job.id, job.name, shard_hostname) for job in jobs]
67        returned_jobs = [(job['id'], job['name'], job['shard']['hostname'])
68                         for job in retval_jobs]
69        self.assertEqual(returned_jobs, expected_jobs)
70
71        expected_hosts = [(host.id, host.hostname) for host in hosts]
72        returned_hosts = [(host['id'], host['hostname'])
73                          for host in retval_hosts]
74        self.assertEqual(returned_hosts, expected_hosts)
75
76        retval_hqes = []
77        for job in retval_jobs:
78            retval_hqes += job['hostqueueentry_set']
79
80        expected_hqes = [(hqe.id) for hqe in hqes]
81        returned_hqes = [(hqe['id']) for hqe in retval_hqes]
82        self.assertEqual(returned_hqes, expected_hqes)
83
84        self.assertEqual(retval_incorrect_hosts, incorrect_host_ids)
85
86
87    def _createJobForLabel(self, label):
88        job_id = rpc_interface.create_job(name='stub', priority=self._PRIORITY,
89                                          control_file='foo',
90                                          control_type=CLIENT,
91                                          meta_hosts=[label.name],
92                                          dependencies=(label.name,))
93        return models.Job.objects.get(id=job_id)
94
95
96    def _testShardHeartbeatFetchHostlessJobHelper(self, host1):
97        """Create a hostless job and ensure it's not assigned to a shard."""
98        label2 = models.Label.objects.create(name='bluetooth', platform=False)
99
100        job1 = self._create_job(hostless=True)
101
102        # Hostless jobs should be executed by the global scheduler.
103        self._do_heartbeat_and_assert_response(hosts=[host1])
104
105
106    def _testShardHeartbeatIncorrectHostsHelper(self, host1):
107        """Ensure that hosts that don't belong to shard are determined."""
108        host2 = models.Host.objects.create(hostname='test_host2', leased=False)
109
110        # host2 should not belong to shard1. Ensure that if shard1 thinks host2
111        # is a known host, then it is returned as invalid.
112        self._do_heartbeat_and_assert_response(known_hosts=[host1, host2],
113                                               incorrect_host_ids=[host2.id])
114
115
116    def _testShardHeartbeatLabelRemovalRaceHelper(self, shard1, host1, label1):
117        """Ensure correctness if label removed during heartbeat."""
118        host2 = models.Host.objects.create(hostname='test_host2', leased=False)
119        host2.labels.add(label1)
120        self.assertEqual(host2.shard, None)
121
122        # In the middle of the assign_to_shard call, remove label1 from shard1.
123        with patch.object(
124                models.Host,
125                '_assign_to_shard_nothing_helper',
126                side_effect=lambda: rpc_interface.remove_board_from_shard(
127                        shard1.hostname, label1.name)):
128            self._do_heartbeat_and_assert_response(
129                    known_hosts=[host1],
130                    hosts=[],
131                    incorrect_host_ids=[host1.id])
132            host2 = models.Host.smart_get(host2.id)
133
134        self.assertEqual(host2.shard, None)
135
136
137    def _testShardRetrieveJobsHelper(self, shard1, host1, label1, shard2,
138                                     host2, label2):
139        """Create jobs and retrieve them."""
140        # should never be returned by heartbeat
141        leased_host = models.Host.objects.create(hostname='leased_host',
142                                                 leased=True)
143
144        leased_host.labels.add(label1)
145
146        job1 = self._createJobForLabel(label1)
147
148        job2 = self._createJobForLabel(label2)
149
150        job_completed = self._createJobForLabel(label1)
151        # Job is already being run, so don't sync it
152        job_completed.hostqueueentry_set.update(complete=True)
153        job_completed.hostqueueentry_set.create(complete=False)
154
155        job_active = self._createJobForLabel(label1)
156        # Job is already started, so don't sync it
157        job_active.hostqueueentry_set.update(active=True)
158        job_active.hostqueueentry_set.create(complete=False, active=False)
159
160        self._do_heartbeat_and_assert_response(
161            jobs=[job1], hosts=[host1], hqes=job1.hostqueueentry_set.all())
162
163        self._do_heartbeat_and_assert_response(
164            shard_hostname=shard2.hostname,
165            jobs=[job2], hosts=[host2], hqes=job2.hostqueueentry_set.all())
166
167        host3 = models.Host.objects.create(hostname='test_host3', leased=False)
168        host3.labels.add(label1)
169
170        self._do_heartbeat_and_assert_response(
171            known_jobs=[job1], known_hosts=[host1], hosts=[host3])
172
173
174    def _testResendJobsAfterFailedHeartbeatHelper(self, shard1, host1, label1):
175        """Create jobs, retrieve them, fail on client, fetch them again."""
176        job1 = self._createJobForLabel(label1)
177
178        self._do_heartbeat_and_assert_response(
179            jobs=[job1],
180            hqes=job1.hostqueueentry_set.all(), hosts=[host1])
181
182        # Make sure it's resubmitted by sending last_job=None again
183        self._do_heartbeat_and_assert_response(
184            known_hosts=[host1],
185            jobs=[job1], hqes=job1.hostqueueentry_set.all(), hosts=[])
186
187        # Now it worked, make sure it's not sent again
188        self._do_heartbeat_and_assert_response(
189            known_jobs=[job1], known_hosts=[host1])
190
191        job1 = models.Job.objects.get(pk=job1.id)
192        job1.hostqueueentry_set.all().update(complete=True)
193
194        # Job is completed, make sure it's not sent again
195        self._do_heartbeat_and_assert_response(
196            known_hosts=[host1])
197
198        job2 = self._createJobForLabel(label1)
199
200        # job2's creation was later, it should be returned now.
201        self._do_heartbeat_and_assert_response(
202            known_hosts=[host1],
203            jobs=[job2], hqes=job2.hostqueueentry_set.all())
204
205        self._do_heartbeat_and_assert_response(
206            known_jobs=[job2], known_hosts=[host1])
207
208        job2 = models.Job.objects.get(pk=job2.pk)
209        job2.hostqueueentry_set.update(aborted=True)
210        # Setting a job to a complete status will set the shard_id to None in
211        # scheduler_models. We have to emulate that here, because we use Django
212        # models in tests.
213        job2.shard = None
214        job2.save()
215
216        self._do_heartbeat_and_assert_response(
217            known_jobs=[job2], known_hosts=[host1],
218            jobs=[job2],
219            hqes=job2.hostqueueentry_set.all())
220
221        models.Test.objects.create(name='platform_BootPerfServer:shard',
222                                   test_type=1)
223        with patch.object(server_utils, 'read_file'):
224            rpc_interface.delete_shard(hostname=shard1.hostname)
225
226            self.assertRaises(models.Shard.DoesNotExist,
227                              models.Shard.objects.get,
228                              pk=shard1.id)
229
230            job1 = models.Job.objects.get(pk=job1.id)
231            label1 = models.Label.objects.get(pk=label1.id)
232
233            self.assertIsNone(job1.shard)
234            self.assertEqual(len(label1.shard_set.all()), 0)
235
236
237    def _testResendHostsAfterFailedHeartbeatHelper(self, host1):
238        """Check that main accepts resending updated records after failure."""
239        # Send the host
240        self._do_heartbeat_and_assert_response(hosts=[host1])
241
242        # Send it again because previous one didn't persist correctly
243        self._do_heartbeat_and_assert_response(hosts=[host1])
244
245        # Now it worked, make sure it isn't sent again
246        self._do_heartbeat_and_assert_response(known_hosts=[host1])
247
248
249class RpcInterfaceTestWithStaticAttribute(unittest.TestCase,
250                                          frontend_test_utils.FrontendTestMixin
251                                          ):
252
253    def setUp(self):
254        super(RpcInterfaceTestWithStaticAttribute, self).setUp()
255        self._frontend_common_setup()
256        self.old_respect_static_config = rpc_interface.RESPECT_STATIC_ATTRIBUTES
257        rpc_interface.RESPECT_STATIC_ATTRIBUTES = True
258        models.RESPECT_STATIC_ATTRIBUTES = True
259
260
261    def tearDown(self):
262        self._frontend_common_teardown()
263        global_config.global_config.reset_config_values()
264        rpc_interface.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
265        models.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
266
267
268    def _fake_host_with_static_attributes(self):
269        host1 = models.Host.objects.create(hostname='test_host')
270        host1.set_attribute('test_attribute1', 'test_value1')
271        host1.set_attribute('test_attribute2', 'test_value2')
272        self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
273        self._set_static_attribute(host1, 'static_attribute1', 'static_value2')
274        host1.save()
275        return host1
276
277
278    def test_get_hosts(self):
279        host1 = self._fake_host_with_static_attributes()
280        hosts = rpc_interface.get_hosts(hostname=host1.hostname)
281        host = hosts[0]
282
283        self.assertEquals(host['hostname'], 'test_host')
284        self.assertEquals(host['acls'], ['Everyone'])
285        # Respect the value of static attributes.
286        self.assertEquals(host['attributes'],
287                          {'test_attribute1': 'static_value1',
288                           'test_attribute2': 'test_value2',
289                           'static_attribute1': 'static_value2'})
290
291    def test_get_host_attribute_with_static(self):
292        host1 = models.Host.objects.create(hostname='test_host1')
293        host1.set_attribute('test_attribute1', 'test_value1')
294        self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
295        host2 = models.Host.objects.create(hostname='test_host2')
296        host2.set_attribute('test_attribute1', 'test_value1')
297        host2.set_attribute('test_attribute2', 'test_value2')
298
299        attributes = rpc_interface.get_host_attribute(
300                'test_attribute1',
301                hostname__in=['test_host1', 'test_host2'])
302        hosts = [attr['host'] for attr in attributes]
303        values = [attr['value'] for attr in attributes]
304        self.assertEquals(set(hosts),
305                          set(['test_host1', 'test_host2']))
306        self.assertEquals(set(values),
307                          set(['test_value1', 'static_value1']))
308
309
310    def test_get_hosts_by_attribute_without_static(self):
311        host1 = models.Host.objects.create(hostname='test_host1')
312        host1.set_attribute('test_attribute1', 'test_value1')
313        host2 = models.Host.objects.create(hostname='test_host2')
314        host2.set_attribute('test_attribute1', 'test_value1')
315
316        hosts = rpc_interface.get_hosts_by_attribute(
317                'test_attribute1', 'test_value1')
318        self.assertEquals(set(hosts),
319                          set(['test_host1', 'test_host2']))
320
321
322    def test_get_hosts_by_attribute_with_static(self):
323        host1 = models.Host.objects.create(hostname='test_host1')
324        host1.set_attribute('test_attribute1', 'test_value1')
325        self._set_static_attribute(host1, 'test_attribute1', 'test_value1')
326        host2 = models.Host.objects.create(hostname='test_host2')
327        host2.set_attribute('test_attribute1', 'test_value1')
328        self._set_static_attribute(host2, 'test_attribute1', 'static_value1')
329        host3 = models.Host.objects.create(hostname='test_host3')
330        self._set_static_attribute(host3, 'test_attribute1', 'test_value1')
331        host4 = models.Host.objects.create(hostname='test_host4')
332        host4.set_attribute('test_attribute1', 'test_value1')
333        host5 = models.Host.objects.create(hostname='test_host5')
334        host5.set_attribute('test_attribute1', 'temp_value1')
335        self._set_static_attribute(host5, 'test_attribute1', 'test_value1')
336
337        hosts = rpc_interface.get_hosts_by_attribute(
338                'test_attribute1', 'test_value1')
339        # host1: matched, it has the same value for test_attribute1.
340        # host2: not matched, it has a new value in
341        #        afe_static_host_attributes for test_attribute1.
342        # host3: matched, it has a corresponding entry in
343        #        afe_host_attributes for test_attribute1.
344        # host4: matched, test_attribute1 is not replaced by static
345        #        attribute.
346        # host5: matched, it has an updated & matched value for
347        #        test_attribute1 in afe_static_host_attributes.
348        self.assertEquals(set(hosts),
349                          set(['test_host1', 'test_host3',
350                               'test_host4', 'test_host5']))
351
352
353class RpcInterfaceTestWithStaticLabel(ShardHeartbeatTest,
354                                      frontend_test_utils.FrontendTestMixin):
355
356    _STATIC_LABELS = ['board:lumpy']
357
358    def setUp(self):
359        super(RpcInterfaceTestWithStaticLabel, self).setUp()
360        self._frontend_common_setup()
361        self.old_respect_static_config = rpc_interface.RESPECT_STATIC_LABELS
362        rpc_interface.RESPECT_STATIC_LABELS = True
363        models.RESPECT_STATIC_LABELS = True
364
365
366    def tearDown(self):
367        self._frontend_common_teardown()
368        global_config.global_config.reset_config_values()
369        rpc_interface.RESPECT_STATIC_LABELS = self.old_respect_static_config
370        models.RESPECT_STATIC_LABELS = self.old_respect_static_config
371
372
373    def _fake_host_with_static_labels(self):
374        host1 = models.Host.objects.create(hostname='test_host')
375        label1 = models.Label.objects.create(
376                name='non_static_label1', platform=False)
377        non_static_platform = models.Label.objects.create(
378                name='static_platform', platform=False)
379        static_platform = models.StaticLabel.objects.create(
380                name='static_platform', platform=True)
381        models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
382        host1.static_labels.add(static_platform)
383        host1.labels.add(non_static_platform)
384        host1.labels.add(label1)
385        host1.save()
386        return host1
387
388
389    def test_get_hosts(self):
390        host1 = self._fake_host_with_static_labels()
391        hosts = rpc_interface.get_hosts(hostname=host1.hostname)
392        host = hosts[0]
393
394        self.assertEquals(host['hostname'], 'test_host')
395        self.assertEquals(host['acls'], ['Everyone'])
396        # Respect all labels in afe_hosts_labels.
397        self.assertEquals(host['labels'],
398                          ['non_static_label1', 'static_platform'])
399        # Respect static labels.
400        self.assertEquals(host['platform'], 'static_platform')
401
402
403    def test_get_hosts_multiple_labels(self):
404        self._fake_host_with_static_labels()
405        hosts = rpc_interface.get_hosts(
406                multiple_labels=['non_static_label1', 'static_platform'])
407        host = hosts[0]
408        self.assertEquals(host['hostname'], 'test_host')
409
410
411    def test_delete_static_label(self):
412        label1 = models.Label.smart_get('static')
413
414        host2 = models.Host.objects.all()[1]
415        shard1 = models.Shard.objects.create(hostname='shard1')
416        host2.shard = shard1
417        host2.labels.add(label1)
418        host2.save()
419
420        with patch.object(frontend_wrappers, 'RetryingAFE') as mock_afe:
421            self.assertRaises(error.UnmodifiableLabelException,
422                              rpc_interface.delete_label, label1.id)
423
424
425
426    def test_modify_static_label(self):
427        label1 = models.Label.smart_get('static')
428        self.assertEqual(label1.invalid, 0)
429
430        host2 = models.Host.objects.all()[1]
431        shard1 = models.Shard.objects.create(hostname='shard1')
432        host2.shard = shard1
433        host2.labels.add(label1)
434        host2.save()
435
436        with patch.object(frontend_wrappers, 'RetryingAFE') as mock_afe:
437            self.assertRaises(error.UnmodifiableLabelException,
438                              rpc_interface.modify_label,
439                              label1.id,
440                              invalid=1)
441
442        self.assertEqual(models.Label.smart_get('static').invalid, 0)
443
444
445    def test_multiple_platforms_add_non_static_to_static(self):
446        """Test non-static platform to a host with static platform."""
447        static_platform = models.StaticLabel.objects.create(
448                name='static_platform', platform=True)
449        non_static_platform = models.Label.objects.create(
450                name='static_platform', platform=True)
451        models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
452        platform2 = models.Label.objects.create(name='platform2', platform=True)
453        host1 = models.Host.objects.create(hostname='test_host')
454        host1.static_labels.add(static_platform)
455        host1.labels.add(non_static_platform)
456        host1.save()
457
458        self.assertRaises(model_logic.ValidationError,
459                          rpc_interface.label_add_hosts, id='platform2',
460                          hosts=['test_host'])
461        self.assertRaises(model_logic.ValidationError,
462                          rpc_interface.host_add_labels,
463                          id='test_host', labels=['platform2'])
464        # make sure the platform didn't get added
465        platforms = rpc_interface.get_labels(
466            host__hostname__in=['test_host'], platform=True)
467        self.assertEquals(len(platforms), 1)
468
469
470    def test_multiple_platforms_add_static_to_non_static(self):
471        """Test static platform to a host with non-static platform."""
472        platform1 = models.Label.objects.create(
473                name='static_platform', platform=True)
474        models.ReplacedLabel.objects.create(label_id=platform1.id)
475        static_platform = models.StaticLabel.objects.create(
476                name='static_platform', platform=True)
477        platform2 = models.Label.objects.create(
478                name='platform2', platform=True)
479
480        host1 = models.Host.objects.create(hostname='test_host')
481        host1.labels.add(platform2)
482        host1.save()
483
484        self.assertRaises(model_logic.ValidationError,
485                          rpc_interface.label_add_hosts,
486                          id='static_platform',
487                          hosts=['test_host'])
488        self.assertRaises(model_logic.ValidationError,
489                          rpc_interface.host_add_labels,
490                          id='test_host', labels=['static_platform'])
491        # make sure the platform didn't get added
492        platforms = rpc_interface.get_labels(
493            host__hostname__in=['test_host'], platform=True)
494        self.assertEquals(len(platforms), 1)
495
496
497    def test_label_remove_hosts(self):
498        """Test remove a label of hosts."""
499        label = models.Label.smart_get('static')
500        static_label = models.StaticLabel.objects.create(name='static')
501
502        host1 = models.Host.objects.create(hostname='test_host')
503        host1.labels.add(label)
504        host1.static_labels.add(static_label)
505        host1.save()
506
507        self.assertRaises(error.UnmodifiableLabelException,
508                          rpc_interface.label_remove_hosts,
509                          id='static', hosts=['test_host'])
510
511
512    def test_host_remove_labels(self):
513        """Test remove labels of a given host."""
514        label = models.Label.smart_get('static')
515        label1 = models.Label.smart_get('label1')
516        label2 = models.Label.smart_get('label2')
517        static_label = models.StaticLabel.objects.create(name='static')
518
519        host1 = models.Host.objects.create(hostname='test_host')
520        host1.labels.add(label)
521        host1.labels.add(label1)
522        host1.labels.add(label2)
523        host1.static_labels.add(static_label)
524        host1.save()
525
526        rpc_interface.host_remove_labels(
527                'test_host', ['static', 'label1'])
528        labels = rpc_interface.get_labels(host__hostname__in=['test_host'])
529        # Only non_static label 'label1' is removed.
530        self.assertEquals(len(labels), 2)
531        self.assertEquals(labels[0].get('name'), 'label2')
532
533
534    def test_remove_board_from_shard(self):
535        """test remove a board (static label) from shard."""
536        label = models.Label.smart_get('static')
537        static_label = models.StaticLabel.objects.create(name='static')
538
539        shard = models.Shard.objects.create(hostname='test_shard')
540        shard.labels.add(label)
541
542        host = models.Host.objects.create(hostname='test_host',
543                                          leased=False,
544                                          shard=shard)
545        host.static_labels.add(static_label)
546        host.save()
547
548        rpc_interface.remove_board_from_shard(shard.hostname, label.name)
549        host1 = models.Host.smart_get(host.id)
550        shard1 = models.Shard.smart_get(shard.id)
551        self.assertEqual(host1.shard, None)
552        six.assertCountEqual(self, shard1.labels.all(), [])
553
554
555    def test_check_job_dependencies_success(self):
556        """Test check_job_dependencies successfully."""
557        static_label = models.StaticLabel.objects.create(name='static')
558
559        host = models.Host.objects.create(hostname='test_host')
560        host.static_labels.add(static_label)
561        host.save()
562
563        host1 = models.Host.smart_get(host.id)
564        rpc_utils.check_job_dependencies([host1], ['static'])
565
566
567    def test_check_job_dependencies_fail(self):
568        """Test check_job_dependencies with raising ValidationError."""
569        label = models.Label.smart_get('static')
570        static_label = models.StaticLabel.objects.create(name='static')
571
572        host = models.Host.objects.create(hostname='test_host')
573        host.labels.add(label)
574        host.save()
575
576        host1 = models.Host.smart_get(host.id)
577        self.assertRaises(model_logic.ValidationError,
578                          rpc_utils.check_job_dependencies,
579                          [host1],
580                          ['static'])
581
582    def test_check_job_metahost_dependencies_success(self):
583        """Test check_job_metahost_dependencies successfully."""
584        label1 = models.Label.smart_get('label1')
585        label2 = models.Label.smart_get('label2')
586        label = models.Label.smart_get('static')
587        static_label = models.StaticLabel.objects.create(name='static')
588
589        host = models.Host.objects.create(hostname='test_host')
590        host.static_labels.add(static_label)
591        host.labels.add(label1)
592        host.labels.add(label2)
593        host.save()
594
595        rpc_utils.check_job_metahost_dependencies(
596                [label1, label], [label2.name])
597        rpc_utils.check_job_metahost_dependencies(
598                [label1], [label2.name, static_label.name])
599
600
601    def test_check_job_metahost_dependencies_fail(self):
602        """Test check_job_metahost_dependencies with raising errors."""
603        label1 = models.Label.smart_get('label1')
604        label2 = models.Label.smart_get('label2')
605        label = models.Label.smart_get('static')
606        static_label = models.StaticLabel.objects.create(name='static')
607
608        host = models.Host.objects.create(hostname='test_host')
609        host.labels.add(label1)
610        host.labels.add(label2)
611        host.save()
612
613        self.assertRaises(error.NoEligibleHostException,
614                          rpc_utils.check_job_metahost_dependencies,
615                          [label1, label], [label2.name])
616        self.assertRaises(error.NoEligibleHostException,
617                          rpc_utils.check_job_metahost_dependencies,
618                          [label1], [label2.name, static_label.name])
619
620
621    def _createShardAndHostWithStaticLabel(self,
622                                           shard_hostname='shard1',
623                                           host_hostname='test_host1',
624                                           label_name='board:lumpy'):
625        label = models.Label.objects.create(name=label_name)
626
627        shard = models.Shard.objects.create(hostname=shard_hostname)
628        shard.labels.add(label)
629
630        host = models.Host.objects.create(hostname=host_hostname, leased=False,
631                                          shard=shard)
632        host.labels.add(label)
633        if label_name in self._STATIC_LABELS:
634            models.ReplacedLabel.objects.create(label_id=label.id)
635            static_label = models.StaticLabel.objects.create(name=label_name)
636            host.static_labels.add(static_label)
637
638        return shard, host, label
639
640
641    def testShardHeartbeatFetchHostlessJob(self):
642        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
643                host_hostname='test_host1')
644        self._testShardHeartbeatFetchHostlessJobHelper(host1)
645
646
647    def testShardHeartbeatIncorrectHosts(self):
648        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
649                host_hostname='test_host1')
650        self._testShardHeartbeatIncorrectHostsHelper(host1)
651
652
653    def testShardHeartbeatLabelRemovalRace(self):
654        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
655                host_hostname='test_host1')
656        self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
657
658
659    def testShardRetrieveJobs(self):
660        shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
661        shard2, host2, label2 = self._createShardAndHostWithStaticLabel(
662            'shard2', 'test_host2', 'board:grumpy')
663        self._testShardRetrieveJobsHelper(shard1, host1, label1,
664                                          shard2, host2, label2)
665
666
667    def testResendJobsAfterFailedHeartbeat(self):
668        shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
669        self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
670
671
672    def testResendHostsAfterFailedHeartbeat(self):
673        shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
674                host_hostname='test_host1')
675        self._testResendHostsAfterFailedHeartbeatHelper(host1)
676
677
678class RpcInterfaceTest(unittest.TestCase,
679                       frontend_test_utils.FrontendTestMixin):
680    def setUp(self):
681        self._frontend_common_setup()
682
683
684    def tearDown(self):
685        self._frontend_common_teardown()
686        global_config.global_config.reset_config_values()
687
688
689    def test_validation(self):
690        # omit a required field
691        self.assertRaises(model_logic.ValidationError, rpc_interface.add_label,
692                          name=None)
693        # violate uniqueness constraint
694        self.assertRaises(model_logic.ValidationError, rpc_interface.add_host,
695                          hostname='host1')
696
697
698    def test_multiple_platforms(self):
699        platform2 = models.Label.objects.create(name='platform2', platform=True)
700        self.assertRaises(model_logic.ValidationError,
701                          rpc_interface. label_add_hosts, id='platform2',
702                          hosts=['host1', 'host2'])
703        self.assertRaises(model_logic.ValidationError,
704                          rpc_interface.host_add_labels,
705                          id='host1', labels=['platform2'])
706        # make sure the platform didn't get added
707        platforms = rpc_interface.get_labels(
708            host__hostname__in=['host1', 'host2'], platform=True)
709        self.assertEquals(len(platforms), 1)
710        self.assertEquals(platforms[0]['name'], 'myplatform')
711
712
713    def _check_hostnames(self, hosts, expected_hostnames):
714        self.assertEquals(set(host['hostname'] for host in hosts),
715                          set(expected_hostnames))
716
717
718    def test_ping_db(self):
719        self.assertEquals(rpc_interface.ping_db(), [True])
720
721
722    def test_get_hosts_by_attribute(self):
723        host1 = models.Host.objects.create(hostname='test_host1')
724        host1.set_attribute('test_attribute1', 'test_value1')
725        host2 = models.Host.objects.create(hostname='test_host2')
726        host2.set_attribute('test_attribute1', 'test_value1')
727
728        hosts = rpc_interface.get_hosts_by_attribute(
729                'test_attribute1', 'test_value1')
730        self.assertEquals(set(hosts),
731                          set(['test_host1', 'test_host2']))
732
733
734    def test_get_host_attribute(self):
735        host1 = models.Host.objects.create(hostname='test_host1')
736        host1.set_attribute('test_attribute1', 'test_value1')
737        host2 = models.Host.objects.create(hostname='test_host2')
738        host2.set_attribute('test_attribute1', 'test_value1')
739
740        attributes = rpc_interface.get_host_attribute(
741                'test_attribute1',
742                hostname__in=['test_host1', 'test_host2'])
743        hosts = [attr['host'] for attr in attributes]
744        values = [attr['value'] for attr in attributes]
745        self.assertEquals(set(hosts),
746                          set(['test_host1', 'test_host2']))
747        self.assertEquals(set(values), set(['test_value1']))
748
749
750    def test_get_hosts(self):
751        hosts = rpc_interface.get_hosts()
752        self._check_hostnames(hosts, [host.hostname for host in self.hosts])
753
754        hosts = rpc_interface.get_hosts(hostname='host1')
755        self._check_hostnames(hosts, ['host1'])
756        host = hosts[0]
757        self.assertEquals(sorted(host['labels']), ['label1', 'myplatform'])
758        self.assertEquals(host['platform'], 'myplatform')
759        self.assertEquals(host['acls'], ['my_acl'])
760        self.assertEquals(host['attributes'], {})
761
762
763    def test_get_hosts_multiple_labels(self):
764        hosts = rpc_interface.get_hosts(
765                multiple_labels=['myplatform', 'label1'])
766        self._check_hostnames(hosts, ['host1'])
767
768
769    def test_job_keyvals(self):
770        keyval_dict = {'mykey': 'myvalue'}
771        job_id = rpc_interface.create_job(name='test',
772                                          priority=priorities.Priority.DEFAULT,
773                                          control_file='foo',
774                                          control_type=CLIENT,
775                                          hosts=['host1'],
776                                          keyvals=keyval_dict)
777        jobs = rpc_interface.get_jobs(id=job_id)
778        self.assertEquals(len(jobs), 1)
779        self.assertEquals(jobs[0]['keyvals'], keyval_dict)
780
781
782    def test_get_jobs_summary(self):
783        job = self._create_job(hosts=range(1, 4))
784        entries = list(job.hostqueueentry_set.all())
785        entries[1].status = _hqe_status.FAILED
786        entries[1].save()
787        entries[2].status = _hqe_status.FAILED
788        entries[2].aborted = True
789        entries[2].save()
790
791        # Mock up tko_rpc_interface.get_status_counts.
792        with patch.object(rpc_interface.tko_rpc_interface,
793                          'get_status_counts',
794                          return_value=None):
795            job_summaries = rpc_interface.get_jobs_summary(id=job.id)
796            self.assertEquals(len(job_summaries), 1)
797            summary = job_summaries[0]
798            self.assertEquals(summary['status_counts'], {
799                    'Queued': 1,
800                    'Failed': 2
801            })
802
803
804    def _check_job_ids(self, actual_job_dicts, expected_jobs):
805        self.assertEquals(
806                set(job_dict['id'] for job_dict in actual_job_dicts),
807                set(job.id for job in expected_jobs))
808
809
810    def test_get_jobs_status_filters(self):
811        HqeStatus = models.HostQueueEntry.Status
812        def create_two_host_job():
813            return self._create_job(hosts=[1, 2])
814        def set_hqe_statuses(job, first_status, second_status):
815            entries = job.hostqueueentry_set.all()
816            entries[0].update_object(status=first_status)
817            entries[1].update_object(status=second_status)
818
819        queued = create_two_host_job()
820
821        queued_and_running = create_two_host_job()
822        set_hqe_statuses(queued_and_running, HqeStatus.QUEUED,
823                           HqeStatus.RUNNING)
824
825        running_and_complete = create_two_host_job()
826        set_hqe_statuses(running_and_complete, HqeStatus.RUNNING,
827                           HqeStatus.COMPLETED)
828
829        complete = create_two_host_job()
830        set_hqe_statuses(complete, HqeStatus.COMPLETED, HqeStatus.COMPLETED)
831
832        started_but_inactive = create_two_host_job()
833        set_hqe_statuses(started_but_inactive, HqeStatus.QUEUED,
834                           HqeStatus.COMPLETED)
835
836        parsing = create_two_host_job()
837        set_hqe_statuses(parsing, HqeStatus.PARSING, HqeStatus.PARSING)
838
839        self._check_job_ids(rpc_interface.get_jobs(not_yet_run=True), [queued])
840        self._check_job_ids(rpc_interface.get_jobs(running=True),
841                      [queued_and_running, running_and_complete,
842                       started_but_inactive, parsing])
843        self._check_job_ids(rpc_interface.get_jobs(finished=True), [complete])
844
845
846    def test_get_jobs_type_filters(self):
847        self.assertRaises(AssertionError, rpc_interface.get_jobs,
848                          suite=True, sub=True)
849        self.assertRaises(AssertionError, rpc_interface.get_jobs,
850                          suite=True, standalone=True)
851        self.assertRaises(AssertionError, rpc_interface.get_jobs,
852                          standalone=True, sub=True)
853
854        parent_job = self._create_job(hosts=[1])
855        child_jobs = self._create_job(hosts=[1, 2],
856                                      parent_job_id=parent_job.id)
857        standalone_job = self._create_job(hosts=[1])
858
859        self._check_job_ids(rpc_interface.get_jobs(suite=True), [parent_job])
860        self._check_job_ids(rpc_interface.get_jobs(sub=True), [child_jobs])
861        self._check_job_ids(rpc_interface.get_jobs(standalone=True),
862                            [standalone_job])
863
864
865    def _create_job_helper(self, **kwargs):
866        return rpc_interface.create_job(name='test',
867                                        priority=priorities.Priority.DEFAULT,
868                                        control_file='control file',
869                                        control_type=SERVER, **kwargs)
870
871
872    def test_one_time_hosts(self):
873        job = self._create_job_helper(one_time_hosts=['testhost'])
874        host = models.Host.objects.get(hostname='testhost')
875        self.assertEquals(host.invalid, True)
876        self.assertEquals(host.labels.count(), 0)
877        self.assertEquals(host.aclgroup_set.count(), 0)
878
879
880    def test_create_job_duplicate_hosts(self):
881        self.assertRaises(model_logic.ValidationError, self._create_job_helper,
882                          hosts=[1, 1])
883
884
885    def test_create_unrunnable_metahost_job(self):
886        self.assertRaises(error.NoEligibleHostException,
887                          self._create_job_helper, meta_hosts=['unused'])
888
889
890    def test_create_hostless_job(self):
891        job_id = self._create_job_helper(hostless=True)
892        job = models.Job.objects.get(pk=job_id)
893        queue_entries = job.hostqueueentry_set.all()
894        self.assertEquals(len(queue_entries), 1)
895        self.assertEquals(queue_entries[0].host, None)
896        self.assertEquals(queue_entries[0].meta_host, None)
897
898
899    def _setup_special_tasks(self):
900        host = self.hosts[0]
901
902        job1 = self._create_job(hosts=[1])
903        job2 = self._create_job(hosts=[1])
904
905        entry1 = job1.hostqueueentry_set.all()[0]
906        entry1.update_object(started_on=datetime.datetime(2009, 1, 2),
907                             execution_subdir='host1')
908        entry2 = job2.hostqueueentry_set.all()[0]
909        entry2.update_object(started_on=datetime.datetime(2009, 1, 3),
910                             execution_subdir='host1')
911
912        self.task1 = models.SpecialTask.objects.create(
913                host=host, task=models.SpecialTask.Task.VERIFY,
914                time_started=datetime.datetime(2009, 1, 1), # ran before job 1
915                is_complete=True, requested_by=models.User.current_user())
916        self.task2 = models.SpecialTask.objects.create(
917                host=host, task=models.SpecialTask.Task.VERIFY,
918                queue_entry=entry2, # ran with job 2
919                is_active=True, requested_by=models.User.current_user())
920        self.task3 = models.SpecialTask.objects.create(
921                host=host, task=models.SpecialTask.Task.VERIFY,
922                requested_by=models.User.current_user()) # not yet run
923
924
925    def test_get_special_tasks(self):
926        self._setup_special_tasks()
927        tasks = rpc_interface.get_special_tasks(host__hostname='host1',
928                                                queue_entry__isnull=True)
929        self.assertEquals(len(tasks), 2)
930        self.assertEquals(tasks[0]['task'], models.SpecialTask.Task.VERIFY)
931        self.assertEquals(tasks[0]['is_active'], False)
932        self.assertEquals(tasks[0]['is_complete'], True)
933
934
935    def test_get_latest_special_task(self):
936        # a particular usage of get_special_tasks()
937        self._setup_special_tasks()
938        self.task2.time_started = datetime.datetime(2009, 1, 2)
939        self.task2.save()
940
941        tasks = rpc_interface.get_special_tasks(
942                host__hostname='host1', task=models.SpecialTask.Task.VERIFY,
943                time_started__isnull=False, sort_by=['-time_started'],
944                query_limit=1)
945        self.assertEquals(len(tasks), 1)
946        self.assertEquals(tasks[0]['id'], 2)
947
948
949    def _common_entry_check(self, entry_dict):
950        self.assertEquals(entry_dict['host']['hostname'], 'host1')
951        self.assertEquals(entry_dict['job']['id'], 2)
952
953
954    def test_get_host_queue_entries_and_special_tasks(self):
955        self._setup_special_tasks()
956
957        host = self.hosts[0].id
958        entries_and_tasks = (
959                rpc_interface.get_host_queue_entries_and_special_tasks(host))
960
961        paths = [entry['execution_path'] for entry in entries_and_tasks]
962        self.assertEquals(paths, ['hosts/host1/3-verify',
963                                  '2-autotest_system/host1',
964                                  'hosts/host1/2-verify',
965                                  '1-autotest_system/host1',
966                                  'hosts/host1/1-verify'])
967
968        verify2 = entries_and_tasks[2]
969        self._common_entry_check(verify2)
970        self.assertEquals(verify2['type'], 'Verify')
971        self.assertEquals(verify2['status'], 'Running')
972        self.assertEquals(verify2['execution_path'], 'hosts/host1/2-verify')
973
974        entry2 = entries_and_tasks[1]
975        self._common_entry_check(entry2)
976        self.assertEquals(entry2['type'], 'Job')
977        self.assertEquals(entry2['status'], 'Queued')
978        self.assertEquals(entry2['started_on'], '2009-01-03 00:00:00')
979
980
981    def _create_hqes_and_start_time_index_entries(self):
982        shard = models.Shard.objects.create(hostname='shard')
983        job = self._create_job(shard=shard, control_file='foo')
984        HqeStatus = models.HostQueueEntry.Status
985
986        models.HostQueueEntry(
987            id=1, job=job, started_on='2017-01-01',
988            status=HqeStatus.QUEUED).save()
989        models.HostQueueEntry(
990            id=2, job=job, started_on='2017-01-02',
991            status=HqeStatus.QUEUED).save()
992        models.HostQueueEntry(
993            id=3, job=job, started_on='2017-01-03',
994            status=HqeStatus.QUEUED).save()
995
996        models.HostQueueEntryStartTimes(
997            insert_time='2017-01-03', highest_hqe_id=3).save()
998        models.HostQueueEntryStartTimes(
999            insert_time='2017-01-02', highest_hqe_id=2).save()
1000        models.HostQueueEntryStartTimes(
1001            insert_time='2017-01-01', highest_hqe_id=1).save()
1002
1003    def test_get_host_queue_entries_by_insert_time(self):
1004        """Check the insert_time_after and insert_time_before constraints."""
1005        self._create_hqes_and_start_time_index_entries()
1006        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1007            insert_time_after='2017-01-01')
1008        self.assertEquals(len(hqes), 3)
1009
1010        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1011            insert_time_after='2017-01-02')
1012        self.assertEquals(len(hqes), 2)
1013
1014        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1015            insert_time_after='2017-01-03')
1016        self.assertEquals(len(hqes), 1)
1017
1018        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1019            insert_time_before='2017-01-01')
1020        self.assertEquals(len(hqes), 1)
1021
1022        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1023            insert_time_before='2017-01-02')
1024        self.assertEquals(len(hqes), 2)
1025
1026        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1027            insert_time_before='2017-01-03')
1028        self.assertEquals(len(hqes), 3)
1029
1030
1031    def test_get_host_queue_entries_by_insert_time_with_missing_index_row(self):
1032        """Shows that the constraints are approximate.
1033
1034        The query may return rows which are actually outside of the bounds
1035        given, if the index table does not have an entry for the specific time.
1036        """
1037        self._create_hqes_and_start_time_index_entries()
1038        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1039            insert_time_before='2016-12-01')
1040        self.assertEquals(len(hqes), 1)
1041
1042    def test_get_hqe_by_insert_time_with_before_and_after(self):
1043        self._create_hqes_and_start_time_index_entries()
1044        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1045            insert_time_before='2017-01-02',
1046            insert_time_after='2017-01-02')
1047        self.assertEquals(len(hqes), 1)
1048
1049    def test_get_hqe_by_insert_time_and_id_constraint(self):
1050        self._create_hqes_and_start_time_index_entries()
1051        # The time constraint is looser than the id constraint, so the time
1052        # constraint should take precedence.
1053        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1054            insert_time_before='2017-01-02',
1055            id__lte=1)
1056        self.assertEquals(len(hqes), 1)
1057
1058        # Now make the time constraint tighter than the id constraint.
1059        hqes = rpc_interface.get_host_queue_entries_by_insert_time(
1060            insert_time_before='2017-01-01',
1061            id__lte=42)
1062        self.assertEquals(len(hqes), 1)
1063
1064    def test_view_invalid_host(self):
1065        # RPCs used by View Host page should work for invalid hosts
1066        self._create_job_helper(hosts=[1])
1067        host = self.hosts[0]
1068        host.delete()
1069
1070        self.assertEquals(1, rpc_interface.get_num_hosts(hostname='host1',
1071                                                         valid_only=False))
1072        data = rpc_interface.get_hosts(hostname='host1', valid_only=False)
1073        self.assertEquals(1, len(data))
1074
1075        self.assertEquals(1, rpc_interface.get_num_host_queue_entries(
1076                host__hostname='host1'))
1077        data = rpc_interface.get_host_queue_entries(host__hostname='host1')
1078        self.assertEquals(1, len(data))
1079
1080        count = rpc_interface.get_num_host_queue_entries_and_special_tasks(
1081                host=host.id)
1082        self.assertEquals(1, count)
1083        data = rpc_interface.get_host_queue_entries_and_special_tasks(
1084                host=host.id)
1085        self.assertEquals(1, len(data))
1086
1087
1088    def test_reverify_hosts(self):
1089        hostname_list = rpc_interface.reverify_hosts(id__in=[1, 2])
1090        self.assertEquals(hostname_list, ['host1', 'host2'])
1091        tasks = rpc_interface.get_special_tasks()
1092        self.assertEquals(len(tasks), 2)
1093        self.assertEquals(set(task['host']['id'] for task in tasks),
1094                          set([1, 2]))
1095
1096        task = tasks[0]
1097        self.assertEquals(task['task'], models.SpecialTask.Task.VERIFY)
1098        self.assertEquals(task['requested_by'], 'autotest_system')
1099
1100
1101    def test_repair_hosts(self):
1102        hostname_list = rpc_interface.repair_hosts(id__in=[1, 2])
1103        self.assertEquals(hostname_list, ['host1', 'host2'])
1104        tasks = rpc_interface.get_special_tasks()
1105        self.assertEquals(len(tasks), 2)
1106        self.assertEquals(set(task['host']['id'] for task in tasks),
1107                          set([1, 2]))
1108
1109        task = tasks[0]
1110        self.assertEquals(task['task'], models.SpecialTask.Task.REPAIR)
1111        self.assertEquals(task['requested_by'], 'autotest_system')
1112
1113
1114    def _modify_host_helper(self, on_shard=False, host_on_shard=False):
1115        shard_hostname = 'shard1'
1116        if on_shard:
1117            global_config.global_config.override_config_value(
1118                'SHARD', 'shard_hostname', shard_hostname)
1119
1120        host = models.Host.objects.all()[0]
1121        if host_on_shard:
1122            shard = models.Shard.objects.create(hostname=shard_hostname)
1123            host.shard = shard
1124            host.save()
1125
1126        self.assertFalse(host.locked)
1127
1128        with MagicMock() as afe_instance, patch.object(
1129                frontend_wrappers, 'RetryingAFE',
1130                return_value=afe_instance) as mock_afe:
1131            rpc_interface.modify_host(id=host.id,
1132                                      locked=True,
1133                                      lock_reason='_modify_host_helper lock',
1134                                      lock_time=datetime.datetime(
1135                                              2015, 12, 15))
1136
1137            host = models.Host.objects.get(pk=host.id)
1138            if on_shard:
1139                # modify_host on shard does nothing but routing the RPC to
1140                # main.
1141                self.assertFalse(host.locked)
1142            else:
1143                self.assertTrue(host.locked)
1144            if host_on_shard and not on_shard:
1145                mock_afe.assert_called_with(server=shard_hostname, user=None)
1146                afe_instance.run.assert_called_with(
1147                        'modify_host_local',
1148                        id=host.id,
1149                        locked=True,
1150                        lock_reason='_modify_host_helper lock',
1151                        lock_time=datetime.datetime(2015, 12, 15))
1152            elif on_shard:
1153                mock_afe.assert_called_with(
1154                        server=server_utils.get_global_afe_hostname(),
1155                        user=None)
1156                afe_instance.run.assert_called_with(
1157                        'modify_host',
1158                        id=host.id,
1159                        locked=True,
1160                        lock_reason='_modify_host_helper lock',
1161                        lock_time=datetime.datetime(2015, 12, 15))
1162
1163
1164    def test_modify_host_on_main_host_on_main(self):
1165        """Call modify_host to main for host in main."""
1166        self._modify_host_helper()
1167
1168
1169    def test_modify_host_on_main_host_on_shard(self):
1170        """Call modify_host to main for host in shard."""
1171        self._modify_host_helper(host_on_shard=True)
1172
1173
1174    def test_modify_host_on_shard(self):
1175        """Call modify_host to shard for host in shard."""
1176        self._modify_host_helper(on_shard=True, host_on_shard=True)
1177
1178
1179    def test_modify_hosts_on_main_host_on_shard(self):
1180        """Ensure calls to modify_hosts are correctly forwarded to shards."""
1181        host1 = models.Host.objects.all()[0]
1182        host2 = models.Host.objects.all()[1]
1183
1184        shard1 = models.Shard.objects.create(hostname='shard1')
1185        host1.shard = shard1
1186        host1.save()
1187
1188        shard2 = models.Shard.objects.create(hostname='shard2')
1189        host2.shard = shard2
1190        host2.save()
1191
1192        self.assertFalse(host1.locked)
1193        self.assertFalse(host2.locked)
1194
1195        with MagicMock() as mock_afe1, MagicMock() as mock_afe2, patch.object(
1196                frontend_wrappers,
1197                'RetryingAFE',
1198                side_effect=(lambda server='', user=None: mock_afe1 if server
1199                             == 'shard1' else mock_afe2)) as mock_afe:
1200            # The statuses of one host might differ on main and shard.
1201            # Filters are always applied on the main. So the host on the shard
1202            # will be affected no matter what the host status is.
1203            filters_to_use = {'status': 'Ready'}
1204
1205            rpc_interface.modify_hosts(host_filter_data={'status': 'Ready'},
1206                                       update_data={
1207                                               'locked':
1208                                               True,
1209                                               'lock_reason':
1210                                               'Testing forward to shard',
1211                                               'lock_time':
1212                                               datetime.datetime(2015, 12, 15)
1213                                       })
1214
1215            host1 = models.Host.objects.get(pk=host1.id)
1216            self.assertTrue(host1.locked)
1217            host2 = models.Host.objects.get(pk=host2.id)
1218            self.assertTrue(host2.locked)
1219            mock_afe1.run.assert_called_with(
1220                    'modify_hosts_local',
1221                    host_filter_data={'id__in': [shard1.id, shard2.id]},
1222                    update_data={
1223                            'locked': True,
1224                            'lock_reason': 'Testing forward to shard',
1225                            'lock_time': datetime.datetime(2015, 12, 15)
1226                    })
1227            mock_afe2.run.assert_called_with(
1228                    'modify_hosts_local',
1229                    host_filter_data={'id__in': [shard1.id, shard2.id]},
1230                    update_data={
1231                            'locked': True,
1232                            'lock_reason': 'Testing forward to shard',
1233                            'lock_time': datetime.datetime(2015, 12, 15)
1234                    })
1235
1236
1237    def test_delete_host(self):
1238        """Ensure an RPC is made on delete a host, if it is on a shard."""
1239        host1 = models.Host.objects.all()[0]
1240        shard1 = models.Shard.objects.create(hostname='shard1')
1241        host1.shard = shard1
1242        host1.save()
1243        host1_id = host1.id
1244
1245        with MagicMock() as mock_afe1, patch.object(
1246                frontend_wrappers, 'RetryingAFE',
1247                return_value=mock_afe1) as mock_afe:
1248            rpc_interface.delete_host(id=host1.id)
1249
1250            self.assertRaises(models.Host.DoesNotExist, models.Host.smart_get,
1251                              host1_id)
1252
1253            mock_afe.assert_called_with(server='shard1', user=None)
1254            mock_afe1.run.assert_called_with('delete_host', id=host1.id)
1255
1256
1257    def test_delete_shard(self):
1258        """Ensure the RPC can delete a shard."""
1259        host1 = models.Host.objects.all()[0]
1260        shard1 = models.Shard.objects.create(hostname='shard1')
1261        host1.shard = shard1
1262        host1.save()
1263
1264        rpc_interface.delete_shard(hostname=shard1.hostname)
1265
1266        host1 = models.Host.smart_get(host1.id)
1267        self.assertIsNone(host1.shard)
1268        self.assertRaises(models.Shard.DoesNotExist,
1269                          models.Shard.smart_get, shard1.hostname)
1270
1271
1272    def test_modify_label(self):
1273        label1 = models.Label.objects.all()[0]
1274        self.assertEqual(label1.invalid, 0)
1275
1276        host2 = models.Host.objects.all()[1]
1277        shard1 = models.Shard.objects.create(hostname='shard1')
1278        host2.shard = shard1
1279        host2.labels.add(label1)
1280        host2.save()
1281
1282        with MagicMock() as mock_afe1, patch.object(
1283                frontend_wrappers, 'RetryingAFE',
1284                return_value=mock_afe1) as mock_afe:
1285            rpc_interface.modify_label(label1.id, invalid=1)
1286
1287            self.assertEqual(models.Label.objects.all()[0].invalid, 1)
1288            mock_afe.assert_called_with(server='shard1', user=None)
1289            mock_afe1.run.assert_called_with('modify_label',
1290                                             id=label1.id,
1291                                             invalid=1)
1292
1293
1294    def test_delete_label(self):
1295        label1 = models.Label.objects.all()[0]
1296
1297        host2 = models.Host.objects.all()[1]
1298        shard1 = models.Shard.objects.create(hostname='shard1')
1299        host2.shard = shard1
1300        host2.labels.add(label1)
1301        host2.save()
1302
1303        with MagicMock() as mock_afe1, patch.object(
1304                frontend_wrappers, 'RetryingAFE',
1305                return_value=mock_afe1) as mock_afe:
1306            rpc_interface.delete_label(id=label1.id)
1307
1308            self.assertRaises(models.Label.DoesNotExist,
1309                              models.Label.smart_get, label1.id)
1310            mock_afe.assert_called_with(server='shard1', user=None)
1311            mock_afe1.run.assert_called_with('delete_label', id=label1.id)
1312
1313
1314    def test_get_image_for_job_with_keyval_build(self):
1315        keyval_dict = {'build': 'cool-image'}
1316        job_id = rpc_interface.create_job(name='test',
1317                                          priority=priorities.Priority.DEFAULT,
1318                                          control_file='foo',
1319                                          control_type=CLIENT,
1320                                          hosts=['host1'],
1321                                          keyvals=keyval_dict)
1322        job = models.Job.objects.get(id=job_id)
1323        self.assertIsNotNone(job)
1324        image = rpc_interface._get_image_for_job(job, True)
1325        self.assertEquals('cool-image', image)
1326
1327
1328    def test_get_image_for_job_with_keyval_builds(self):
1329        keyval_dict = {'builds': {'cros-version': 'cool-image'}}
1330        job_id = rpc_interface.create_job(name='test',
1331                                          priority=priorities.Priority.DEFAULT,
1332                                          control_file='foo',
1333                                          control_type=CLIENT,
1334                                          hosts=['host1'],
1335                                          keyvals=keyval_dict)
1336        job = models.Job.objects.get(id=job_id)
1337        self.assertIsNotNone(job)
1338        image = rpc_interface._get_image_for_job(job, True)
1339        self.assertEquals('cool-image', image)
1340
1341
1342    def test_get_image_for_job_with_control_build(self):
1343        CONTROL_FILE = """build='cool-image'
1344        """
1345        job_id = rpc_interface.create_job(name='test',
1346                                          priority=priorities.Priority.DEFAULT,
1347                                          control_file='foo',
1348                                          control_type=CLIENT,
1349                                          hosts=['host1'])
1350        job = models.Job.objects.get(id=job_id)
1351        self.assertIsNotNone(job)
1352        job.control_file = CONTROL_FILE
1353        image = rpc_interface._get_image_for_job(job, True)
1354        self.assertEquals('cool-image', image)
1355
1356
1357    def test_get_image_for_job_with_control_builds(self):
1358        CONTROL_FILE = """builds={'cros-version': 'cool-image'}
1359        """
1360        job_id = rpc_interface.create_job(name='test',
1361                                          priority=priorities.Priority.DEFAULT,
1362                                          control_file='foo',
1363                                          control_type=CLIENT,
1364                                          hosts=['host1'])
1365        job = models.Job.objects.get(id=job_id)
1366        self.assertIsNotNone(job)
1367        job.control_file = CONTROL_FILE
1368        image = rpc_interface._get_image_for_job(job, True)
1369        self.assertEquals('cool-image', image)
1370
1371
1372class ExtraRpcInterfaceTest(frontend_test_utils.FrontendTestMixin,
1373                            ShardHeartbeatTest):
1374    """Unit tests for functions originally in site_rpc_interface.py.
1375
1376    @var _NAME: fake suite name.
1377    @var _BOARD: fake board to reimage.
1378    @var _BUILD: fake build with which to reimage.
1379    @var _PRIORITY: fake priority with which to reimage.
1380    """
1381    _NAME = 'name'
1382    _BOARD = 'link'
1383    _BUILD = 'link-release/R36-5812.0.0'
1384    _BUILDS = {provision.CROS_VERSION_PREFIX: _BUILD}
1385    _PRIORITY = priorities.Priority.DEFAULT
1386    _TIMEOUT = 24
1387
1388
1389    def setUp(self):
1390        super(ExtraRpcInterfaceTest, self).setUp()
1391        self._SUITE_NAME = suite_common.canonicalize_suite_name(
1392            self._NAME)
1393        patcher = patch.object(dev_server, 'ImageServer')
1394        self.dev_server = patcher.start()
1395        self.addCleanup(patcher.stop)
1396        self._frontend_common_setup(fill_data=False)
1397
1398
1399    def tearDown(self):
1400        self._frontend_common_teardown()
1401        if self.dev_server.resolve.call_count > 0:
1402            self.dev_server.resolve.assert_called_with(self._BUILD,
1403                                                       None,
1404                                                       ban_list=None)
1405
1406    def _setupDevserver(self):
1407        self.dev_server.resolve.return_value = self.dev_server
1408
1409
1410    def _mockDevServerGetter(self, get_control_file=True):
1411        self._setupDevserver()
1412        if get_control_file:
1413            patcher = patch.object(control_file_getter, 'DevServerGetter')
1414            self.getter = patcher.start()
1415            self.getter.create.return_value = self.getter
1416            self.addCleanup(patcher.stop)
1417
1418
1419    def _mockRpcUtils(self, to_return, control_file_substring=''):
1420        """Fake out the autotest rpc_utils module with a mockable class.
1421
1422        @param to_return: the value that rpc_utils.create_job_common() should
1423                          be mocked out to return.
1424        @param control_file_substring: A substring that is expected to appear
1425                                       in the control file output string that
1426                                       is passed to create_job_common.
1427                                       Default: ''
1428        """
1429        download_started_time = constants.DOWNLOAD_STARTED_TIME
1430        payload_finished_time = constants.PAYLOAD_FINISHED_TIME
1431        patcher = patch.object(rpc_utils,
1432                               'create_job_common',
1433                               return_value=to_return)
1434        self.rpc_utils = patcher.start()
1435        self.addCleanup(patcher.stop)
1436
1437    def testStageBuildFail(self):
1438        """Ensure that a failure to stage the desired build fails the RPC."""
1439        self._setupDevserver()
1440
1441        self.dev_server.hostname = 'mox_url'
1442        self.dev_server.stage_artifacts.side_effect = (
1443                dev_server.DevServerException())
1444        self.assertRaises(error.StageControlFileFailure,
1445                          rpc_interface.create_suite_job,
1446                          name=self._NAME,
1447                          board=self._BOARD,
1448                          builds=self._BUILDS,
1449                          pool=None)
1450        self.dev_server.stage_artifacts.assert_called_with(
1451                image=self._BUILD, artifacts=['test_suites', 'control_files'])
1452
1453
1454    def testGetControlFileFail(self):
1455        """Ensure that a failure to get needed control file fails the RPC."""
1456        self._mockDevServerGetter()
1457
1458        self.dev_server.hostname = 'mox_url'
1459        self.dev_server.stage_artifacts.return_value = True
1460        self.getter.get_control_file_contents_by_name.return_value = None
1461
1462        self.assertRaises(error.ControlFileEmpty,
1463                          rpc_interface.create_suite_job,
1464                          name=self._NAME,
1465                          board=self._BOARD,
1466                          builds=self._BUILDS,
1467                          pool=None)
1468        self.dev_server.stage_artifacts.assert_called_with(
1469                image=self._BUILD, artifacts=['test_suites', 'control_files'])
1470        self.getter.get_control_file_contents_by_name.assert_called_with(
1471                self._SUITE_NAME)
1472
1473
1474    def testGetControlFileListFail(self):
1475        """Ensure that a failure to get needed control file fails the RPC."""
1476        self._mockDevServerGetter()
1477
1478        self.dev_server.hostname = 'mox_url'
1479        self.dev_server.stage_artifacts.return_value = True
1480        self.getter.get_control_file_contents_by_name.side_effect = (
1481                error.NoControlFileList())
1482
1483        self.assertRaises(error.NoControlFileList,
1484                          rpc_interface.create_suite_job,
1485                          name=self._NAME,
1486                          board=self._BOARD,
1487                          builds=self._BUILDS,
1488                          pool=None)
1489        self.dev_server.stage_artifacts.assert_called_with(
1490                image=self._BUILD, artifacts=['test_suites', 'control_files'])
1491        self.getter.get_control_file_contents_by_name.assert_called_with(
1492                self._SUITE_NAME)
1493
1494
1495    def testCreateSuiteJobFail(self):
1496        """Ensure that failure to schedule the suite job fails the RPC."""
1497        self._mockDevServerGetter()
1498
1499        self.dev_server.hostname = 'mox_url'
1500        self.dev_server.stage_artifacts.return_value = True
1501        self.getter.get_control_file_contents_by_name.return_value = 'f'
1502        self.dev_server.url.return_value = 'mox_url'
1503        self._mockRpcUtils(-1)
1504
1505        self.assertEquals(
1506            rpc_interface.create_suite_job(name=self._NAME,
1507                                           board=self._BOARD,
1508                                           builds=self._BUILDS, pool=None),
1509            -1)
1510        self.dev_server.stage_artifacts.assert_called_with(
1511                image=self._BUILD, artifacts=['test_suites', 'control_files'])
1512        self.getter.get_control_file_contents_by_name.assert_called_with(
1513                self._SUITE_NAME)
1514
1515
1516    def testCreateSuiteJobSuccess(self):
1517        """Ensures that success results in a successful RPC."""
1518        self._mockDevServerGetter()
1519
1520        self.dev_server.hostname = 'mox_url'
1521        self.dev_server.stage_artifacts.return_value = True
1522        self.getter.get_control_file_contents_by_name.return_value = 'f'
1523        self.dev_server.url.return_value = 'mox_url'
1524        job_id = 5
1525        self._mockRpcUtils(job_id)
1526
1527        self.assertEquals(
1528            rpc_interface.create_suite_job(name=self._NAME,
1529                                           board=self._BOARD,
1530                                           builds=self._BUILDS,
1531                                           pool=None),
1532            job_id)
1533        self.dev_server.stage_artifacts.assert_called_with(
1534                image=self._BUILD, artifacts=['test_suites', 'control_files'])
1535        self.getter.get_control_file_contents_by_name.assert_called_with(
1536                self._SUITE_NAME)
1537
1538
1539    def testCreateSuiteJobNoHostCheckSuccess(self):
1540        """Ensures that success results in a successful RPC."""
1541        self._mockDevServerGetter()
1542
1543        self.dev_server.hostname = 'mox_url'
1544        self.dev_server.stage_artifacts.return_value = True
1545        self.getter.get_control_file_contents_by_name.return_value = 'f'
1546        self.dev_server.url.return_value = 'mox_url'
1547        job_id = 5
1548        self._mockRpcUtils(job_id)
1549
1550        self.assertEquals(
1551                rpc_interface.create_suite_job(name=self._NAME,
1552                                               board=self._BOARD,
1553                                               builds=self._BUILDS,
1554                                               pool=None,
1555                                               check_hosts=False), job_id)
1556        self.getter.get_control_file_contents_by_name.assert_called_with(
1557                self._SUITE_NAME)
1558        self.dev_server.stage_artifacts.assert_called_with(
1559                image=self._BUILD, artifacts=['test_suites', 'control_files'])
1560
1561
1562    def testCreateSuiteJobControlFileSupplied(self):
1563        """Ensure we can supply the control file to create_suite_job."""
1564        self._mockDevServerGetter(get_control_file=False)
1565
1566        self.dev_server.hostname = 'mox_url'
1567        self.dev_server.stage_artifacts.return_value = True
1568        self.dev_server.url.return_value = 'mox_url'
1569        job_id = 5
1570        self._mockRpcUtils(job_id)
1571        self.assertEquals(
1572            rpc_interface.create_suite_job(name='%s/%s' % (self._NAME,
1573                                                           self._BUILD),
1574                                           board=None,
1575                                           builds=self._BUILDS,
1576                                           pool=None,
1577                                           control_file='CONTROL FILE'),
1578            job_id)
1579        self.dev_server.stage_artifacts.assert_called_with(
1580                image=self._BUILD, artifacts=['test_suites', 'control_files'])
1581
1582
1583    def _get_records_for_sending_to_main(self):
1584        return [{'control_file': 'foo',
1585                 'control_type': 1,
1586                 'created_on': datetime.datetime(2014, 8, 21),
1587                 'drone_set': None,
1588                 'email_list': '',
1589                 'max_runtime_hrs': 72,
1590                 'max_runtime_mins': 1440,
1591                 'name': 'stub',
1592                 'owner': 'autotest_system',
1593                 'parse_failed_repair': True,
1594                 'priority': 40,
1595                 'reboot_after': 0,
1596                 'reboot_before': 1,
1597                 'run_reset': True,
1598                 'run_verify': False,
1599                 'synch_count': 0,
1600                 'test_retry': 0,
1601                 'timeout': 24,
1602                 'timeout_mins': 1440,
1603                 'id': 1
1604                 }], [{
1605                    'aborted': False,
1606                    'active': False,
1607                    'complete': False,
1608                    'deleted': False,
1609                    'execution_subdir': '',
1610                    'finished_on': None,
1611                    'started_on': None,
1612                    'status': 'Queued',
1613                    'id': 1
1614                }]
1615
1616
1617    def _send_records_to_main_helper(
1618        self, jobs, hqes, shard_hostname='host1',
1619        exception_to_throw=error.UnallowedRecordsSentToMain, aborted=False):
1620        job_id = rpc_interface.create_job(
1621                name='stub',
1622                priority=self._PRIORITY,
1623                control_file='foo',
1624                control_type=SERVER,
1625                hostless=True)
1626        job = models.Job.objects.get(pk=job_id)
1627        shard = models.Shard.objects.create(hostname='host1')
1628        job.shard = shard
1629        job.save()
1630
1631        if aborted:
1632            job.hostqueueentry_set.update(aborted=True)
1633            job.shard = None
1634            job.save()
1635
1636        hqe = job.hostqueueentry_set.all()[0]
1637        if not exception_to_throw:
1638            self._do_heartbeat_and_assert_response(
1639                shard_hostname=shard_hostname,
1640                upload_jobs=jobs, upload_hqes=hqes)
1641        else:
1642            self.assertRaises(
1643                exception_to_throw,
1644                self._do_heartbeat_and_assert_response,
1645                shard_hostname=shard_hostname,
1646                upload_jobs=jobs, upload_hqes=hqes)
1647
1648
1649    def testSendingRecordsToMain(self):
1650        """Send records to the main and ensure they are persisted."""
1651        jobs, hqes = self._get_records_for_sending_to_main()
1652        hqes[0]['status'] = 'Completed'
1653        self._send_records_to_main_helper(
1654            jobs=jobs, hqes=hqes, exception_to_throw=None)
1655
1656        # Check the entry was actually written to db
1657        self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
1658                         'Completed')
1659
1660
1661    def testSendingRecordsToMainAbortedOnMain(self):
1662        """Send records to the main and ensure they are persisted."""
1663        jobs, hqes = self._get_records_for_sending_to_main()
1664        hqes[0]['status'] = 'Completed'
1665        self._send_records_to_main_helper(
1666            jobs=jobs, hqes=hqes, exception_to_throw=None, aborted=True)
1667
1668        # Check the entry was actually written to db
1669        self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
1670                         'Completed')
1671
1672
1673    def testSendingRecordsToMainJobAssignedToDifferentShard(self):
1674        """Ensure records belonging to different shard are silently rejected."""
1675        shard1 = models.Shard.objects.create(hostname='shard1')
1676        shard2 = models.Shard.objects.create(hostname='shard2')
1677        job1 = self._create_job(shard=shard1, control_file='foo1')
1678        job2 = self._create_job(shard=shard2, control_file='foo2')
1679        job1_id = job1.id
1680        job2_id = job2.id
1681        hqe1 = models.HostQueueEntry.objects.create(job=job1)
1682        hqe2 = models.HostQueueEntry.objects.create(job=job2)
1683        hqe1_id = hqe1.id
1684        hqe2_id = hqe2.id
1685        job1_record = job1.serialize(include_dependencies=False)
1686        job2_record = job2.serialize(include_dependencies=False)
1687        hqe1_record = hqe1.serialize(include_dependencies=False)
1688        hqe2_record = hqe2.serialize(include_dependencies=False)
1689
1690        # Prepare a bogus job record update from the wrong shard. The update
1691        # should not throw an exception. Non-bogus jobs in the same update
1692        # should happily update.
1693        job1_record.update({'control_file': 'bar1'})
1694        job2_record.update({'control_file': 'bar2'})
1695        hqe1_record.update({'status': 'Aborted'})
1696        hqe2_record.update({'status': 'Aborted'})
1697        self._do_heartbeat_and_assert_response(
1698            shard_hostname='shard2', upload_jobs=[job1_record, job2_record],
1699            upload_hqes=[hqe1_record, hqe2_record])
1700
1701        # Job and HQE record for wrong job should not be modified, because the
1702        # rpc came from the wrong shard. Job and HQE record for valid job are
1703        # modified.
1704        self.assertEqual(models.Job.objects.get(id=job1_id).control_file,
1705                         'foo1')
1706        self.assertEqual(models.Job.objects.get(id=job2_id).control_file,
1707                         'bar2')
1708        self.assertEqual(models.HostQueueEntry.objects.get(id=hqe1_id).status,
1709                         '')
1710        self.assertEqual(models.HostQueueEntry.objects.get(id=hqe2_id).status,
1711                         'Aborted')
1712
1713
1714    def testSendingRecordsToMainNotExistingJob(self):
1715        """Ensure update for non existing job gets rejected."""
1716        jobs, hqes = self._get_records_for_sending_to_main()
1717        jobs[0]['id'] = 3
1718
1719        self._send_records_to_main_helper(
1720            jobs=jobs, hqes=hqes)
1721
1722
1723    def _createShardAndHostWithLabel(self, shard_hostname='shard1',
1724                                     host_hostname='host1',
1725                                     label_name='board:lumpy'):
1726        """Create a label, host, shard, and assign host to shard."""
1727        try:
1728            label = models.Label.objects.create(name=label_name)
1729        except:
1730            label = models.Label.smart_get(label_name)
1731
1732        shard = models.Shard.objects.create(hostname=shard_hostname)
1733        shard.labels.add(label)
1734
1735        host = models.Host.objects.create(hostname=host_hostname, leased=False,
1736                                          shard=shard)
1737        host.labels.add(label)
1738
1739        return shard, host, label
1740
1741
1742    def testShardLabelRemovalInvalid(self):
1743        """Ensure you cannot remove the wrong label from shard."""
1744        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1745        stumpy_label = models.Label.objects.create(
1746                name='board:stumpy', platform=True)
1747        with self.assertRaises(error.RPCException):
1748            rpc_interface.remove_board_from_shard(
1749                    shard1.hostname, stumpy_label.name)
1750
1751
1752    def testShardHeartbeatLabelRemoval(self):
1753        """Ensure label removal from shard works."""
1754        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1755
1756        self.assertEqual(host1.shard, shard1)
1757        six.assertCountEqual(self, shard1.labels.all(), [lumpy_label])
1758        rpc_interface.remove_board_from_shard(
1759                shard1.hostname, lumpy_label.name)
1760        host1 = models.Host.smart_get(host1.id)
1761        shard1 = models.Shard.smart_get(shard1.id)
1762        self.assertEqual(host1.shard, None)
1763        six.assertCountEqual(self, shard1.labels.all(), [])
1764
1765    def testCreateListShard(self):
1766        """Retrieve a list of all shards."""
1767        lumpy_label = models.Label.objects.create(name='board:lumpy',
1768                                                  platform=True)
1769        stumpy_label = models.Label.objects.create(name='board:stumpy',
1770                                                  platform=True)
1771        peppy_label = models.Label.objects.create(name='board:peppy',
1772                                                  platform=True)
1773
1774        shard_id = rpc_interface.add_shard(
1775            hostname='host1', labels='board:lumpy,board:stumpy')
1776        self.assertRaises(error.RPCException,
1777                          rpc_interface.add_shard,
1778                          hostname='host1', labels='board:lumpy,board:stumpy')
1779        self.assertRaises(model_logic.ValidationError,
1780                          rpc_interface.add_shard,
1781                          hostname='host1', labels='board:peppy')
1782        shard = models.Shard.objects.get(pk=shard_id)
1783        self.assertEqual(shard.hostname, 'host1')
1784        self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
1785        self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
1786
1787        self.assertEqual(rpc_interface.get_shards(),
1788                         [{'labels': ['board:lumpy','board:stumpy'],
1789                           'hostname': 'host1',
1790                           'id': 1}])
1791
1792
1793    def testAddBoardsToShard(self):
1794        """Add boards to a given shard."""
1795        shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
1796        stumpy_label = models.Label.objects.create(name='board:stumpy',
1797                                                   platform=True)
1798        shard_id = rpc_interface.add_board_to_shard(
1799            hostname='shard1', labels='board:stumpy')
1800        # Test whether raise exception when board label does not exist.
1801        self.assertRaises(models.Label.DoesNotExist,
1802                          rpc_interface.add_board_to_shard,
1803                          hostname='shard1', labels='board:test')
1804        # Test whether raise exception when board already sharded.
1805        self.assertRaises(error.RPCException,
1806                          rpc_interface.add_board_to_shard,
1807                          hostname='shard1', labels='board:lumpy')
1808        shard = models.Shard.objects.get(pk=shard_id)
1809        self.assertEqual(shard.hostname, 'shard1')
1810        self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
1811        self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
1812
1813        self.assertEqual(rpc_interface.get_shards(),
1814                         [{'labels': ['board:lumpy','board:stumpy'],
1815                           'hostname': 'shard1',
1816                           'id': 1}])
1817
1818
1819    def testShardHeartbeatFetchHostlessJob(self):
1820        shard1, host1, label1 = self._createShardAndHostWithLabel()
1821        self._testShardHeartbeatFetchHostlessJobHelper(host1)
1822
1823
1824    def testShardHeartbeatIncorrectHosts(self):
1825        shard1, host1, label1 = self._createShardAndHostWithLabel()
1826        self._testShardHeartbeatIncorrectHostsHelper(host1)
1827
1828
1829    def testShardHeartbeatLabelRemovalRace(self):
1830        shard1, host1, label1 = self._createShardAndHostWithLabel()
1831        self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
1832
1833
1834    def testShardRetrieveJobs(self):
1835        shard1, host1, label1 = self._createShardAndHostWithLabel()
1836        shard2, host2, label2 = self._createShardAndHostWithLabel(
1837                'shard2', 'host2', 'board:grumpy')
1838        self._testShardRetrieveJobsHelper(shard1, host1, label1,
1839                                          shard2, host2, label2)
1840
1841
1842    def testResendJobsAfterFailedHeartbeat(self):
1843        shard1, host1, label1 = self._createShardAndHostWithLabel()
1844        self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
1845
1846
1847    def testResendHostsAfterFailedHeartbeat(self):
1848        shard1, host1, label1 = self._createShardAndHostWithLabel()
1849        self._testResendHostsAfterFailedHeartbeatHelper(host1)
1850
1851
1852if __name__ == '__main__':
1853    unittest.main()
1854