Skip to content

Commit

Permalink
Add full coverage for the scheduler/rpc/server.py
Browse files Browse the repository at this point in the history
module
  • Loading branch information
mihaelabalutoiu committed Feb 28, 2024
1 parent 9cdeaa2 commit 6246684
Showing 1 changed file with 122 additions and 2 deletions.
124 changes: 122 additions & 2 deletions coriolis/tests/scheduler/rpc/test_server.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# Copyright 2023 Cloudbase Solutions Srl
# Copyright 2024 Cloudbase Solutions Srl
# All Rights Reserved.

import logging
from unittest import mock

import ddt

from coriolis import constants
from coriolis.db import api as db_api
from coriolis import exception
from coriolis.scheduler.filters import trivial_filters
from coriolis.scheduler.rpc import server
from coriolis.tests import test_base
from coriolis.tests import testutils
from coriolis import utils


@ddt.ddt
Expand All @@ -20,6 +24,122 @@ def setUp(self):
super(SchedulerServerEndpointTestCase, self).setUp()
self.server = server.SchedulerServerEndpoint()

@mock.patch.object(utils, "get_diagnostics_info")
def test_get_diagnostics(self, mock_get_diagnostics_info):
result = self.server.get_diagnostics(mock.sentinel.context)

mock_get_diagnostics_info.assert_called_once_with()
self.assertEqual(result, mock_get_diagnostics_info.return_value)

@mock.patch.object(trivial_filters, 'TopicFilter', autospec=True)
@mock.patch.object(db_api, 'get_services')
def test_get_all_worker_services(self, mock_get_services,
mock_topic_filter_cls):
mock_get_services.return_value = mock.sentinel.services

mock_topic_filter_cls.return_value.filter_services.return_value = \
mock.sentinel.filtered_services

result = self.server._get_all_worker_services(mock.sentinel.context)

mock_get_services.assert_called_once_with(mock.sentinel.context)
mock_topic_filter_cls.assert_called_once_with(
constants.WORKER_MAIN_MESSAGING_TOPIC)
mock_topic_filter_cls.return_value.filter_services.\
assert_called_once_with(mock.sentinel.services)

self.assertEqual(result, mock.sentinel.filtered_services)

@mock.patch.object(db_api, 'get_services')
def test_get_all_worker_services_no_services(self, mock_get_services):
mock_get_services.return_value = []

self.assertRaises(exception.NoWorkerServiceError,
self.server._get_all_worker_services,
mock.sentinel.context)

mock_get_services.assert_called_once_with(mock.sentinel.context)

def test_get_weighted_filtered_services_no_filters(self):
services = [mock.Mock(id=1), mock.Mock(id=2)]

with self.assertLogs('coriolis.scheduler.rpc.server',
level=logging.WARN):
result = self.server._get_weighted_filtered_services(services,
None)
expected_result = [(services[0], 100), (services[1], 100)]
self.assertEqual(result, expected_result)

def test_get_weighted_filtered_services_with_filters_reject(self):
services = [mock.Mock(id=1), mock.Mock(id=2)]
filters = [mock.Mock(), mock.Mock()]
filters[0].rate_service.return_value = 50
filters[1].rate_service.return_value = 0

self.assertRaises(exception.NoSuitableWorkerServiceError,
self.server._get_weighted_filtered_services,
services, filters)

def test_get_weighted_filtered_services_with_filters_accept(self):
services = [mock.Mock(id=1), mock.Mock(id=2)]
filters = [mock.Mock(), mock.Mock()]
filters[0].rate_service.return_value = 50
filters[1].rate_service.return_value = 100

result = self.server._get_weighted_filtered_services(services,
filters)
expected_result = [(services[0], 150), (services[1], 150)]
self.assertEqual(result, expected_result)

@mock.patch.object(db_api, 'get_regions')
def test__filter_regions_check_all_exist_false(self, mock_get_regions):
mock_get_regions.return_value = [
mock.Mock(id='region1', enabled=True),
mock.Mock(id='region2', enabled=True),
]
region_ids = ['region1', 'region2']

result = self.server._filter_regions(None, region_ids,
check_all_exist=False)

self.assertEqual(result, mock_get_regions.return_value)

@mock.patch.object(db_api, 'get_regions')
def test__filter_regions_all_disabled(self, mock_get_regions):
mock_get_regions.return_value = [
mock.Mock(id='region1', enabled=False),
mock.Mock(id='region2', enabled=False),
]
region_ids = ['region1', 'region2']

result = self.server._filter_regions(None, region_ids, enabled=False)

self.assertEqual(result, mock_get_regions.return_value)

@mock.patch.object(db_api, 'get_regions')
def test__filter_regions_some_enabled_some_disabled(self,
mock_get_regions):
mock_get_regions.return_value = [
mock.Mock(id='region1', enabled=True),
mock.Mock(id='region2', enabled=False),
]
region_ids = ['region1', 'region2']

result = self.server._filter_regions(None, region_ids)

self.assertEqual(result, [mock_get_regions.return_value[0]])

@mock.patch.object(db_api, 'get_regions')
def test__filter_regions_some_missing(self, mock_get_regions):
mock_get_regions.return_value = [
mock.Mock(id='region1', enabled=True),
mock.Mock(id='region2', enabled=True),
]
region_ids = ['region1', 'region2', 'region3']

self.assertRaises(exception.RegionNotFound,
self.server._filter_regions, None, region_ids)

@mock.patch.object(trivial_filters, 'ProviderTypesFilter', autospec=True)
@mock.patch.object(trivial_filters, 'RegionsFilter', autospec=True)
@mock.patch.object(trivial_filters, 'EnabledFilter', autospec=True)
Expand Down Expand Up @@ -52,7 +172,7 @@ def test_get_workers_for_specs(
provider_requirements = config.get("provider_requirements", None)

# Convert the config dict to an object, skipping the providers
# providers is the only field used as dict in the code
# as it's the only field used as dict in the code
config_obj = testutils.DictToObject(config, skip_attrs=["providers"])
mock_get_all_worker_services.return_value = (
config_obj.services_db or []
Expand Down

0 comments on commit 6246684

Please sign in to comment.