# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import collections import copy import json import logging import os import shlex import tempfile from oslo_concurrency import processutils import yaml from os_faults.api import error LOG = logging.getLogger(__name__) STDOUT_LIMIT = 4096 # Symbols count ANSIBLE_FORKS = 100 STATUS_OK = 'OK' STATUS_FAILED = 'FAILED' STATUS_UNREACHABLE = 'UNREACHABLE' STATUS_SKIPPED = 'SKIPPED' DEFAULT_ERROR_STATUSES = {STATUS_FAILED, STATUS_UNREACHABLE} SSH_COMMON_ARGS = ('-o UserKnownHostsFile=/dev/null ' '-o StrictHostKeyChecking=no ' '-o ConnectTimeout=60') class AnsibleExecutionException(Exception): pass class AnsibleExecutionUnreachable(AnsibleExecutionException): pass AnsibleExecutionRecord = collections.namedtuple( 'AnsibleExecutionRecord', ['host', 'status', 'task', 'payload']) def find_ansible(): stdout, stderr = processutils.execute( *shlex.split('which ansible-playbook'), check_exit_code=[0, 1]) if not stdout: raise AnsibleExecutionException( 'Ansible executable is not found in $PATH') return stdout[:-1] def resolve_relative_path(file_name): path = os.path.normpath(os.path.join( os.path.dirname(__import__('os_faults').__file__), '../', file_name)) if os.path.exists(path): return path MODULE_PATHS = { resolve_relative_path('os_faults/ansible/modules'), } def get_module_paths(): global MODULE_PATHS return MODULE_PATHS def add_module_paths(paths): global MODULE_PATHS for path in paths: if not os.path.exists(path): raise error.OSFError('{!r} does not exist'.format(path)) # find all subfolders dirs = [x[0] for x in os.walk(path)] MODULE_PATHS.update(dirs) def make_module_path_option(): # now it is a list of strings (MUST have > 1 element) module_path = list(get_module_paths()) if len(module_path) == 1: module_path += [module_path[0]] return module_path Options = collections.namedtuple( 'Options', ['connection', 'module_path', 'forks']) class AnsibleRunner(object): def __init__(self, auth=None, forks=ANSIBLE_FORKS, serial=None): super(AnsibleRunner, self).__init__() self.default_host_vars = self._build_auth_host_vars(auth) self.options = Options( connection='smart', module_path=make_module_path_option(), forks=forks) self.serial = serial or 10 self.ansible = find_ansible() @staticmethod def _build_proxy_arg(jump_user, jump_host, private_key_file=None): key = '-i ' + private_key_file if private_key_file else '' return (' -o ProxyCommand="ssh %(key)s -W %%h:%%p %(ssh_args)s ' '%(user)s@%(host)s"' % dict(key=key, user=jump_user, host=jump_host, ssh_args=SSH_COMMON_ARGS)) def _run_play(self, play_source, host_vars): inventory = {} for host, variables in host_vars.items(): host_vars = dict((k, v) for k, v in self.default_host_vars.items() if v) host_vars.update(dict((k, v) for k, v in variables.items() if v)) inventory[host] = host_vars inventory[host]['ansible_connection'] = self.options.connection full_inventory = {'all': {'hosts': inventory}} temp_dir = tempfile.mkdtemp(prefix='os-faults') inventory_file_name = os.path.join(temp_dir, 'inventory') playbook_file_name = os.path.join(temp_dir, 'playbook') with open(inventory_file_name, 'w') as fd: cnt = yaml.safe_dump(full_inventory, default_flow_style=False) print(cnt, file=fd) LOG.debug('Inventory:\n%s', cnt) play = { 'hosts': 'all', 'gather_facts': 'no', 'tasks': play_source['tasks'], } with open(playbook_file_name, 'w') as fd: cnt = yaml.safe_dump([play], default_flow_style=False) print(cnt, file=fd) LOG.debug('Playbook:\n%s', cnt) cmd = ('%(ansible)s --module-path %(module_path)s ' '-i %(inventory)s %(playbook)s' % {'ansible': self.ansible, 'module_path': ':'.join(self.options.module_path), 'inventory': inventory_file_name, 'playbook': playbook_file_name}) logging.info('Executing %s' % cmd) command_stdout, command_stderr = processutils.execute( *shlex.split(cmd), env_variables={'ANSIBLE_STDOUT_CALLBACK': 'json'}, check_exit_code=False) d = json.loads(command_stdout[command_stdout.find('{'):]) h = d['plays'][0]['tasks'][0]['hosts'] recs = [] for h, hv in h.items(): if hv.get('unreachable'): status = STATUS_UNREACHABLE elif hv.get('failed'): status = STATUS_FAILED else: status = STATUS_OK r = AnsibleExecutionRecord(host=h, status=status, task='', payload=hv) recs.append(r) return recs def run_playbook(self, playbook, host_vars): result = [] for play_source in playbook: play_source['gather_facts'] = 'no' result += self._run_play(play_source, host_vars) return result def execute(self, hosts, task, raise_on_statuses=None): """Executes the task on every host from the list Raises exception if any of the commands fails with one of specified statuses. :param hosts: list of host addresses :param task: Ansible task :param raise_on_statuses: raise exception if any of commands return any of these statuses :return: execution result, type AnsibleExecutionRecord """ if raise_on_statuses is None: raise_on_statuses = DEFAULT_ERROR_STATUSES LOG.debug('Executing task: %s on hosts: %s with serial: %s', task, hosts, self.serial) host_vars = {h.ip: self._build_auth_host_vars(h.auth) for h in hosts} task_play = {'hosts': [h.ip for h in hosts], 'tasks': [task], 'serial': self.serial} result = self.run_playbook([task_play], host_vars) log_result = copy.deepcopy(result) LOG.debug('Execution completed with %s result(s):' % len(log_result)) for lr in log_result: if 'stdout' in lr.payload: if len(lr.payload['stdout']) > STDOUT_LIMIT: lr.payload['stdout'] = ( lr.payload['stdout'][:STDOUT_LIMIT] + '... ') if 'stdout_lines' in lr.payload: del lr.payload['stdout_lines'] LOG.debug(lr) if raise_on_statuses: errors = [] only_unreachable = True for r in result: if r.status in raise_on_statuses: if r.status != STATUS_UNREACHABLE: only_unreachable = False errors.append(r) if errors: msg = 'Execution failed: %s' % ', '.join(( '(host: %s, status: %s)' % (r.host, r.status)) for r in errors) ek = (AnsibleExecutionUnreachable if only_unreachable else AnsibleExecutionException) raise ek(msg) return result def _build_auth_host_vars(self, auth): if not auth: return {} ssh_common_args = SSH_COMMON_ARGS if 'jump' in auth: ssh_common_args += self._build_proxy_arg( jump_host=auth['jump']['host'], jump_user=auth['jump'].get( 'username', auth.get('username')), private_key_file=auth['jump'].get('private_key_file')) return { 'ansible_user': auth.get('username'), 'ansible_ssh_pass': auth.get('password'), 'ansible_become_user': auth.get('become_username'), 'ansible_become_pass': auth.get('become_password'), 'ansible_become_method': auth.get('become_method'), 'ansible_ssh_private_key_file': auth.get('private_key_file'), 'ansible_ssh_common_args': ssh_common_args, }