diff --git a/cloudbaseinit/metadata/services/baseopenstackservice.py b/cloudbaseinit/metadata/services/baseopenstackservice.py index 300f113d..7973df0b 100644 --- a/cloudbaseinit/metadata/services/baseopenstackservice.py +++ b/cloudbaseinit/metadata/services/baseopenstackservice.py @@ -19,6 +19,7 @@ from oslo.config import cfg from cloudbaseinit.metadata.services import base from cloudbaseinit.openstack.common import log as logging +from cloudbaseinit.utils import encoding from cloudbaseinit.utils import x509constants opts = [ @@ -48,7 +49,8 @@ class BaseOpenStackService(base.BaseMetadataService): path = posixpath.normpath( posixpath.join('openstack', version, 'meta_data.json')) data = self._get_cache_data(path) - return json.loads(data.decode('utf8')) + if data: + return json.loads(encoding.get_as_string(data)) def get_instance_id(self): return self._get_meta_data().get('uuid') @@ -98,10 +100,12 @@ class BaseOpenStackService(base.BaseMetadataService): i += 1 if not cert_data: + # Look if the user_data contains a PEM certificate try: user_data = self.get_user_data() - if user_data.startswith(x509constants.PEM_HEADER): + if user_data.startswith( + x509constants.PEM_HEADER.encode()): cert_data = user_data except base.NotExistingMetadataException: LOG.debug("user_data metadata not present") diff --git a/cloudbaseinit/osutils/base.py b/cloudbaseinit/osutils/base.py index ace015ae..55267ca1 100644 --- a/cloudbaseinit/osutils/base.py +++ b/cloudbaseinit/osutils/base.py @@ -19,6 +19,8 @@ import os import subprocess import sys +from cloudbaseinit.utils import encoding + class BaseOSUtils(object): PROTOCOL_TCP = "TCP" @@ -34,7 +36,9 @@ class BaseOSUtils(object): # On Windows os.urandom() uses CryptGenRandom, which is a # cryptographically secure pseudorandom number generator b64_password = base64.b64encode(os.urandom(256)) - return b64_password.replace('/', '').replace('+', '')[:length] + b64_password = encoding.get_as_string(b64_password).replace( + '/', '').replace('+', '')[:length] + return b64_password def execute_process(self, args, shell=True, decode_output=False): p = subprocess.Popen(args, diff --git a/cloudbaseinit/osutils/windows.py b/cloudbaseinit/osutils/windows.py index 3c4cb0eb..5eccf575 100644 --- a/cloudbaseinit/osutils/windows.py +++ b/cloudbaseinit/osutils/windows.py @@ -28,6 +28,7 @@ import wmi from cloudbaseinit import exception from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.osutils import base +from cloudbaseinit.utils import encoding from cloudbaseinit.utils.windows import network @@ -707,9 +708,12 @@ class WindowsUtils(base.BaseOSUtils): while i < forward_table.dwNumEntries: row = table[i] routing_table.append(( - Ws2_32.inet_ntoa(row.dwForwardDest), - Ws2_32.inet_ntoa(row.dwForwardMask), - Ws2_32.inet_ntoa(row.dwForwardNextHop), + encoding.get_as_string(Ws2_32.inet_ntoa( + row.dwForwardDest)), + encoding.get_as_string(Ws2_32.inet_ntoa( + row.dwForwardMask)), + encoding.get_as_string(Ws2_32.inet_ntoa( + row.dwForwardNextHop)), row.dwForwardIfIndex, row.dwForwardMetric1)) i += 1 diff --git a/cloudbaseinit/plugins/windows/networkconfig.py b/cloudbaseinit/plugins/windows/networkconfig.py index b912bbf5..1d9d60cd 100644 --- a/cloudbaseinit/plugins/windows/networkconfig.py +++ b/cloudbaseinit/plugins/windows/networkconfig.py @@ -22,6 +22,7 @@ from cloudbaseinit.metadata.services import base as service_base from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.osutils import factory as osutils_factory from cloudbaseinit.plugins import base as plugin_base +from cloudbaseinit.utils import encoding LOG = logging.getLogger(__name__) @@ -58,6 +59,7 @@ class NetworkConfigPlugin(plugin_base.BasePlugin): content_path = network_config['content_path'] content_name = content_path.rsplit('/', 1)[-1] debian_network_conf = service.get_content(content_name) + debian_network_conf = encoding.get_as_string(debian_network_conf) LOG.debug('network config content:\n%s' % debian_network_conf) diff --git a/cloudbaseinit/plugins/windows/setuserpassword.py b/cloudbaseinit/plugins/windows/setuserpassword.py index 287966ff..fad79ea3 100644 --- a/cloudbaseinit/plugins/windows/setuserpassword.py +++ b/cloudbaseinit/plugins/windows/setuserpassword.py @@ -47,7 +47,7 @@ class SetUserPasswordPlugin(base.BasePlugin): def _get_ssh_public_key(self, service): public_keys = service.get_public_keys() if public_keys: - return public_keys[0] + return list(public_keys)[0] def _get_password(self, service, osutils): if CONF.inject_user_password: diff --git a/cloudbaseinit/plugins/windows/userdata.py b/cloudbaseinit/plugins/windows/userdata.py index 67d30776..4b151775 100644 --- a/cloudbaseinit/plugins/windows/userdata.py +++ b/cloudbaseinit/plugins/windows/userdata.py @@ -23,6 +23,7 @@ from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.plugins import base from cloudbaseinit.plugins.windows.userdataplugins import factory from cloudbaseinit.plugins.windows import userdatautils +from cloudbaseinit.utils import encoding LOG = logging.getLogger(__name__) @@ -40,6 +41,7 @@ class UserDataPlugin(base.BasePlugin): if not user_data: return (base.PLUGIN_EXECUTION_DONE, False) + LOG.debug('User data content length: %d' % len(user_data)) user_data = self._check_gzip_compression(user_data) return self._process_user_data(user_data) @@ -53,14 +55,16 @@ class UserDataPlugin(base.BasePlugin): return user_data def _parse_mime(self, user_data): - return email.message_from_string(user_data).walk() + user_data_str = encoding.get_as_string(user_data) + LOG.debug('User data content:\n%s' % user_data_str) + + return email.message_from_string(user_data_str).walk() def _process_user_data(self, user_data): plugin_status = base.PLUGIN_EXECUTION_DONE reboot = False - LOG.debug('User data content:\n%s' % user_data) - if user_data.startswith('Content-Type: multipart'): + if user_data.startswith(b'Content-Type: multipart'): user_data_plugins = factory.load_plugins() user_handlers = {} @@ -158,7 +162,7 @@ class UserDataPlugin(base.BasePlugin): return (plugin_status, reboot) def _process_non_multi_part(self, user_data): - if user_data.startswith('#cloud-config'): + if user_data.startswith(b'#cloud-config'): user_data_plugins = factory.load_plugins() cloud_config_plugin = user_data_plugins.get('text/cloud-config') ret_val = cloud_config_plugin.process(user_data) diff --git a/cloudbaseinit/plugins/windows/userdataplugins/heat.py b/cloudbaseinit/plugins/windows/userdataplugins/heat.py index 072770e3..d063c4fc 100644 --- a/cloudbaseinit/plugins/windows/userdataplugins/heat.py +++ b/cloudbaseinit/plugins/windows/userdataplugins/heat.py @@ -20,6 +20,7 @@ from oslo.config import cfg from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.plugins.windows.userdataplugins import base from cloudbaseinit.plugins.windows import userdatautils +from cloudbaseinit.utils import encoding opts = [ cfg.StrOpt('heat_config_dir', default='C:\\cfn', help='The directory ' @@ -47,8 +48,7 @@ class HeatPlugin(base.BaseUserDataPlugin): file_name = os.path.join(CONF.heat_config_dir, part.get_filename()) self._check_dir(file_name) - with open(file_name, 'wb') as f: - f.write(part.get_payload()) + encoding.write_file(file_name, part.get_payload()) if part.get_filename() == self._heat_user_data_filename: return userdatautils.execute_user_data_script(part.get_payload()) diff --git a/cloudbaseinit/plugins/windows/userdataplugins/shellscript.py b/cloudbaseinit/plugins/windows/userdataplugins/shellscript.py index 1d288879..fad0f4c5 100644 --- a/cloudbaseinit/plugins/windows/userdataplugins/shellscript.py +++ b/cloudbaseinit/plugins/windows/userdataplugins/shellscript.py @@ -19,6 +19,7 @@ import tempfile from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.plugins.windows import fileexecutils from cloudbaseinit.plugins.windows.userdataplugins import base +from cloudbaseinit.utils import encoding LOG = logging.getLogger(__name__) @@ -32,8 +33,7 @@ class ShellScriptPlugin(base.BaseUserDataPlugin): target_path = os.path.join(tempfile.gettempdir(), file_name) try: - with open(target_path, 'wb') as f: - f.write(part.get_payload()) + encoding.write_file(target_path, part.get_payload()) return fileexecutils.exec_file(target_path) except Exception as ex: diff --git a/cloudbaseinit/plugins/windows/userdatautils.py b/cloudbaseinit/plugins/windows/userdatautils.py index 9d4915c5..2037d639 100644 --- a/cloudbaseinit/plugins/windows/userdatautils.py +++ b/cloudbaseinit/plugins/windows/userdatautils.py @@ -19,6 +19,7 @@ import uuid from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.osutils import factory as osutils_factory +from cloudbaseinit.utils import encoding LOG = logging.getLogger(__name__) @@ -54,8 +55,7 @@ def execute_user_data_script(user_data): return 0 try: - with open(target_path, 'wb') as f: - f.write(user_data) + encoding.write_file(target_path, user_data) if powershell: (out, err, diff --git a/cloudbaseinit/tests/metadata/services/test_baseopenstackservice.py b/cloudbaseinit/tests/metadata/services/test_baseopenstackservice.py index c24a52a2..bc2b41f7 100644 --- a/cloudbaseinit/tests/metadata/services/test_baseopenstackservice.py +++ b/cloudbaseinit/tests/metadata/services/test_baseopenstackservice.py @@ -126,7 +126,7 @@ class BaseOpenStackServiceTest(unittest.TestCase): response = self._service.get_client_auth_certs() mock_get_meta_data.assert_called_once_with() if 'meta' in meta_data: - self.assertEqual(['fake cert'], response) + self.assertEqual([b'fake cert'], response) elif type(ret_value) is str and ret_value.startswith( x509constants.PEM_HEADER): mock_get_user_data.assert_called_once_with() @@ -136,11 +136,11 @@ class BaseOpenStackServiceTest(unittest.TestCase): def test_get_client_auth_certs(self): self._test_get_client_auth_certs( - meta_data={'meta': {'admin_cert0': 'fake cert'}}) + meta_data={'meta': {'admin_cert0': b'fake cert'}}) def test_get_client_auth_certs_no_cert_data(self): self._test_get_client_auth_certs( - meta_data={}, ret_value=x509constants.PEM_HEADER) + meta_data={}, ret_value=x509constants.PEM_HEADER.encode()) def test_get_client_auth_certs_no_cert_data_exception(self): self._test_get_client_auth_certs( diff --git a/cloudbaseinit/tests/plugins/windows/test_setuserpassword.py b/cloudbaseinit/tests/plugins/windows/test_setuserpassword.py index 2915307d..ced1102d 100644 --- a/cloudbaseinit/tests/plugins/windows/test_setuserpassword.py +++ b/cloudbaseinit/tests/plugins/windows/test_setuserpassword.py @@ -55,10 +55,12 @@ class SetUserPasswordPluginTests(unittest.TestCase): def _test_get_ssh_public_key(self, data_exists): mock_service = mock.MagicMock() public_keys = self.fake_data['public_keys'] - mock_service.get_public_keys.return_value = public_keys + mock_service.get_public_keys.return_value = public_keys.values() + response = self._setpassword_plugin._get_ssh_public_key(mock_service) + mock_service.get_public_keys.assert_called_with() - self.assertEqual(public_keys[0], response) + self.assertEqual(list(public_keys.values())[0], response) def test_get_ssh_plublic_key(self): self._test_get_ssh_public_key(data_exists=True) diff --git a/cloudbaseinit/tests/plugins/windows/test_userdata.py b/cloudbaseinit/tests/plugins/windows/test_userdata.py index e92ac1e2..bd4501bc 100644 --- a/cloudbaseinit/tests/plugins/windows/test_userdata.py +++ b/cloudbaseinit/tests/plugins/windows/test_userdata.py @@ -58,7 +58,7 @@ class UserDataPluginTest(unittest.TestCase): self.assertEqual(response, mock_process_user_data.return_value) def test_execute(self): - self._test_execute(ret_val=mock.sentinel.fake_data) + self._test_execute(ret_val='fake_data') def test_execute_no_data(self): self._test_execute(ret_val=None) @@ -85,10 +85,15 @@ class UserDataPluginTest(unittest.TestCase): self.assertEqual(data, response) @mock.patch('email.message_from_string') - def test_parse_mime(self, mock_message_from_string): + @mock.patch('cloudbaseinit.utils.encoding.get_as_string') + def test_parse_mime(self, mock_get_as_string, mock_message_from_string): fake_user_data = 'fake data' + response = self._userdata._parse_mime(user_data=fake_user_data) - mock_message_from_string.assert_called_once_with(fake_user_data) + + mock_get_as_string.assert_called_once_with(fake_user_data) + mock_message_from_string.assert_called_once_with( + mock_get_as_string.return_value) self.assertEqual(response, mock_message_from_string().walk()) @mock.patch('cloudbaseinit.plugins.windows.userdataplugins.factory.' @@ -110,7 +115,8 @@ class UserDataPluginTest(unittest.TestCase): mock_process_part.return_value = (base.PLUGIN_EXECUTION_DONE, reboot) response = self._userdata._process_user_data(user_data=user_data) - if user_data.startswith('Content-Type: multipart'): + + if user_data.startswith(b'Content-Type: multipart'): mock_load_plugins.assert_called_once_with() mock_parse_mime.assert_called_once_with(user_data) mock_process_part.assert_called_once_with(mock_part, @@ -122,15 +128,15 @@ class UserDataPluginTest(unittest.TestCase): response) def test_process_user_data_multipart_reboot_true(self): - self._test_process_user_data(user_data='Content-Type: multipart', + self._test_process_user_data(user_data=b'Content-Type: multipart', reboot=True) def test_process_user_data_multipart_reboot_false(self): - self._test_process_user_data(user_data='Content-Type: multipart', + self._test_process_user_data(user_data=b'Content-Type: multipart', reboot=False) def test_process_user_data_non_multipart(self): - self._test_process_user_data(user_data='Content-Type: non-multipart', + self._test_process_user_data(user_data=b'Content-Type: non-multipart', reboot=False) @mock.patch('cloudbaseinit.plugins.windows.userdata.UserDataPlugin' @@ -250,7 +256,7 @@ class UserDataPluginTest(unittest.TestCase): '._get_plugin_return_value') def test_process_non_multi_part(self, mock_get_plugin_return_value, mock_execute_user_data_script): - user_data = 'fake' + user_data = b'fake' response = self._userdata._process_non_multi_part(user_data=user_data) mock_execute_user_data_script.assert_called_once_with(user_data) mock_get_plugin_return_value.assert_called_once_with( @@ -263,7 +269,7 @@ class UserDataPluginTest(unittest.TestCase): '._get_plugin_return_value') def test_process_non_multi_part_cloud_config( self, mock_get_plugin_return_value, mock_load_plugins): - user_data = '#cloud-config' + user_data = b'#cloud-config' mock_return_value = mock.sentinel.return_value mock_cloud_config_plugin = mock.Mock() mock_cloud_config_plugin.process.return_value = mock_return_value diff --git a/cloudbaseinit/tests/plugins/windows/test_userdatautils.py b/cloudbaseinit/tests/plugins/windows/test_userdatautils.py index ae32c201..ffc2617c 100644 --- a/cloudbaseinit/tests/plugins/windows/test_userdatautils.py +++ b/cloudbaseinit/tests/plugins/windows/test_userdatautils.py @@ -37,8 +37,9 @@ class UserDataUtilsTest(unittest.TestCase): @mock.patch('os.path.expandvars') @mock.patch('cloudbaseinit.osutils.factory.get_os_utils') @mock.patch('uuid.uuid4') - def _test_execute_user_data_script(self, mock_uuid4, mock_get_os_utils, - mock_path_expandvars, + @mock.patch('cloudbaseinit.utils.encoding.write_file') + def _test_execute_user_data_script(self, mock_write_file, mock_uuid4, + mock_get_os_utils, mock_path_expandvars, mock_path_exists, mock_os_remove, mock_gettempdir, mock_re_search, fake_user_data): @@ -51,6 +52,8 @@ class UserDataUtilsTest(unittest.TestCase): powershell = False mock_get_os_utils.return_value = mock_osutils mock_path_exists.return_value = True + extension = '' + if fake_user_data == '^rem cmd\s': side_effect = [match_instance] number_of_calls = 1 @@ -88,12 +91,14 @@ class UserDataUtilsTest(unittest.TestCase): mock_re_search.side_effect = side_effect - with mock.patch('cloudbaseinit.plugins.windows.userdatautils.open', - mock.mock_open(), create=True): - response = userdatautils.execute_user_data_script(fake_user_data) + response = userdatautils.execute_user_data_script(fake_user_data) + mock_gettempdir.assert_called_once_with() + self.assertEqual(number_of_calls, mock_re_search.call_count) if args: + mock_write_file.assert_called_once_with(path + extension, + fake_user_data) mock_osutils.execute_process.assert_called_with(args, shell) mock_os_remove.assert_called_once_with(path + extension) self.assertEqual(None, response) @@ -106,24 +111,24 @@ class UserDataUtilsTest(unittest.TestCase): self.assertEqual(0, response) def test_handle_batch(self): - fake_user_data = '^rem cmd\s' + fake_user_data = b'^rem cmd\s' self._test_execute_user_data_script(fake_user_data=fake_user_data) def test_handle_python(self): - self._test_execute_user_data_script(fake_user_data='^#!/usr/bin/env' - '\spython\s') + self._test_execute_user_data_script( + fake_user_data=b'^#!/usr/bin/env\spython\s') def test_handle_shell(self): - self._test_execute_user_data_script(fake_user_data='^#!') + self._test_execute_user_data_script(fake_user_data=b'^#!') def test_handle_powershell(self): - self._test_execute_user_data_script(fake_user_data='^#ps1\s') + self._test_execute_user_data_script(fake_user_data=b'^#ps1\s') def test_handle_powershell_sysnative(self): - self._test_execute_user_data_script(fake_user_data='#ps1_sysnative\s') + self._test_execute_user_data_script(fake_user_data=b'#ps1_sysnative\s') def test_handle_powershell_sysnative_no_sysnative(self): - self._test_execute_user_data_script(fake_user_data='#ps1_x86\s') + self._test_execute_user_data_script(fake_user_data=b'#ps1_x86\s') def test_handle_unsupported_format(self): - self._test_execute_user_data_script(fake_user_data='unsupported') + self._test_execute_user_data_script(fake_user_data=b'unsupported') diff --git a/cloudbaseinit/tests/plugins/windows/userdataplugins/test_heat.py b/cloudbaseinit/tests/plugins/windows/userdataplugins/test_heat.py index cf6b4b84..27a0b74e 100644 --- a/cloudbaseinit/tests/plugins/windows/userdataplugins/test_heat.py +++ b/cloudbaseinit/tests/plugins/windows/userdataplugins/test_heat.py @@ -50,19 +50,19 @@ class HeatUserDataHandlerTests(unittest.TestCase): '.execute_user_data_script') @mock.patch('cloudbaseinit.plugins.windows.userdataplugins.heat' '.HeatPlugin._check_dir') - def _test_process(self, mock_check_dir, mock_execute_user_data_script, - filename): + @mock.patch('cloudbaseinit.utils.encoding.write_file') + def _test_process(self, mock_write_file, mock_check_dir, + mock_execute_user_data_script, filename): mock_part = mock.MagicMock() mock_part.get_filename.return_value = filename - with mock.patch('six.moves.builtins.open', mock.mock_open(), - create=True) as handle: - response = self._heat.process(mock_part) - - handle().write.assert_called_once_with(mock_part.get_payload()) + response = self._heat.process(mock_part) path = os.path.join(CONF.heat_config_dir, filename) mock_check_dir.assert_called_once_with(path) mock_part.get_filename.assert_called_with() + mock_write_file.assert_called_once_with( + path, mock_part.get_payload.return_value) + if filename == self._heat._heat_user_data_filename: mock_execute_user_data_script.assert_called_with( mock_part.get_payload()) diff --git a/cloudbaseinit/tests/plugins/windows/userdataplugins/test_shellscript.py b/cloudbaseinit/tests/plugins/windows/userdataplugins/test_shellscript.py index f08e1cae..71680aef 100644 --- a/cloudbaseinit/tests/plugins/windows/userdataplugins/test_shellscript.py +++ b/cloudbaseinit/tests/plugins/windows/userdataplugins/test_shellscript.py @@ -27,8 +27,9 @@ class ShellScriptPluginTests(unittest.TestCase): @mock.patch('cloudbaseinit.osutils.factory.get_os_utils') @mock.patch('tempfile.gettempdir') @mock.patch('cloudbaseinit.plugins.windows.fileexecutils.exec_file') - def _test_process(self, mock_exec_file, mock_gettempdir, mock_get_os_utils, - exception=False): + @mock.patch('cloudbaseinit.utils.encoding.write_file') + def _test_process(self, mock_write_file, mock_exec_file, mock_gettempdir, + mock_get_os_utils, exception=False): fake_dir_path = os.path.join("fake", "dir") mock_osutils = mock.MagicMock() mock_part = mock.MagicMock() @@ -45,6 +46,8 @@ class ShellScriptPluginTests(unittest.TestCase): response = self._shellscript.process(mock_part) mock_part.get_filename.assert_called_once_with() + mock_write_file.assert_called_once_with( + fake_target, mock_part.get_payload.return_value) mock_exec_file.assert_called_once_with(fake_target) mock_part.get_payload.assert_called_once_with() mock_gettempdir.assert_called_once_with() diff --git a/cloudbaseinit/tests/utils/test_dhcp.py b/cloudbaseinit/tests/utils/test_dhcp.py index 614d0318..8de6bd5e 100644 --- a/cloudbaseinit/tests/utils/test_dhcp.py +++ b/cloudbaseinit/tests/utils/test_dhcp.py @@ -45,9 +45,8 @@ class DHCPUtilsTests(unittest.TestCase): data += b'\x00' * 128 data += dhcp._DHCP_COOKIE data += b'\x35\x01\x01' - data += b'\x3c' + struct.pack('b', - len('fake id')) + 'fake id'.encode( - 'ascii') + data += b'\x3c' + struct.pack('b', len('fake id')) + 'fake id'.encode( + 'ascii') data += b'\x3d\x07\x01' data += fake_mac_address_b data += b'\x37' + struct.pack('b', len([100])) diff --git a/cloudbaseinit/tests/utils/windows/test_x509.py b/cloudbaseinit/tests/utils/windows/test_x509.py index 782704f8..8ce9c4ce 100644 --- a/cloudbaseinit/tests/utils/windows/test_x509.py +++ b/cloudbaseinit/tests/utils/windows/test_x509.py @@ -281,6 +281,8 @@ class CryptoAPICertManagerTests(unittest.TestCase): fake_cert_data += x509constants.PEM_HEADER + '\n' fake_cert_data += 'fake cert' + '\n' fake_cert_data += x509constants.PEM_FOOTER + fake_cert_data = fake_cert_data.encode() + response = self._x509_manager._get_cert_base64(fake_cert_data) self.assertEqual('fake cert', response) diff --git a/cloudbaseinit/utils/crypt.py b/cloudbaseinit/utils/crypt.py index ece1f422..1a4deb70 100644 --- a/cloudbaseinit/utils/crypt.py +++ b/cloudbaseinit/utils/crypt.py @@ -158,7 +158,7 @@ class CryptManager(object): key_type_len = struct.unpack('>I', pub_key[offset:offset + 4])[0] offset += 4 - key_type = pub_key[offset:offset + key_type_len] + key_type = pub_key[offset:offset + key_type_len].decode('utf-8') offset += key_type_len if key_type not in ['ssh-rsa', 'rsa', 'rsa1']: diff --git a/cloudbaseinit/utils/encoding.py b/cloudbaseinit/utils/encoding.py new file mode 100644 index 00000000..7a09e90d --- /dev/null +++ b/cloudbaseinit/utils/encoding.py @@ -0,0 +1,35 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2014 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import six + + +def get_as_string(value): + if value is None or isinstance(value, six.text_type): + return value + else: + try: + return value.decode() + except Exception: + pass + + +def write_file(target_path, data, mode='wb'): + if isinstance(data, six.text_type) and 'b' in mode: + data = data.encode() + + with open(target_path, mode) as f: + f.write(data) diff --git a/cloudbaseinit/utils/windows/cryptoapi.py b/cloudbaseinit/utils/windows/cryptoapi.py index 7288fba2..76b12ddf 100644 --- a/cloudbaseinit/utils/windows/cryptoapi.py +++ b/cloudbaseinit/utils/windows/cryptoapi.py @@ -102,8 +102,8 @@ CRYPT_STRING_BASE64 = 1 PKCS_7_ASN_ENCODING = 65536 PROV_RSA_FULL = 1 X509_ASN_ENCODING = 1 -szOID_PKIX_KP_SERVER_AUTH = "1.3.6.1.5.5.7.3.1" -szOID_RSA_SHA1RSA = "1.2.840.113549.1.1.5" +szOID_PKIX_KP_SERVER_AUTH = b"1.3.6.1.5.5.7.3.1" +szOID_RSA_SHA1RSA = b"1.2.840.113549.1.1.5" advapi32 = windll.advapi32 crypt32 = windll.crypt32 diff --git a/cloudbaseinit/utils/windows/x509.py b/cloudbaseinit/utils/windows/x509.py index 7e2de462..392d22cd 100644 --- a/cloudbaseinit/utils/windows/x509.py +++ b/cloudbaseinit/utils/windows/x509.py @@ -21,6 +21,7 @@ import uuid from ctypes import wintypes +from cloudbaseinit.utils import encoding from cloudbaseinit.utils.windows import cryptoapi from cloudbaseinit.utils import x509constants @@ -202,7 +203,7 @@ class CryptoAPICertManager(object): free(subject_encoded) def _get_cert_base64(self, cert_data): - base64_cert_data = cert_data + base64_cert_data = encoding.get_as_string(cert_data) if base64_cert_data.startswith(x509constants.PEM_HEADER): base64_cert_data = base64_cert_data[len(x509constants.PEM_HEADER):] if base64_cert_data.endswith(x509constants.PEM_FOOTER):