Source code for plumpy.base.state_machine

# -*- coding: utf-8 -*-
"""The state machine for processes"""
from collections.abc import Iterable
import enum
import functools
import inspect
import logging
import os
import sys

from plumpy.futures import Future
from .utils import call_with_super_check, super_check

__all__ = ['StateMachine', 'StateMachineMeta', 'event', 'TransitionFailed']

_LOGGER = logging.getLogger(__name__)


class StateMachineError(Exception):
    """Base class for state machine errors"""


class StateEntryFailed(Exception):
    """
    Failed to enter a state, can provide the next state to go to via this exception
    """

    def __init__(self, state=None, *args, **kwargs):  # pylint: disable=keyword-arg-before-vararg
        super().__init__('failed to enter state')
        self.state = state
        self.args = args
        self.kwargs = kwargs


class InvalidStateError(Exception):
    """The operation is not allowed in this state."""


class EventError(StateMachineError):

    def __init__(self, evt, msg):
        super().__init__(msg)
        self.event = evt


class TransitionFailed(Exception):
    """A state transition failed"""

    def __init__(self, initial_state, final_state=None, traceback_str=None):
        self.initial_state = initial_state
        self.final_state = final_state
        self.traceback_str = traceback_str
        super().__init__(self._format_msg())

    def _format_msg(self):
        msg = ['{} -> {}'.format(self.initial_state, self.final_state)]
        if self.traceback_str is not None:
            msg.append(self.traceback_str)
        return '\n'.join(msg)


def event(from_states='*', to_states='*'):
    if from_states != '*':
        if inspect.isclass(from_states):
            from_states = (from_states,)
        if not all(issubclass(state, State) for state in from_states):
            raise TypeError()
    if to_states != '*':
        if inspect.isclass(to_states):
            to_states = (to_states,)
        if not all(issubclass(state, State) for state in to_states):
            raise TypeError()

    def wrapper(wrapped):
        evt_label = wrapped.__name__

        @functools.wraps(wrapped)
        def transition(self, *a, **kw):
            initial = self._state

            if from_states != '*' and not any(isinstance(self._state, state) for state in from_states):
                raise EventError(evt_label, 'Event {} invalid in state {}'.format(evt_label, initial.LABEL))

            result = wrapped(self, *a, **kw)
            if not (result is False or isinstance(result, Future)):
                if to_states != '*' and not any(isinstance(self._state, state) for state in to_states):
                    if self._state == initial:
                        raise EventError(evt_label, 'Machine did not transition')

                    raise EventError(
                        evt_label, 'Event produced invalid state transition from '
                        '{} to {}'.format(initial.LABEL, self._state.LABEL)
                    )

            return result

        return transition

    if inspect.isfunction(from_states):
        return wrapper(from_states)

    return wrapper


[docs]class State: LABEL = None # A set containing the labels of states that can be entered # from this one ALLOWED = set() @classmethod def is_terminal(cls): return not cls.ALLOWED def __init__(self, state_machine): """ :param state_machine: The process this state belongs to :type state_machine: :class:`StateMachine` """ self.state_machine = state_machine self.in_state = False def __str__(self): return str(self.LABEL) @property def label(self): """ Convenience property to get the state label """ return self.LABEL @super_check def enter(self): """ Entering the state """
[docs] def execute(self): """ Execute the state, performing the actions that this state is responsible for. Return a state to transition to or None if finished. """
@super_check def exit(self): """ Exiting the state """ if self.is_terminal(): raise InvalidStateError('Cannot exit a terminal state {}'.format(self.LABEL)) def create_state(self, state_label, *args, **kwargs): return self.state_machine.create_state(state_label, *args, **kwargs) def do_enter(self): call_with_super_check(self.enter) self.in_state = True def do_exit(self): call_with_super_check(self.exit) self.in_state = False
class StateEventHook(enum.Enum): """ Hooks that can be used to register callback at various points in the state transition procedure. The callback will be passed a state instance whose meaning will differ depending on the hook as commented below. """ ENTERING_STATE = 0 # State passed will be the state that is being entered ENTERED_STATE = 1 # State passed will be the last state that we entered from EXITING_STATE = 2 # State passed will be the next state that will be entered (or None for terminal) class StateMachineMeta(type): def __call__(cls, *args, **kwargs): """ Create the state machine and enter the initial state. :param args: Any positional arguments to be passed to the constructor :param kwargs: Any keyword arguments to be passed to the constructor :return: An instance of the state machine """ inst = super().__call__(*args, **kwargs) inst.transition_to(inst.create_initial_state()) call_with_super_check(inst.init) return inst
[docs]class StateMachine(metaclass=StateMachineMeta): STATES = None _STATES_MAP = None _transitioning = False _transition_failing = False @classmethod def get_states_map(cls): cls.__ensure_built() return cls._STATES_MAP @classmethod def get_states(cls): if cls.STATES is not None: return cls.STATES raise RuntimeError('States not defined') @classmethod def initial_state_label(cls): cls.__ensure_built() return cls.STATES[0].LABEL # pylint: disable=unsubscriptable-object @classmethod def get_state_class(cls, label): cls.__ensure_built() return cls._STATES_MAP[label] # pylint: disable=unsubscriptable-object @classmethod def __ensure_built(cls): try: # Check if it's already been built (and therefore sealed) if cls.__getattribute__(cls, 'sealed'): return except AttributeError: pass cls.STATES = cls.get_states() assert isinstance(cls.STATES, Iterable) # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: # pylint: disable=not-an-iterable assert issubclass(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, "Duplicate label '{}'".format(label) # pylint: disable=unsupported-membership-test cls._STATES_MAP[label] = state_cls # pylint: disable=unsupported-assignment-operation cls.sealed = True def __init__(self): super().__init__() self.__ensure_built() self._state = None self._exception_handler = None self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG')))) self._transitioning = False self._event_callbacks = {} @super_check def init(self): """ Called after entering initial state in `__call__` method of `StateMachineMeta` """ def __str__(self): return '<{}> ({})'.format(self.__class__.__name__, self.state) def create_initial_state(self): return self.get_state_class(self.initial_state_label())(self) @property def state(self): if self._state is None: return None return self._state.LABEL
[docs] def add_state_event_callback(self, hook, callback): """ Add a callback to be called on a particular state event hook. The callback should have form fn(state_machine, hook, state) :param hook: The state event hook :param callback: The callback function """ self._event_callbacks.setdefault(hook, []).append(callback)
def remove_state_event_callback(self, hook, callback): try: self._event_callbacks[hook].remove(callback) except (KeyError, ValueError): raise ValueError("Callback not set for hook '{}'".format(hook)) def _fire_state_event(self, hook, state): for callback in self._event_callbacks.get(hook, []): callback(self, hook, state) @super_check def on_terminated(self): """ Called when a terminal state is entered """ def transition_to(self, new_state, *args, **kwargs): assert not self._transitioning, \ 'Cannot call transition_to when already transitioning state' initial_state_label = self._state.LABEL if self._state is not None else None label = None try: self._transitioning = True # Make sure we have a state instance new_state = self._create_state_instance(new_state, *args, **kwargs) label = new_state.LABEL self._exit_current_state(new_state) try: self._enter_next_state(new_state) except StateEntryFailed as exception: new_state = exception.state # Make sure we have a state instance new_state = self._create_state_instance(new_state, *exception.args, **exception.kwargs) label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) if self._state.is_terminal(): call_with_super_check(self.on_terminated) except Exception: # pylint: disable=broad-except self._transitioning = False if self._transition_failing: raise self._transition_failing = True self.transition_failed(initial_state_label, label, *sys.exc_info()[1:]) finally: self._transition_failing = False self._transitioning = False
[docs] @staticmethod def transition_failed(initial_state, final_state, exception, trace): """ Called when a state transitions fails. This method can be overwritten to change the default behaviour which is to raise the exception. :param exception: The transition failed exception :type exception: :class:`Exception` """ raise exception.with_traceback(trace)
def get_debug(self): return self._debug def set_debug(self, enabled): self._debug = enabled def create_state(self, state_label, *args, **kwargs): try: return self.get_states_map()[state_label](self, *args, **kwargs) # pylint: disable=unsubscriptable-object except KeyError: raise ValueError('{} is not a valid state'.format(state_label)) def _exit_current_state(self, next_state): """ Exit the given state """ # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: if next_state.label != self.initial_state_label(): raise RuntimeError("Cannot enter state '{}' as the initial state".format(next_state)) return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: raise RuntimeError('Cannot transition from {} to {}'.format(self._state.LABEL, next_state.label)) self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.do_exit() def _enter_next_state(self, next_state): last_state = self._state self._fire_state_event(StateEventHook.ENTERING_STATE, next_state) # Enter the new state next_state.do_enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) def _create_state_instance(self, state, *args, **kwargs): if isinstance(state, State): # It's already a state instance return state # OK, have to create it state_cls = self._ensure_state_class(state) return state_cls(self, *args, **kwargs) def _ensure_state_class(self, state): if inspect.isclass(state) and issubclass(state, State): return state try: return self.get_states_map()[state] # pylint: disable=unsubscriptable-object except KeyError: raise ValueError('{} is not a valid state'.format(state))