Fixes Python34 bugs

This fix is needed due to the different encodings on Python27 and
Python34. Some tests failed and some string values were not
compatible.

Change-Id: I7206a0549051971da76b3b30221aea9f86eed904
This commit is contained in:
Robert Tingirica 2014-10-24 19:49:02 +03:00
parent 98be94fd27
commit b0aa3b56b8
21 changed files with 131 additions and 60 deletions

View File

@ -19,6 +19,7 @@ from oslo.config import cfg
from cloudbaseinit.metadata.services import base from cloudbaseinit.metadata.services import base
from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.utils import encoding
from cloudbaseinit.utils import x509constants from cloudbaseinit.utils import x509constants
opts = [ opts = [
@ -48,7 +49,8 @@ class BaseOpenStackService(base.BaseMetadataService):
path = posixpath.normpath( path = posixpath.normpath(
posixpath.join('openstack', version, 'meta_data.json')) posixpath.join('openstack', version, 'meta_data.json'))
data = self._get_cache_data(path) data = self._get_cache_data(path)
return json.loads(data.decode('utf8')) if data:
return json.loads(encoding.get_as_string(data))
def get_instance_id(self): def get_instance_id(self):
return self._get_meta_data().get('uuid') return self._get_meta_data().get('uuid')
@ -98,10 +100,12 @@ class BaseOpenStackService(base.BaseMetadataService):
i += 1 i += 1
if not cert_data: if not cert_data:
# Look if the user_data contains a PEM certificate # Look if the user_data contains a PEM certificate
try: try:
user_data = self.get_user_data() user_data = self.get_user_data()
if user_data.startswith(x509constants.PEM_HEADER): if user_data.startswith(
x509constants.PEM_HEADER.encode()):
cert_data = user_data cert_data = user_data
except base.NotExistingMetadataException: except base.NotExistingMetadataException:
LOG.debug("user_data metadata not present") LOG.debug("user_data metadata not present")

View File

@ -19,6 +19,8 @@ import os
import subprocess import subprocess
import sys import sys
from cloudbaseinit.utils import encoding
class BaseOSUtils(object): class BaseOSUtils(object):
PROTOCOL_TCP = "TCP" PROTOCOL_TCP = "TCP"
@ -34,7 +36,9 @@ class BaseOSUtils(object):
# On Windows os.urandom() uses CryptGenRandom, which is a # On Windows os.urandom() uses CryptGenRandom, which is a
# cryptographically secure pseudorandom number generator # cryptographically secure pseudorandom number generator
b64_password = base64.b64encode(os.urandom(256)) b64_password = base64.b64encode(os.urandom(256))
return b64_password.replace('/', '').replace('+', '')[:length] b64_password = encoding.get_as_string(b64_password).replace(
'/', '').replace('+', '')[:length]
return b64_password
def execute_process(self, args, shell=True, decode_output=False): def execute_process(self, args, shell=True, decode_output=False):
p = subprocess.Popen(args, p = subprocess.Popen(args,

View File

@ -28,6 +28,7 @@ import wmi
from cloudbaseinit import exception from cloudbaseinit import exception
from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.osutils import base from cloudbaseinit.osutils import base
from cloudbaseinit.utils import encoding
from cloudbaseinit.utils.windows import network from cloudbaseinit.utils.windows import network
@ -707,9 +708,12 @@ class WindowsUtils(base.BaseOSUtils):
while i < forward_table.dwNumEntries: while i < forward_table.dwNumEntries:
row = table[i] row = table[i]
routing_table.append(( routing_table.append((
Ws2_32.inet_ntoa(row.dwForwardDest), encoding.get_as_string(Ws2_32.inet_ntoa(
Ws2_32.inet_ntoa(row.dwForwardMask), row.dwForwardDest)),
Ws2_32.inet_ntoa(row.dwForwardNextHop), encoding.get_as_string(Ws2_32.inet_ntoa(
row.dwForwardMask)),
encoding.get_as_string(Ws2_32.inet_ntoa(
row.dwForwardNextHop)),
row.dwForwardIfIndex, row.dwForwardIfIndex,
row.dwForwardMetric1)) row.dwForwardMetric1))
i += 1 i += 1

View File

@ -22,6 +22,7 @@ from cloudbaseinit.metadata.services import base as service_base
from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.osutils import factory as osutils_factory from cloudbaseinit.osutils import factory as osutils_factory
from cloudbaseinit.plugins import base as plugin_base from cloudbaseinit.plugins import base as plugin_base
from cloudbaseinit.utils import encoding
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -58,6 +59,7 @@ class NetworkConfigPlugin(plugin_base.BasePlugin):
content_path = network_config['content_path'] content_path = network_config['content_path']
content_name = content_path.rsplit('/', 1)[-1] content_name = content_path.rsplit('/', 1)[-1]
debian_network_conf = service.get_content(content_name) debian_network_conf = service.get_content(content_name)
debian_network_conf = encoding.get_as_string(debian_network_conf)
LOG.debug('network config content:\n%s' % debian_network_conf) LOG.debug('network config content:\n%s' % debian_network_conf)

View File

@ -47,7 +47,7 @@ class SetUserPasswordPlugin(base.BasePlugin):
def _get_ssh_public_key(self, service): def _get_ssh_public_key(self, service):
public_keys = service.get_public_keys() public_keys = service.get_public_keys()
if public_keys: if public_keys:
return public_keys[0] return list(public_keys)[0]
def _get_password(self, service, osutils): def _get_password(self, service, osutils):
if CONF.inject_user_password: if CONF.inject_user_password:

View File

@ -23,6 +23,7 @@ from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.plugins import base from cloudbaseinit.plugins import base
from cloudbaseinit.plugins.windows.userdataplugins import factory from cloudbaseinit.plugins.windows.userdataplugins import factory
from cloudbaseinit.plugins.windows import userdatautils from cloudbaseinit.plugins.windows import userdatautils
from cloudbaseinit.utils import encoding
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -40,6 +41,7 @@ class UserDataPlugin(base.BasePlugin):
if not user_data: if not user_data:
return (base.PLUGIN_EXECUTION_DONE, False) return (base.PLUGIN_EXECUTION_DONE, False)
LOG.debug('User data content length: %d' % len(user_data))
user_data = self._check_gzip_compression(user_data) user_data = self._check_gzip_compression(user_data)
return self._process_user_data(user_data) return self._process_user_data(user_data)
@ -53,14 +55,16 @@ class UserDataPlugin(base.BasePlugin):
return user_data return user_data
def _parse_mime(self, user_data): def _parse_mime(self, user_data):
return email.message_from_string(user_data).walk() user_data_str = encoding.get_as_string(user_data)
LOG.debug('User data content:\n%s' % user_data_str)
return email.message_from_string(user_data_str).walk()
def _process_user_data(self, user_data): def _process_user_data(self, user_data):
plugin_status = base.PLUGIN_EXECUTION_DONE plugin_status = base.PLUGIN_EXECUTION_DONE
reboot = False reboot = False
LOG.debug('User data content:\n%s' % user_data) if user_data.startswith(b'Content-Type: multipart'):
if user_data.startswith('Content-Type: multipart'):
user_data_plugins = factory.load_plugins() user_data_plugins = factory.load_plugins()
user_handlers = {} user_handlers = {}
@ -158,7 +162,7 @@ class UserDataPlugin(base.BasePlugin):
return (plugin_status, reboot) return (plugin_status, reboot)
def _process_non_multi_part(self, user_data): def _process_non_multi_part(self, user_data):
if user_data.startswith('#cloud-config'): if user_data.startswith(b'#cloud-config'):
user_data_plugins = factory.load_plugins() user_data_plugins = factory.load_plugins()
cloud_config_plugin = user_data_plugins.get('text/cloud-config') cloud_config_plugin = user_data_plugins.get('text/cloud-config')
ret_val = cloud_config_plugin.process(user_data) ret_val = cloud_config_plugin.process(user_data)

View File

@ -20,6 +20,7 @@ from oslo.config import cfg
from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.plugins.windows.userdataplugins import base from cloudbaseinit.plugins.windows.userdataplugins import base
from cloudbaseinit.plugins.windows import userdatautils from cloudbaseinit.plugins.windows import userdatautils
from cloudbaseinit.utils import encoding
opts = [ opts = [
cfg.StrOpt('heat_config_dir', default='C:\\cfn', help='The directory ' cfg.StrOpt('heat_config_dir', default='C:\\cfn', help='The directory '
@ -47,8 +48,7 @@ class HeatPlugin(base.BaseUserDataPlugin):
file_name = os.path.join(CONF.heat_config_dir, part.get_filename()) file_name = os.path.join(CONF.heat_config_dir, part.get_filename())
self._check_dir(file_name) self._check_dir(file_name)
with open(file_name, 'wb') as f: encoding.write_file(file_name, part.get_payload())
f.write(part.get_payload())
if part.get_filename() == self._heat_user_data_filename: if part.get_filename() == self._heat_user_data_filename:
return userdatautils.execute_user_data_script(part.get_payload()) return userdatautils.execute_user_data_script(part.get_payload())

View File

@ -19,6 +19,7 @@ import tempfile
from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.plugins.windows import fileexecutils from cloudbaseinit.plugins.windows import fileexecutils
from cloudbaseinit.plugins.windows.userdataplugins import base from cloudbaseinit.plugins.windows.userdataplugins import base
from cloudbaseinit.utils import encoding
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -32,8 +33,7 @@ class ShellScriptPlugin(base.BaseUserDataPlugin):
target_path = os.path.join(tempfile.gettempdir(), file_name) target_path = os.path.join(tempfile.gettempdir(), file_name)
try: try:
with open(target_path, 'wb') as f: encoding.write_file(target_path, part.get_payload())
f.write(part.get_payload())
return fileexecutils.exec_file(target_path) return fileexecutils.exec_file(target_path)
except Exception as ex: except Exception as ex:

View File

@ -19,6 +19,7 @@ import uuid
from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.osutils import factory as osutils_factory from cloudbaseinit.osutils import factory as osutils_factory
from cloudbaseinit.utils import encoding
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -54,8 +55,7 @@ def execute_user_data_script(user_data):
return 0 return 0
try: try:
with open(target_path, 'wb') as f: encoding.write_file(target_path, user_data)
f.write(user_data)
if powershell: if powershell:
(out, err, (out, err,

View File

@ -126,7 +126,7 @@ class BaseOpenStackServiceTest(unittest.TestCase):
response = self._service.get_client_auth_certs() response = self._service.get_client_auth_certs()
mock_get_meta_data.assert_called_once_with() mock_get_meta_data.assert_called_once_with()
if 'meta' in meta_data: if 'meta' in meta_data:
self.assertEqual(['fake cert'], response) self.assertEqual([b'fake cert'], response)
elif type(ret_value) is str and ret_value.startswith( elif type(ret_value) is str and ret_value.startswith(
x509constants.PEM_HEADER): x509constants.PEM_HEADER):
mock_get_user_data.assert_called_once_with() mock_get_user_data.assert_called_once_with()
@ -136,11 +136,11 @@ class BaseOpenStackServiceTest(unittest.TestCase):
def test_get_client_auth_certs(self): def test_get_client_auth_certs(self):
self._test_get_client_auth_certs( self._test_get_client_auth_certs(
meta_data={'meta': {'admin_cert0': 'fake cert'}}) meta_data={'meta': {'admin_cert0': b'fake cert'}})
def test_get_client_auth_certs_no_cert_data(self): def test_get_client_auth_certs_no_cert_data(self):
self._test_get_client_auth_certs( self._test_get_client_auth_certs(
meta_data={}, ret_value=x509constants.PEM_HEADER) meta_data={}, ret_value=x509constants.PEM_HEADER.encode())
def test_get_client_auth_certs_no_cert_data_exception(self): def test_get_client_auth_certs_no_cert_data_exception(self):
self._test_get_client_auth_certs( self._test_get_client_auth_certs(

View File

@ -55,10 +55,12 @@ class SetUserPasswordPluginTests(unittest.TestCase):
def _test_get_ssh_public_key(self, data_exists): def _test_get_ssh_public_key(self, data_exists):
mock_service = mock.MagicMock() mock_service = mock.MagicMock()
public_keys = self.fake_data['public_keys'] public_keys = self.fake_data['public_keys']
mock_service.get_public_keys.return_value = public_keys mock_service.get_public_keys.return_value = public_keys.values()
response = self._setpassword_plugin._get_ssh_public_key(mock_service) response = self._setpassword_plugin._get_ssh_public_key(mock_service)
mock_service.get_public_keys.assert_called_with() mock_service.get_public_keys.assert_called_with()
self.assertEqual(public_keys[0], response) self.assertEqual(list(public_keys.values())[0], response)
def test_get_ssh_plublic_key(self): def test_get_ssh_plublic_key(self):
self._test_get_ssh_public_key(data_exists=True) self._test_get_ssh_public_key(data_exists=True)

View File

@ -58,7 +58,7 @@ class UserDataPluginTest(unittest.TestCase):
self.assertEqual(response, mock_process_user_data.return_value) self.assertEqual(response, mock_process_user_data.return_value)
def test_execute(self): def test_execute(self):
self._test_execute(ret_val=mock.sentinel.fake_data) self._test_execute(ret_val='fake_data')
def test_execute_no_data(self): def test_execute_no_data(self):
self._test_execute(ret_val=None) self._test_execute(ret_val=None)
@ -85,10 +85,15 @@ class UserDataPluginTest(unittest.TestCase):
self.assertEqual(data, response) self.assertEqual(data, response)
@mock.patch('email.message_from_string') @mock.patch('email.message_from_string')
def test_parse_mime(self, mock_message_from_string): @mock.patch('cloudbaseinit.utils.encoding.get_as_string')
def test_parse_mime(self, mock_get_as_string, mock_message_from_string):
fake_user_data = 'fake data' fake_user_data = 'fake data'
response = self._userdata._parse_mime(user_data=fake_user_data) response = self._userdata._parse_mime(user_data=fake_user_data)
mock_message_from_string.assert_called_once_with(fake_user_data)
mock_get_as_string.assert_called_once_with(fake_user_data)
mock_message_from_string.assert_called_once_with(
mock_get_as_string.return_value)
self.assertEqual(response, mock_message_from_string().walk()) self.assertEqual(response, mock_message_from_string().walk())
@mock.patch('cloudbaseinit.plugins.windows.userdataplugins.factory.' @mock.patch('cloudbaseinit.plugins.windows.userdataplugins.factory.'
@ -110,7 +115,8 @@ class UserDataPluginTest(unittest.TestCase):
mock_process_part.return_value = (base.PLUGIN_EXECUTION_DONE, reboot) mock_process_part.return_value = (base.PLUGIN_EXECUTION_DONE, reboot)
response = self._userdata._process_user_data(user_data=user_data) response = self._userdata._process_user_data(user_data=user_data)
if user_data.startswith('Content-Type: multipart'):
if user_data.startswith(b'Content-Type: multipart'):
mock_load_plugins.assert_called_once_with() mock_load_plugins.assert_called_once_with()
mock_parse_mime.assert_called_once_with(user_data) mock_parse_mime.assert_called_once_with(user_data)
mock_process_part.assert_called_once_with(mock_part, mock_process_part.assert_called_once_with(mock_part,
@ -122,15 +128,15 @@ class UserDataPluginTest(unittest.TestCase):
response) response)
def test_process_user_data_multipart_reboot_true(self): def test_process_user_data_multipart_reboot_true(self):
self._test_process_user_data(user_data='Content-Type: multipart', self._test_process_user_data(user_data=b'Content-Type: multipart',
reboot=True) reboot=True)
def test_process_user_data_multipart_reboot_false(self): def test_process_user_data_multipart_reboot_false(self):
self._test_process_user_data(user_data='Content-Type: multipart', self._test_process_user_data(user_data=b'Content-Type: multipart',
reboot=False) reboot=False)
def test_process_user_data_non_multipart(self): def test_process_user_data_non_multipart(self):
self._test_process_user_data(user_data='Content-Type: non-multipart', self._test_process_user_data(user_data=b'Content-Type: non-multipart',
reboot=False) reboot=False)
@mock.patch('cloudbaseinit.plugins.windows.userdata.UserDataPlugin' @mock.patch('cloudbaseinit.plugins.windows.userdata.UserDataPlugin'
@ -250,7 +256,7 @@ class UserDataPluginTest(unittest.TestCase):
'._get_plugin_return_value') '._get_plugin_return_value')
def test_process_non_multi_part(self, mock_get_plugin_return_value, def test_process_non_multi_part(self, mock_get_plugin_return_value,
mock_execute_user_data_script): mock_execute_user_data_script):
user_data = 'fake' user_data = b'fake'
response = self._userdata._process_non_multi_part(user_data=user_data) response = self._userdata._process_non_multi_part(user_data=user_data)
mock_execute_user_data_script.assert_called_once_with(user_data) mock_execute_user_data_script.assert_called_once_with(user_data)
mock_get_plugin_return_value.assert_called_once_with( mock_get_plugin_return_value.assert_called_once_with(
@ -263,7 +269,7 @@ class UserDataPluginTest(unittest.TestCase):
'._get_plugin_return_value') '._get_plugin_return_value')
def test_process_non_multi_part_cloud_config( def test_process_non_multi_part_cloud_config(
self, mock_get_plugin_return_value, mock_load_plugins): self, mock_get_plugin_return_value, mock_load_plugins):
user_data = '#cloud-config' user_data = b'#cloud-config'
mock_return_value = mock.sentinel.return_value mock_return_value = mock.sentinel.return_value
mock_cloud_config_plugin = mock.Mock() mock_cloud_config_plugin = mock.Mock()
mock_cloud_config_plugin.process.return_value = mock_return_value mock_cloud_config_plugin.process.return_value = mock_return_value

View File

@ -37,8 +37,9 @@ class UserDataUtilsTest(unittest.TestCase):
@mock.patch('os.path.expandvars') @mock.patch('os.path.expandvars')
@mock.patch('cloudbaseinit.osutils.factory.get_os_utils') @mock.patch('cloudbaseinit.osutils.factory.get_os_utils')
@mock.patch('uuid.uuid4') @mock.patch('uuid.uuid4')
def _test_execute_user_data_script(self, mock_uuid4, mock_get_os_utils, @mock.patch('cloudbaseinit.utils.encoding.write_file')
mock_path_expandvars, def _test_execute_user_data_script(self, mock_write_file, mock_uuid4,
mock_get_os_utils, mock_path_expandvars,
mock_path_exists, mock_os_remove, mock_path_exists, mock_os_remove,
mock_gettempdir, mock_re_search, mock_gettempdir, mock_re_search,
fake_user_data): fake_user_data):
@ -51,6 +52,8 @@ class UserDataUtilsTest(unittest.TestCase):
powershell = False powershell = False
mock_get_os_utils.return_value = mock_osutils mock_get_os_utils.return_value = mock_osutils
mock_path_exists.return_value = True mock_path_exists.return_value = True
extension = ''
if fake_user_data == '^rem cmd\s': if fake_user_data == '^rem cmd\s':
side_effect = [match_instance] side_effect = [match_instance]
number_of_calls = 1 number_of_calls = 1
@ -88,12 +91,14 @@ class UserDataUtilsTest(unittest.TestCase):
mock_re_search.side_effect = side_effect mock_re_search.side_effect = side_effect
with mock.patch('cloudbaseinit.plugins.windows.userdatautils.open',
mock.mock_open(), create=True):
response = userdatautils.execute_user_data_script(fake_user_data) response = userdatautils.execute_user_data_script(fake_user_data)
mock_gettempdir.assert_called_once_with() mock_gettempdir.assert_called_once_with()
self.assertEqual(number_of_calls, mock_re_search.call_count) self.assertEqual(number_of_calls, mock_re_search.call_count)
if args: if args:
mock_write_file.assert_called_once_with(path + extension,
fake_user_data)
mock_osutils.execute_process.assert_called_with(args, shell) mock_osutils.execute_process.assert_called_with(args, shell)
mock_os_remove.assert_called_once_with(path + extension) mock_os_remove.assert_called_once_with(path + extension)
self.assertEqual(None, response) self.assertEqual(None, response)
@ -106,24 +111,24 @@ class UserDataUtilsTest(unittest.TestCase):
self.assertEqual(0, response) self.assertEqual(0, response)
def test_handle_batch(self): def test_handle_batch(self):
fake_user_data = '^rem cmd\s' fake_user_data = b'^rem cmd\s'
self._test_execute_user_data_script(fake_user_data=fake_user_data) self._test_execute_user_data_script(fake_user_data=fake_user_data)
def test_handle_python(self): def test_handle_python(self):
self._test_execute_user_data_script(fake_user_data='^#!/usr/bin/env' self._test_execute_user_data_script(
'\spython\s') fake_user_data=b'^#!/usr/bin/env\spython\s')
def test_handle_shell(self): def test_handle_shell(self):
self._test_execute_user_data_script(fake_user_data='^#!') self._test_execute_user_data_script(fake_user_data=b'^#!')
def test_handle_powershell(self): def test_handle_powershell(self):
self._test_execute_user_data_script(fake_user_data='^#ps1\s') self._test_execute_user_data_script(fake_user_data=b'^#ps1\s')
def test_handle_powershell_sysnative(self): def test_handle_powershell_sysnative(self):
self._test_execute_user_data_script(fake_user_data='#ps1_sysnative\s') self._test_execute_user_data_script(fake_user_data=b'#ps1_sysnative\s')
def test_handle_powershell_sysnative_no_sysnative(self): def test_handle_powershell_sysnative_no_sysnative(self):
self._test_execute_user_data_script(fake_user_data='#ps1_x86\s') self._test_execute_user_data_script(fake_user_data=b'#ps1_x86\s')
def test_handle_unsupported_format(self): def test_handle_unsupported_format(self):
self._test_execute_user_data_script(fake_user_data='unsupported') self._test_execute_user_data_script(fake_user_data=b'unsupported')

View File

@ -50,19 +50,19 @@ class HeatUserDataHandlerTests(unittest.TestCase):
'.execute_user_data_script') '.execute_user_data_script')
@mock.patch('cloudbaseinit.plugins.windows.userdataplugins.heat' @mock.patch('cloudbaseinit.plugins.windows.userdataplugins.heat'
'.HeatPlugin._check_dir') '.HeatPlugin._check_dir')
def _test_process(self, mock_check_dir, mock_execute_user_data_script, @mock.patch('cloudbaseinit.utils.encoding.write_file')
filename): def _test_process(self, mock_write_file, mock_check_dir,
mock_execute_user_data_script, filename):
mock_part = mock.MagicMock() mock_part = mock.MagicMock()
mock_part.get_filename.return_value = filename mock_part.get_filename.return_value = filename
with mock.patch('six.moves.builtins.open', mock.mock_open(),
create=True) as handle:
response = self._heat.process(mock_part) response = self._heat.process(mock_part)
handle().write.assert_called_once_with(mock_part.get_payload())
path = os.path.join(CONF.heat_config_dir, filename) path = os.path.join(CONF.heat_config_dir, filename)
mock_check_dir.assert_called_once_with(path) mock_check_dir.assert_called_once_with(path)
mock_part.get_filename.assert_called_with() mock_part.get_filename.assert_called_with()
mock_write_file.assert_called_once_with(
path, mock_part.get_payload.return_value)
if filename == self._heat._heat_user_data_filename: if filename == self._heat._heat_user_data_filename:
mock_execute_user_data_script.assert_called_with( mock_execute_user_data_script.assert_called_with(
mock_part.get_payload()) mock_part.get_payload())

View File

@ -27,8 +27,9 @@ class ShellScriptPluginTests(unittest.TestCase):
@mock.patch('cloudbaseinit.osutils.factory.get_os_utils') @mock.patch('cloudbaseinit.osutils.factory.get_os_utils')
@mock.patch('tempfile.gettempdir') @mock.patch('tempfile.gettempdir')
@mock.patch('cloudbaseinit.plugins.windows.fileexecutils.exec_file') @mock.patch('cloudbaseinit.plugins.windows.fileexecutils.exec_file')
def _test_process(self, mock_exec_file, mock_gettempdir, mock_get_os_utils, @mock.patch('cloudbaseinit.utils.encoding.write_file')
exception=False): def _test_process(self, mock_write_file, mock_exec_file, mock_gettempdir,
mock_get_os_utils, exception=False):
fake_dir_path = os.path.join("fake", "dir") fake_dir_path = os.path.join("fake", "dir")
mock_osutils = mock.MagicMock() mock_osutils = mock.MagicMock()
mock_part = mock.MagicMock() mock_part = mock.MagicMock()
@ -45,6 +46,8 @@ class ShellScriptPluginTests(unittest.TestCase):
response = self._shellscript.process(mock_part) response = self._shellscript.process(mock_part)
mock_part.get_filename.assert_called_once_with() mock_part.get_filename.assert_called_once_with()
mock_write_file.assert_called_once_with(
fake_target, mock_part.get_payload.return_value)
mock_exec_file.assert_called_once_with(fake_target) mock_exec_file.assert_called_once_with(fake_target)
mock_part.get_payload.assert_called_once_with() mock_part.get_payload.assert_called_once_with()
mock_gettempdir.assert_called_once_with() mock_gettempdir.assert_called_once_with()

View File

@ -45,8 +45,7 @@ class DHCPUtilsTests(unittest.TestCase):
data += b'\x00' * 128 data += b'\x00' * 128
data += dhcp._DHCP_COOKIE data += dhcp._DHCP_COOKIE
data += b'\x35\x01\x01' data += b'\x35\x01\x01'
data += b'\x3c' + struct.pack('b', data += b'\x3c' + struct.pack('b', len('fake id')) + 'fake id'.encode(
len('fake id')) + 'fake id'.encode(
'ascii') 'ascii')
data += b'\x3d\x07\x01' data += b'\x3d\x07\x01'
data += fake_mac_address_b data += fake_mac_address_b

View File

@ -281,6 +281,8 @@ class CryptoAPICertManagerTests(unittest.TestCase):
fake_cert_data += x509constants.PEM_HEADER + '\n' fake_cert_data += x509constants.PEM_HEADER + '\n'
fake_cert_data += 'fake cert' + '\n' fake_cert_data += 'fake cert' + '\n'
fake_cert_data += x509constants.PEM_FOOTER fake_cert_data += x509constants.PEM_FOOTER
fake_cert_data = fake_cert_data.encode()
response = self._x509_manager._get_cert_base64(fake_cert_data) response = self._x509_manager._get_cert_base64(fake_cert_data)
self.assertEqual('fake cert', response) self.assertEqual('fake cert', response)

View File

@ -158,7 +158,7 @@ class CryptManager(object):
key_type_len = struct.unpack('>I', pub_key[offset:offset + 4])[0] key_type_len = struct.unpack('>I', pub_key[offset:offset + 4])[0]
offset += 4 offset += 4
key_type = pub_key[offset:offset + key_type_len] key_type = pub_key[offset:offset + key_type_len].decode('utf-8')
offset += key_type_len offset += key_type_len
if key_type not in ['ssh-rsa', 'rsa', 'rsa1']: if key_type not in ['ssh-rsa', 'rsa', 'rsa1']:

View File

@ -0,0 +1,35 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2014 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import six
def get_as_string(value):
if value is None or isinstance(value, six.text_type):
return value
else:
try:
return value.decode()
except Exception:
pass
def write_file(target_path, data, mode='wb'):
if isinstance(data, six.text_type) and 'b' in mode:
data = data.encode()
with open(target_path, mode) as f:
f.write(data)

View File

@ -102,8 +102,8 @@ CRYPT_STRING_BASE64 = 1
PKCS_7_ASN_ENCODING = 65536 PKCS_7_ASN_ENCODING = 65536
PROV_RSA_FULL = 1 PROV_RSA_FULL = 1
X509_ASN_ENCODING = 1 X509_ASN_ENCODING = 1
szOID_PKIX_KP_SERVER_AUTH = "1.3.6.1.5.5.7.3.1" szOID_PKIX_KP_SERVER_AUTH = b"1.3.6.1.5.5.7.3.1"
szOID_RSA_SHA1RSA = "1.2.840.113549.1.1.5" szOID_RSA_SHA1RSA = b"1.2.840.113549.1.1.5"
advapi32 = windll.advapi32 advapi32 = windll.advapi32
crypt32 = windll.crypt32 crypt32 = windll.crypt32

View File

@ -21,6 +21,7 @@ import uuid
from ctypes import wintypes from ctypes import wintypes
from cloudbaseinit.utils import encoding
from cloudbaseinit.utils.windows import cryptoapi from cloudbaseinit.utils.windows import cryptoapi
from cloudbaseinit.utils import x509constants from cloudbaseinit.utils import x509constants
@ -202,7 +203,7 @@ class CryptoAPICertManager(object):
free(subject_encoded) free(subject_encoded)
def _get_cert_base64(self, cert_data): def _get_cert_base64(self, cert_data):
base64_cert_data = cert_data base64_cert_data = encoding.get_as_string(cert_data)
if base64_cert_data.startswith(x509constants.PEM_HEADER): if base64_cert_data.startswith(x509constants.PEM_HEADER):
base64_cert_data = base64_cert_data[len(x509constants.PEM_HEADER):] base64_cert_data = base64_cert_data[len(x509constants.PEM_HEADER):]
if base64_cert_data.endswith(x509constants.PEM_FOOTER): if base64_cert_data.endswith(x509constants.PEM_FOOTER):