Respawns process with new service credentials
After updating the service user credentials with a randomly generated password, the process is re-spawned using the new credentials. Change-Id: Ie793411a34e9d90ef15de2d85bddf055c69194e0 Co-Authored-By: Alexandru Coman <acoman@cloudbasesolutions.com> Co-Authored-By: Stefan Caraiman <scaraiman@cloudbasesolutions.com> Closes-Bug: #1641342
This commit is contained in:
parent
4895b36344
commit
ed57b43833
@ -121,14 +121,6 @@ class GlobalOptions(conf_base.Options):
|
||||
'the password is a clear text password, coming from the '
|
||||
'metadata. The last option is `no`, when the user is '
|
||||
'never forced to change the password.'),
|
||||
cfg.BoolOpt(
|
||||
'reset_service_password', default=True,
|
||||
help='If set to True, the service user password will be '
|
||||
'reset at each execution with a new random value of '
|
||||
'appropriate length and complexity, unless the user is '
|
||||
'a built-in or domain account.'
|
||||
'This is needed to avoid "pass the hash" attacks on '
|
||||
'Windows cloned instances.'),
|
||||
cfg.ListOpt(
|
||||
'metadata_services',
|
||||
default=[
|
||||
@ -199,9 +191,21 @@ class GlobalOptions(conf_base.Options):
|
||||
'plugins ordered by priority.'),
|
||||
]
|
||||
|
||||
self._cli_options = [
|
||||
cfg.BoolOpt(
|
||||
'reset_service_password', default=True,
|
||||
help='If set to True, the service user password will be '
|
||||
'reset at each execution with a new random value of '
|
||||
'appropriate length and complexity, unless the user is '
|
||||
'a built-in or domain account.'
|
||||
'This is needed to avoid "pass the hash" attacks on '
|
||||
'Windows cloned instances.'),
|
||||
]
|
||||
|
||||
def register(self):
|
||||
"""Register the current options to the global ConfigOpts object."""
|
||||
self._config.register_opts(self._options)
|
||||
self._config.register_cli_opts(self._cli_options)
|
||||
self._config.register_opts(self._options + self._cli_options)
|
||||
|
||||
def list(self):
|
||||
"""Return a list which contains all the available options."""
|
||||
|
@ -13,6 +13,7 @@
|
||||
# under the License.
|
||||
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
|
||||
from oslo_log import log as oslo_logging
|
||||
@ -22,6 +23,7 @@ from cloudbaseinit.metadata import factory as metadata_factory
|
||||
from cloudbaseinit.osutils import factory as osutils_factory
|
||||
from cloudbaseinit.plugins.common import base as plugins_base
|
||||
from cloudbaseinit.plugins import factory as plugins_factory
|
||||
from cloudbaseinit.utils import log as logging
|
||||
from cloudbaseinit import version
|
||||
|
||||
|
||||
@ -113,13 +115,55 @@ class InitManager(object):
|
||||
|
||||
return reboot_required
|
||||
|
||||
def configure_host(self):
|
||||
LOG.info('Cloudbase-Init version: %s', version.get_version())
|
||||
@staticmethod
|
||||
def _reset_service_password_and_respawn(osutils):
|
||||
"""Avoid pass the hash attacks from cloned instances."""
|
||||
credentials = osutils.reset_service_password()
|
||||
if not credentials:
|
||||
return
|
||||
|
||||
service_domain, service_user, service_password = credentials
|
||||
_, current_user = osutils.get_current_user()
|
||||
# Notes(alexcoman): No need to check domain as password reset applies
|
||||
# to local users only.
|
||||
if current_user != service_user:
|
||||
LOG.debug("No need to respawn process. Current user: "
|
||||
"%(current_user)s. Service user: "
|
||||
"%(service_user)s",
|
||||
{"current_user": current_user,
|
||||
"service_user": service_user})
|
||||
return
|
||||
|
||||
# Note(alexcoman): In order to avoid conflicts caused by the logging
|
||||
# handlers being shared between the current process and the new one,
|
||||
# any logging handlers for the current logger object will be closed.
|
||||
# By doing so, the next time the logger is called, it will be created
|
||||
# under the newly updated proccess, thus avoiding any issues or
|
||||
# conflicts where the logging can't be done.
|
||||
logging.release_logging_handlers("cloudbaseinit")
|
||||
|
||||
# Note(alexcoman): In some edge cases the sys.args doesn't contain
|
||||
# the python executable. In order to avoid this kind of issue the
|
||||
# sys.executable will be injected into the arguments if it's necessary.
|
||||
arguments = sys.argv + ["--noreset_service_password"]
|
||||
if os.path.basename(arguments[0]).endswith(".py"):
|
||||
arguments.insert(0, sys.executable)
|
||||
|
||||
LOG.info("Respawning current process with updated credentials.")
|
||||
token = osutils.create_user_logon_session(
|
||||
service_user, service_password, service_domain,
|
||||
logon_type=osutils.LOGON32_LOGON_BATCH)
|
||||
exit_code = osutils.execute_process_as_user(token, arguments)
|
||||
LOG.info("Process execution ended with exit code: %s", exit_code)
|
||||
sys.exit(exit_code)
|
||||
|
||||
def configure_host(self):
|
||||
osutils = osutils_factory.get_os_utils()
|
||||
if CONF.reset_service_password:
|
||||
# Avoid pass the hash attacks from cloned instances
|
||||
osutils.reset_service_password()
|
||||
|
||||
if CONF.reset_service_password and sys.platform == 'win32':
|
||||
self._reset_service_password_and_respawn(osutils)
|
||||
|
||||
LOG.info('Cloudbase-Init version: %s', version.get_version())
|
||||
osutils.wait_for_boot_completion()
|
||||
|
||||
reboot_required = self._handle_plugins_stage(
|
||||
|
@ -129,3 +129,7 @@ class BaseOSUtils(object):
|
||||
def get_service_username(self, service_name):
|
||||
"""Retrieve the username under which a service runs."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_user(self):
|
||||
"""Retrieve the username under which the current thread runs."""
|
||||
raise NotImplementedError()
|
||||
|
@ -18,6 +18,7 @@ from ctypes import wintypes
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
from oslo_log import log as oslo_logging
|
||||
@ -57,6 +58,7 @@ Ws2_32 = ctypes.windll.Ws2_32
|
||||
setupapi = ctypes.windll.setupapi
|
||||
msvcrt = ctypes.cdll.msvcrt
|
||||
ntdll = ctypes.windll.ntdll
|
||||
secur32 = ctypes.windll.secur32
|
||||
|
||||
|
||||
class Win32_PROFILEINFO(ctypes.Structure):
|
||||
@ -144,6 +146,53 @@ class Win32_STORAGE_DEVICE_NUMBER(ctypes.Structure):
|
||||
]
|
||||
|
||||
|
||||
class Win32_STARTUPINFO_W(ctypes.Structure):
|
||||
_fields_ = [
|
||||
('cb', wintypes.DWORD),
|
||||
('lpReserved', wintypes.LPWSTR),
|
||||
('lpDesktop', wintypes.LPWSTR),
|
||||
('lpTitle', wintypes.LPWSTR),
|
||||
('dwX', wintypes.DWORD),
|
||||
('dwY', wintypes.DWORD),
|
||||
('dwXSize', wintypes.DWORD),
|
||||
('dwYSize', wintypes.DWORD),
|
||||
('dwXCountChars', wintypes.DWORD),
|
||||
('dwYCountChars', wintypes.DWORD),
|
||||
('dwFillAttribute', wintypes.DWORD),
|
||||
('dwFlags', wintypes.DWORD),
|
||||
('wShowWindow', wintypes.WORD),
|
||||
('cbReserved2', wintypes.WORD),
|
||||
('lpReserved2', ctypes.POINTER(wintypes.BYTE)),
|
||||
('hStdInput', wintypes.HANDLE),
|
||||
('hStdOutput', wintypes.HANDLE),
|
||||
('hStdError', wintypes.HANDLE),
|
||||
]
|
||||
|
||||
|
||||
class Win32_PROCESS_INFORMATION(ctypes.Structure):
|
||||
_fields_ = [
|
||||
('hProcess', wintypes.HANDLE),
|
||||
('hThread', wintypes.HANDLE),
|
||||
('dwProcessId', wintypes.DWORD),
|
||||
('dwThreadId', wintypes.DWORD),
|
||||
]
|
||||
|
||||
|
||||
advapi32.CreateProcessAsUserW.argtypes = [wintypes.HANDLE,
|
||||
wintypes.LPCWSTR,
|
||||
wintypes.LPWSTR,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
wintypes.BOOL,
|
||||
wintypes.DWORD,
|
||||
ctypes.c_void_p,
|
||||
wintypes.LPCWSTR,
|
||||
ctypes.POINTER(
|
||||
Win32_STARTUPINFO_W),
|
||||
ctypes.POINTER(
|
||||
Win32_PROCESS_INFORMATION)]
|
||||
advapi32.CreateProcessAsUserW.restype = wintypes.BOOL
|
||||
|
||||
msvcrt.malloc.argtypes = [ctypes.c_size_t]
|
||||
msvcrt.malloc.restype = ctypes.c_void_p
|
||||
|
||||
@ -201,6 +250,11 @@ iphlpapi.GetIpForwardTable.restype = wintypes.DWORD
|
||||
|
||||
Ws2_32.inet_ntoa.restype = ctypes.c_char_p
|
||||
|
||||
secur32.GetUserNameExW.argtypes = [wintypes.DWORD,
|
||||
wintypes.LPWSTR,
|
||||
ctypes.POINTER(wintypes.ULONG)]
|
||||
secur32.GetUserNameExW.restype = wintypes.BOOL
|
||||
|
||||
setupapi.SetupDiGetClassDevsW.argtypes = [ctypes.POINTER(disk.GUID),
|
||||
wintypes.LPCWSTR,
|
||||
wintypes.HANDLE,
|
||||
@ -269,6 +323,18 @@ class WindowsUtils(base.BaseOSUtils):
|
||||
|
||||
DRIVE_CDROM = 5
|
||||
|
||||
INFINITE = 0xFFFFFFFF
|
||||
|
||||
CREATE_NEW_CONSOLE = 0x10
|
||||
|
||||
LOGON32_LOGON_BATCH = 4
|
||||
LOGON32_LOGON_INTERACTIVE = 2
|
||||
LOGON32_LOGON_SERVICE = 5
|
||||
|
||||
LOGON32_PROVIDER_DEFAULT = 0
|
||||
|
||||
EXTENDED_NAME_FORMAT_SAM_COMPATIBLE = 2
|
||||
|
||||
SERVICE_STATUS_STOPPED = "Stopped"
|
||||
SERVICE_STATUS_START_PENDING = "Start Pending"
|
||||
SERVICE_STATUS_STOP_PENDING = "Stop Pending"
|
||||
@ -417,11 +483,17 @@ class WindowsUtils(base.BaseOSUtils):
|
||||
pass
|
||||
|
||||
def create_user_logon_session(self, username, password, domain='.',
|
||||
load_profile=True):
|
||||
load_profile=True,
|
||||
logon_type=LOGON32_LOGON_INTERACTIVE):
|
||||
LOG.debug("Creating logon session for user: %(domain)s\\%(username)s",
|
||||
{"username": username, "domain": domain})
|
||||
|
||||
token = wintypes.HANDLE()
|
||||
ret_val = advapi32.LogonUserW(six.text_type(username),
|
||||
six.text_type(domain),
|
||||
six.text_type(password), 2, 0,
|
||||
six.text_type(password),
|
||||
logon_type,
|
||||
self.LOGON32_PROVIDER_DEFAULT,
|
||||
ctypes.byref(token))
|
||||
if not ret_val:
|
||||
raise exception.WindowsCloudbaseInitException(
|
||||
@ -439,6 +511,70 @@ class WindowsUtils(base.BaseOSUtils):
|
||||
|
||||
return token
|
||||
|
||||
def get_current_user(self):
|
||||
"""Get the user account name from the underlying instance."""
|
||||
buf_len = wintypes.ULONG(512)
|
||||
buf = ctypes.create_unicode_buffer(512)
|
||||
|
||||
ret_val = secur32.GetUserNameExW(
|
||||
self.EXTENDED_NAME_FORMAT_SAM_COMPATIBLE,
|
||||
buf, ctypes.byref(buf_len))
|
||||
if not ret_val:
|
||||
raise exception.WindowsCloudbaseInitException(
|
||||
"GetUserNameExW failed: %r")
|
||||
|
||||
return buf.value.split("\\")
|
||||
|
||||
def execute_process_as_user(self, token, args, wait=True,
|
||||
new_console=False):
|
||||
"""Executes processes as an user.
|
||||
|
||||
:param token: Represents the user logon session token, resulted from
|
||||
running the 'create_user_logon_session' method.
|
||||
:param args: The arguments with which the process will be runned with.
|
||||
:param wait: Specifies if it's needed to wait for the process
|
||||
handler to finish up running all the operations
|
||||
on the process object.
|
||||
:param new_console: Specifies whether the process should run
|
||||
under a new console or not.
|
||||
:return: The exit code value resulted from the running process.
|
||||
:rtype: int
|
||||
"""
|
||||
LOG.debug("Executing process as user, command line: %s", args)
|
||||
|
||||
proc_info = Win32_PROCESS_INFORMATION()
|
||||
startup_info = Win32_STARTUPINFO_W()
|
||||
startup_info.cb = ctypes.sizeof(Win32_STARTUPINFO_W)
|
||||
startup_info.lpDesktop = ""
|
||||
|
||||
flags = self.CREATE_NEW_CONSOLE if new_console else 0
|
||||
cmdline = ctypes.create_unicode_buffer(subprocess.list2cmdline(args))
|
||||
|
||||
try:
|
||||
ret_val = advapi32.CreateProcessAsUserW(
|
||||
token, None, cmdline, None, None, False, flags, None, None,
|
||||
ctypes.byref(startup_info), ctypes.byref(proc_info))
|
||||
if not ret_val:
|
||||
raise exception.WindowsCloudbaseInitException(
|
||||
"CreateProcessAsUserW failed: %r")
|
||||
|
||||
if wait and proc_info.hProcess:
|
||||
kernel32.WaitForSingleObject(
|
||||
proc_info.hProcess, self.INFINITE)
|
||||
|
||||
exit_code = wintypes.DWORD()
|
||||
if not kernel32.GetExitCodeProcess(
|
||||
proc_info.hProcess, ctypes.byref(exit_code)):
|
||||
raise exception.WindowsCloudbaseInitException(
|
||||
"GetExitCodeProcess failed: %r")
|
||||
|
||||
return exit_code.value
|
||||
finally:
|
||||
if proc_info.hProcess:
|
||||
kernel32.CloseHandle(proc_info.hProcess)
|
||||
if proc_info.hThread:
|
||||
kernel32.CloseHandle(proc_info.hThread)
|
||||
|
||||
def close_user_logon_session(self, token):
|
||||
kernel32.CloseHandle(token)
|
||||
|
||||
@ -773,19 +909,19 @@ class WindowsUtils(base.BaseOSUtils):
|
||||
"""This is needed to avoid pass the hash attacks."""
|
||||
if not self.check_service_exists(self._service_name):
|
||||
LOG.info("Service does not exist: %s", self._service_name)
|
||||
return False
|
||||
return None
|
||||
|
||||
service_username = self.get_service_username(self._service_name)
|
||||
# Ignore builtin accounts
|
||||
if "\\" not in service_username:
|
||||
LOG.info("Skipping password reset, service running as a built-in "
|
||||
"account: %s", service_username)
|
||||
return False
|
||||
return None
|
||||
domain, username = service_username.split('\\')
|
||||
if domain != ".":
|
||||
LOG.info("Skipping password reset, service running as a domain "
|
||||
"account: %s", service_username)
|
||||
return False
|
||||
return None
|
||||
|
||||
LOG.debug('Resetting password for service user: %s', service_username)
|
||||
maximum_length = self.get_maximum_password_length()
|
||||
@ -793,7 +929,7 @@ class WindowsUtils(base.BaseOSUtils):
|
||||
self.set_user_password(username, password)
|
||||
self.set_service_credentials(
|
||||
self._service_name, service_username, password)
|
||||
return True
|
||||
return domain, username, password
|
||||
|
||||
def terminate(self):
|
||||
# Wait for the service to start. Polling the service "Started" property
|
||||
|
@ -2114,3 +2114,82 @@ class TestWindowsUtils(testutils.CloudbaseInitTestBase):
|
||||
self.assertEqual('dwForwardNextHop', given_route[2])
|
||||
self.assertEqual('dwForwardIfIndex', given_route[3])
|
||||
self.assertEqual('dwForwardMetric1', given_route[4])
|
||||
|
||||
def test_get_current_user(self):
|
||||
response = mock.Mock()
|
||||
response.value.split.return_value = mock.sentinel.user
|
||||
secur32 = self._ctypes_mock.windll.secur32
|
||||
self._ctypes_mock.create_unicode_buffer.return_value = response
|
||||
secur32.GetUserNameExW.side_effect = [True, False]
|
||||
|
||||
self.assertIs(self._winutils.get_current_user(), mock.sentinel.user)
|
||||
with self.assert_raises_windows_message("GetUserNameExW failed: %r",
|
||||
100):
|
||||
self._winutils.get_current_user()
|
||||
|
||||
@mock.patch('cloudbaseinit.osutils.windows.Win32_STARTUPINFO_W')
|
||||
@mock.patch('cloudbaseinit.osutils.windows.Win32_PROCESS_INFORMATION')
|
||||
@mock.patch('subprocess.list2cmdline')
|
||||
def _test_execute_process_as_user(self, mock_list2cmdline, mock_proc_info,
|
||||
mock_startup_info,
|
||||
token, args, wait, new_console):
|
||||
advapi32 = self._windll_mock.advapi32
|
||||
advapi32.CreateProcessAsUserW.return_value = True
|
||||
kernel32 = self._ctypes_mock.windll.kernel32
|
||||
kernel32.GetExitCodeProcess.return_value = True
|
||||
|
||||
proc_info = mock.Mock()
|
||||
proc_info.hProcess = wait
|
||||
proc_info.hThread = wait
|
||||
mock_proc_info.return_value = proc_info
|
||||
|
||||
command_line = mock.sentinel.command_line
|
||||
self._ctypes_mock.create_unicode_buffer.return_value = command_line
|
||||
|
||||
self._winutils.execute_process_as_user(token, args, wait, new_console)
|
||||
|
||||
self.assertEqual(advapi32.CreateProcessAsUserW.call_count, 1)
|
||||
if wait:
|
||||
kernel32.WaitForSingleObject.assert_called_once_with(
|
||||
proc_info.hProcess, self._winutils.INFINITE
|
||||
)
|
||||
self.assertEqual(kernel32.GetExitCodeProcess.call_count, 1)
|
||||
|
||||
if wait:
|
||||
self.assertEqual(kernel32.CloseHandle.call_count, 2)
|
||||
|
||||
mock_list2cmdline.assert_called_once_with(args)
|
||||
|
||||
def test_execute_process_as_user(self):
|
||||
self._test_execute_process_as_user(token=mock.sentinel.token,
|
||||
args=mock.sentinel.args,
|
||||
wait=False, new_console=False)
|
||||
|
||||
def test_execute_process_as_user_with_wait(self):
|
||||
self._test_execute_process_as_user(token=mock.sentinel.token,
|
||||
args=mock.sentinel.args,
|
||||
wait=False, new_console=False)
|
||||
|
||||
@mock.patch('cloudbaseinit.osutils.windows.Win32_STARTUPINFO_W')
|
||||
@mock.patch('cloudbaseinit.osutils.windows.Win32_PROCESS_INFORMATION')
|
||||
@mock.patch('subprocess.list2cmdline')
|
||||
def test_execute_process_as_user_fail(self, mock_list2cmdline,
|
||||
mock_proc_info, mock_startup_info):
|
||||
advapi32 = self._windll_mock.advapi32
|
||||
advapi32.CreateProcessAsUserW.side_effect = [False, True]
|
||||
kernel32 = self._ctypes_mock.windll.kernel32
|
||||
kernel32.GetExitCodeProcess.return_value = False
|
||||
mock_proc_info.hProcess = True
|
||||
|
||||
token = mock.sentinel.token
|
||||
args = mock.sentinel.args
|
||||
new_console = mock.sentinel.new_console
|
||||
|
||||
with self.assert_raises_windows_message("CreateProcessAsUserW "
|
||||
"failed: %r", 100):
|
||||
self._winutils.execute_process_as_user(token, args, False,
|
||||
new_console)
|
||||
with self.assert_raises_windows_message("GetExitCodeProcess "
|
||||
"failed: %r", 100):
|
||||
self._winutils.execute_process_as_user(token, args, True,
|
||||
new_console)
|
||||
|
@ -180,6 +180,8 @@ class TestInitManager(unittest.TestCase):
|
||||
def test_handle_plugins_stage_no_fast_reboot(self):
|
||||
self._test_handle_plugins_stage(fast_reboot=False)
|
||||
|
||||
@mock.patch('cloudbaseinit.init.InitManager.'
|
||||
'_reset_service_password_and_respawn')
|
||||
@mock.patch('cloudbaseinit.init.InitManager'
|
||||
'._handle_plugins_stage')
|
||||
@mock.patch('cloudbaseinit.init.InitManager._check_latest_version')
|
||||
@ -190,10 +192,10 @@ class TestInitManager(unittest.TestCase):
|
||||
def _test_configure_host(self, mock_get_metadata_service,
|
||||
mock_get_os_utils, mock_load_plugins,
|
||||
mock_get_version, mock_check_latest_version,
|
||||
mock_handle_plugins_stage,
|
||||
mock_handle_plugins_stage, mock_reset_service,
|
||||
expected_logging,
|
||||
version, name, instance_id, reboot=True):
|
||||
|
||||
sys.platform = 'win32'
|
||||
mock_get_version.return_value = version
|
||||
fake_service = mock.MagicMock()
|
||||
fake_plugin = mock.MagicMock()
|
||||
@ -218,7 +220,8 @@ class TestInitManager(unittest.TestCase):
|
||||
self.assertEqual(expected_logging, snatcher.output)
|
||||
mock_check_latest_version.assert_called_once_with()
|
||||
if CONF.reset_service_password:
|
||||
self.osutils.reset_service_password.assert_called_once_with()
|
||||
mock_reset_service.assert_called_once_with(self.osutils)
|
||||
|
||||
self.osutils.wait_for_boot_completion.assert_called_once_with()
|
||||
mock_get_metadata_service.assert_called_once_with()
|
||||
fake_service.get_name.assert_called_once_with()
|
||||
@ -283,3 +286,57 @@ class TestInitManager(unittest.TestCase):
|
||||
mock_partial.return_value)
|
||||
mock_partial.assert_called_once_with(
|
||||
init.LOG.info, 'Found new version of cloudbase-init %s')
|
||||
|
||||
@mock.patch('os.path.basename')
|
||||
@mock.patch("sys.executable")
|
||||
@mock.patch("sys.argv")
|
||||
@mock.patch("sys.exit")
|
||||
def _test_reset_service_password_and_respawn(self, mock_exit, mock_argv,
|
||||
mock_executable, mock_os_path,
|
||||
credentials, current_user):
|
||||
token = mock.sentinel.token
|
||||
self.osutils.create_user_logon_session.return_value = token
|
||||
self.osutils.execute_process_as_user.return_value = 0
|
||||
self.osutils.reset_service_password.return_value = credentials
|
||||
self.osutils.get_current_user.return_value = current_user
|
||||
expected_logging = []
|
||||
arguments = sys.argv + ["--noreset_service_password"]
|
||||
|
||||
with testutils.LogSnatcher('cloudbaseinit.init') as snatcher:
|
||||
self._init._reset_service_password_and_respawn(self.osutils)
|
||||
|
||||
if not credentials:
|
||||
return
|
||||
|
||||
if credentials[1] != current_user[1]:
|
||||
expected_logging = [
|
||||
"No need to respawn process. Current user: "
|
||||
"%(current_user)s. Service user: %(service_user)s" %
|
||||
{"current_user": current_user[1],
|
||||
"service_user": credentials[1]}
|
||||
]
|
||||
self.assertEqual(expected_logging, snatcher.output)
|
||||
else:
|
||||
self.osutils.create_user_logon_session.assert_called_once_with(
|
||||
credentials[1], credentials[2], credentials[0],
|
||||
logon_type=self.osutils.LOGON32_LOGON_BATCH)
|
||||
self.osutils.execute_process_as_user.assert_called_once_with(
|
||||
token, arguments)
|
||||
mock_exit.assert_called_once_with(0)
|
||||
|
||||
def test_reset_service_password_and_respawn(self):
|
||||
current_user = [mock.sentinel.domain, mock.sentinel.current_user]
|
||||
self._test_reset_service_password_and_respawn(
|
||||
credentials=None,
|
||||
current_user=current_user
|
||||
)
|
||||
self._test_reset_service_password_and_respawn(
|
||||
credentials=[mock.sentinel.domain, mock.sentinel.user,
|
||||
mock.sentinel.password],
|
||||
current_user=current_user
|
||||
)
|
||||
self._test_reset_service_password_and_respawn(
|
||||
credentials=[mock.sentinel.domain, mock.sentinel.current_user,
|
||||
mock.sentinel.password],
|
||||
current_user=current_user
|
||||
)
|
||||
|
@ -19,7 +19,6 @@ try:
|
||||
import unittest.mock as mock
|
||||
except ImportError:
|
||||
import mock
|
||||
import six
|
||||
|
||||
from cloudbaseinit import conf as cloudbaseinit_conf
|
||||
|
||||
@ -43,33 +42,12 @@ class SerialPortHandlerTests(unittest.TestCase):
|
||||
self._old_value = CONF.get('logging_serial_port_settings')
|
||||
CONF.set_override('logging_serial_port_settings', "COM1,115200,N,8")
|
||||
self._serial_port_handler = self.log.SerialPortHandler()
|
||||
self._unicode_stream = self._serial_port_handler._UnicodeToBytesStream(
|
||||
self._stream)
|
||||
self._serial_port_handler._port = mock.MagicMock()
|
||||
self._serial_port_handler.stream = mock.MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self._module_patcher.stop()
|
||||
CONF.set_override('logging_serial_port_settings', self._old_value)
|
||||
|
||||
def test_init(self):
|
||||
mock_Serial = self._serial.Serial
|
||||
mock_Serial.return_value.isOpen.return_value = False
|
||||
|
||||
self.log.SerialPortHandler()
|
||||
|
||||
mock_Serial.assert_called_with(bytesize=8, baudrate=115200,
|
||||
port='COM1', parity='N')
|
||||
mock_Serial.return_value.isOpen.assert_called_with()
|
||||
mock_Serial.return_value.open.assert_called_once_with()
|
||||
|
||||
def test_close(self):
|
||||
self._serial_port_handler._port.isOpen.return_value = True
|
||||
|
||||
self._serial_port_handler.close()
|
||||
|
||||
self._serial_port_handler._port.isOpen.assert_called_once_with()
|
||||
self._serial_port_handler._port.close.assert_called_once_with()
|
||||
|
||||
@mock.patch('oslo_log.log.setup')
|
||||
@mock.patch('oslo_log.log.getLogger')
|
||||
@mock.patch('cloudbaseinit.utils.log.SerialPortHandler')
|
||||
@ -89,28 +67,3 @@ class SerialPortHandlerTests(unittest.TestCase):
|
||||
|
||||
mock_SerialPortHandler().setFormatter.assert_called_once_with(
|
||||
mock_ContextFormatter())
|
||||
|
||||
def _test_unicode_write(self, is_six_instance=False):
|
||||
self._stream.isOpen.return_value = False
|
||||
if is_six_instance:
|
||||
fake_data = mock.MagicMock(spec=six.text_type)
|
||||
fake_data.encode = mock.MagicMock()
|
||||
else:
|
||||
fake_data = mock.MagicMock()
|
||||
|
||||
self._unicode_stream.write(fake_data)
|
||||
|
||||
self._stream.isOpen.assert_called_once_with()
|
||||
self._stream.open.assert_called_once_with()
|
||||
if is_six_instance:
|
||||
self._stream.write.assert_called_once_with(
|
||||
fake_data.encode.return_value)
|
||||
fake_data.encode.assert_called_once_with('utf-8')
|
||||
else:
|
||||
self._stream.write.assert_called_once_with(fake_data)
|
||||
|
||||
def test_unicode_write(self):
|
||||
self._test_unicode_write()
|
||||
|
||||
def test_unicode_write_with_encode(self):
|
||||
self._test_unicode_write(is_six_instance=True)
|
||||
|
@ -25,45 +25,70 @@ CONF = cloudbaseinit_conf.CONF
|
||||
LOG = log.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_write(function):
|
||||
"""Avoid issues related to unicode strings handling."""
|
||||
def _wrapper(message):
|
||||
# Unicode strings are not properly handled by the serial module
|
||||
if isinstance(message, six.text_type):
|
||||
function(message.encode("utf-8"))
|
||||
else:
|
||||
function(message)
|
||||
return _wrapper
|
||||
|
||||
|
||||
def release_logging_handlers(product_name):
|
||||
"""Closes any currently used logging port handlers.
|
||||
|
||||
Resulting in the stream, file and serial port handler being closed
|
||||
and removed from the logging object.
|
||||
"""
|
||||
log_root = log.getLogger(product_name).logger
|
||||
for handler in log_root.handlers:
|
||||
log_root.removeHandler(handler)
|
||||
handler.close()
|
||||
|
||||
|
||||
class SerialPortHandler(logging.StreamHandler):
|
||||
|
||||
class _UnicodeToBytesStream(object):
|
||||
|
||||
def __init__(self, stream):
|
||||
self._stream = stream
|
||||
|
||||
def write(self, data):
|
||||
if self._stream and not self._stream.isOpen():
|
||||
self._stream.open()
|
||||
|
||||
if isinstance(data, six.text_type):
|
||||
self._stream.write(data.encode("utf-8"))
|
||||
else:
|
||||
self._stream.write(data)
|
||||
|
||||
def __init__(self):
|
||||
self._port = None
|
||||
super(SerialPortHandler, self).__init__(None)
|
||||
self.stream = None
|
||||
|
||||
@staticmethod
|
||||
def _open():
|
||||
serial_port = None
|
||||
if CONF.logging_serial_port_settings:
|
||||
settings = CONF.logging_serial_port_settings.split(',')
|
||||
|
||||
try:
|
||||
self._port = serial.Serial(port=settings[0],
|
||||
baudrate=int(settings[1]),
|
||||
parity=settings[2],
|
||||
bytesize=int(settings[3]))
|
||||
if not self._port.isOpen():
|
||||
self._port.open()
|
||||
except serial.SerialException as ex:
|
||||
# Log to other handlers
|
||||
LOG.exception(ex)
|
||||
serial_port = serial.Serial(port=settings[0],
|
||||
baudrate=int(settings[1]),
|
||||
parity=settings[2],
|
||||
bytesize=int(settings[3]))
|
||||
if not serial_port.isOpen():
|
||||
serial_port.open()
|
||||
serial_port.write = _safe_write(serial_port.write)
|
||||
except serial.SerialException as exc:
|
||||
LOG.debug(exc)
|
||||
return serial_port
|
||||
|
||||
# Unicode strings are not properly handled by the serial module
|
||||
super(SerialPortHandler, self).__init__(
|
||||
self._UnicodeToBytesStream(self._port))
|
||||
def emit(self, record):
|
||||
"""Emit a record."""
|
||||
if self.stream is None:
|
||||
self.stream = self._open()
|
||||
|
||||
super(SerialPortHandler, self).emit(record)
|
||||
|
||||
def close(self):
|
||||
if self._port and self._port.isOpen():
|
||||
self._port.close()
|
||||
"""Closes the serial port."""
|
||||
self.acquire()
|
||||
try:
|
||||
serial_port = self.stream
|
||||
if serial_port and serial_port.isOpen():
|
||||
self.stream = None
|
||||
serial_port.close()
|
||||
logging.Handler.close(self)
|
||||
finally:
|
||||
self.release()
|
||||
|
||||
|
||||
def setup(product_name):
|
||||
|
Loading…
x
Reference in New Issue
Block a user