diff --git a/zun/api/controllers/v1/containers.py b/zun/api/controllers/v1/containers.py index 049e14786..c8a2f4a8f 100644 --- a/zun/api/controllers/v1/containers.py +++ b/zun/api/controllers/v1/containers.py @@ -286,6 +286,7 @@ class ContainersController(base.Controller): container_dict['status'] = consts.CREATING extra_spec = {} extra_spec['hints'] = container_dict.get('hints', None) + extra_spec['pci_requests'] = pci_req new_container = objects.Container(context, **container_dict) new_container.create(context) diff --git a/zun/scheduler/filter_scheduler.py b/zun/scheduler/filter_scheduler.py index ddb538b0c..e0597e00a 100644 --- a/zun/scheduler/filter_scheduler.py +++ b/zun/scheduler/filter_scheduler.py @@ -21,6 +21,7 @@ from zun.common import exception from zun.common.i18n import _ import zun.conf from zun import objects +from zun.pci import stats as pci_stats from zun.scheduler import driver from zun.scheduler import filters from zun.scheduler.host_state import HostState @@ -111,5 +112,7 @@ class FilterScheduler(driver.Scheduler): host_state.cpu_used = node.cpu_used host_state.numa_topology = node.numa_topology host_state.labels = node.labels + host_state.pci_stats = pci_stats.PciDeviceStats( + stats=node.pci_device_pools) host_states.append(host_state) return host_states diff --git a/zun/scheduler/filters/pci_passthrough_filter.py b/zun/scheduler/filters/pci_passthrough_filter.py new file mode 100644 index 000000000..b139a486e --- /dev/null +++ b/zun/scheduler/filters/pci_passthrough_filter.py @@ -0,0 +1,52 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_log import log as logging + +from zun.scheduler import filters + +LOG = logging.getLogger(__name__) + + +class PciPassthroughFilter(filters.BaseHostFilter): + """Pci Passthrough Filter based on PCI request + + Filter that schedules containers on a host if the host has devices + to meet the device requests in the 'extra_specs'. + + PCI resource tracker provides updated summary information about the + PCI devices for each host, like:: + + | [{"count": 5, "vendor_id": "8086", "product_id": "1520", + | "extra_info":'{}'}], + + and container requests PCI devices via PCI requests, like:: + + | [{"count": 1, "vendor_id": "8086", "product_id": "1520",}]. + + The filter checks if the host passes or not based on this information. + + """ + + def host_passes(self, host_state, container, extra_spec): + """Return true if the host has the required PCI devices.""" + pci_requests = extra_spec['pci_requests'] + if not pci_requests or not pci_requests.requests: + return True + if (not host_state.pci_stats or + not host_state.pci_stats.support_requests( + pci_requests.requests)): + LOG.debug("%(host_state)s doesn't have the required PCI devices" + " (%(requests)s)", + {'host_state': host_state, 'requests': pci_requests}) + return False + return True diff --git a/zun/tests/unit/scheduler/filters/test_pci_passthrough_filters.py b/zun/tests/unit/scheduler/filters/test_pci_passthrough_filters.py new file mode 100644 index 000000000..24c70177d --- /dev/null +++ b/zun/tests/unit/scheduler/filters/test_pci_passthrough_filters.py @@ -0,0 +1,90 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import mock + +from zun import objects +from zun.pci import stats +from zun.scheduler.filters import pci_passthrough_filter +from zun.scheduler.host_state import HostState +from zun.tests import base + + +class TestPCIPassthroughFilter(base.TestCase): + + def setUp(self): + super(TestPCIPassthroughFilter, self).setUp() + self.filt_cls = pci_passthrough_filter.PciPassthroughFilter() + + def test_pci_passthrough_pass(self): + pci_stats_mock = mock.MagicMock() + pci_stats_mock.support_requests.return_value = True + request = objects.ContainerPCIRequest( + count=1, spec=[{'vendor_id': '8086'}]) + requests = objects.ContainerPCIRequests(requests=[request]) + container = objects.Container(self.context) + host = HostState('testhost') + host.pci_stats = pci_stats_mock + extra_spec = {'pci_requests': requests} + self.assertTrue(self.filt_cls.host_passes(host, container, extra_spec)) + pci_stats_mock.support_requests.assert_called_once_with( + requests.requests) + + def test_pci_passthrough_fail(self): + pci_stats_mock = mock.MagicMock() + pci_stats_mock.support_requests.return_value = False + request = objects.ContainerPCIRequest( + count=1, spec=[{'vendor_id': '8086'}]) + requests = objects.ContainerPCIRequests(requests=[request]) + container = objects.Container(self.context) + host = HostState('testhost') + host.pci_stats = pci_stats_mock + extra_spec = {'pci_requests': requests} + self.assertFalse(self.filt_cls.host_passes(host, container, + extra_spec)) + pci_stats_mock.support_requests.assert_called_once_with( + requests.requests) + + def test_pci_passthrough_no_pci_request(self): + container = objects.Container(self.context) + host = HostState('testhost') + extra_spec = {'pci_requests': None} + self.assertTrue(self.filt_cls.host_passes(host, container, extra_spec)) + + def test_pci_passthrough_empty_pci_request_obj(self): + requests = objects.ContainerPCIRequests(requests=[]) + container = objects.Container(self.context) + host = HostState('testhost') + extra_spec = {'pci_requests': requests} + self.assertTrue(self.filt_cls.host_passes(host, container, extra_spec)) + + def test_pci_passthrough_no_pci_stats(self): + request = objects.ContainerPCIRequest( + count=1, spec=[{'vendor_id': '8086'}]) + requests = objects.ContainerPCIRequests(requests=[request]) + container = objects.Container(self.context) + host = HostState('testhost') + host.pci_stats = stats.PciDeviceStats() + extra_spec = {'pci_requests': requests} + self.assertFalse(self.filt_cls.host_passes(host, container, + extra_spec)) + + def test_pci_passthrough_with_pci_stats_none(self): + request = objects.ContainerPCIRequest( + count=1, spec=[{'vendor_id': '8086'}]) + requests = objects.ContainerPCIRequests(requests=[request]) + container = objects.Container(self.context) + host = HostState('testhost') + host.pci_stats = None + extra_spec = {'pci_requests': requests} + self.assertFalse(self.filt_cls.host_passes(host, container, + extra_spec)) diff --git a/zun/tests/unit/scheduler/test_filter_scheduler.py b/zun/tests/unit/scheduler/test_filter_scheduler.py index 2974bd3c7..fe8a2fe12 100644 --- a/zun/tests/unit/scheduler/test_filter_scheduler.py +++ b/zun/tests/unit/scheduler/test_filter_scheduler.py @@ -63,6 +63,7 @@ class FilterSchedulerTestCase(base.TestCase): node1.hostname = 'host1' node1.numa_topology = None node1.labels = {} + node1.pci_device_pools = None node2 = objects.ComputeNode(self.context) node2.cpus = 48 node2.cpu_used = 0.0 @@ -71,6 +72,7 @@ class FilterSchedulerTestCase(base.TestCase): node2.hostname = 'host2' node2.numa_topology = None node2.labels = {} + node2.pci_device_pools = None node3 = objects.ComputeNode(self.context) node3.cpus = 48 node3.cpu_used = 0.0 @@ -79,6 +81,7 @@ class FilterSchedulerTestCase(base.TestCase): node3.hostname = 'host3' node3.numa_topology = None node3.labels = {} + node3.pci_device_pools = None node4 = objects.ComputeNode(self.context) node4.cpus = 48 node4.cpu_used = 0.0 @@ -87,6 +90,7 @@ class FilterSchedulerTestCase(base.TestCase): node4.hostname = 'host4' node4.numa_topology = None node4.labels = {} + node4.pci_device_pools = None nodes = [node1, node2, node3, node4] mock_compute_list.return_value = nodes