diff --git a/keystonemiddleware/auth_token/__init__.py b/keystonemiddleware/auth_token/__init__.py index 75a6ec5e..80710e8a 100644 --- a/keystonemiddleware/auth_token/__init__.py +++ b/keystonemiddleware/auth_token/__init__.py @@ -376,6 +376,7 @@ CONF = cfg.CONF CONF.register_opts(_OPTS, group=_base.AUTHTOKEN_GROUP) _LOG = logging.getLogger(__name__) +_CACHE_INVALID_INDICATOR = 'invalid' class _BIND_MODE(object): @@ -835,18 +836,29 @@ class AuthProtocol(BaseAuthProtocol): cached = self._cache_get_hashes(token_hashes) if cached: - data = cached + if cached == _CACHE_INVALID_INDICATOR: + self._LOG.debug('Cached token is marked unauthorized') + raise ksm_exceptions.InvalidToken() if self._check_revocations_for_cached: # A token might have been revoked, regardless of initial # mechanism used to validate it, and needs to be checked. self._revocations.check(token_hashes) + + # NOTE(jamielennox): Cached values used to be stored as a tuple + # of data and expiry time. They no longer are but we have to + # allow some time to transition the old format so if it's a + # tuple just use the data. + if len(cached) == 2: + cached = cached[0] + + data = cached else: data = self._validate_offline(token, token_hashes) if not data: data = self._identity_server.verify_token(token) - self._token_cache.store(token_hashes[0], data) + self._token_cache.set(token_hashes[0], data) except (ksa_exceptions.ConnectFailure, ksa_exceptions.RequestTimeout, @@ -857,7 +869,8 @@ class AuthProtocol(BaseAuthProtocol): except ksm_exceptions.InvalidToken: self.log.debug('Token validation failure.', exc_info=True) if token_hashes: - self._token_cache.store_invalid(token_hashes[0]) + self._token_cache.set(token_hashes[0], + _CACHE_INVALID_INDICATOR) self.log.warning(_LW('Authorization failed for token')) raise diff --git a/keystonemiddleware/auth_token/_cache.py b/keystonemiddleware/auth_token/_cache.py index 9cd5b00b..e9b62c8a 100644 --- a/keystonemiddleware/auth_token/_cache.py +++ b/keystonemiddleware/auth_token/_cache.py @@ -103,15 +103,13 @@ class TokenCache(object): initialize() must be called before calling the other methods. - Store a valid token in the cache using store(); mark a token as invalid in - the cache using store_invalid(). + Store data in the cache store. Check if a token is in the cache and retrieve it using get(). """ _CACHE_KEY_TEMPLATE = 'tokens/%s' - _INVALID_INDICATOR = 'invalid' def __init__(self, log, cache_time=None, env_cache_name=None, memcached_servers=None, @@ -144,16 +142,6 @@ class TokenCache(object): self._cache_pool = self._get_cache_pool(env.get(self._env_cache_name)) self._initialized = True - def store(self, token_id, data): - """Put token data into the cache.""" - self._LOG.debug('Storing token in cache') - self._cache_store(token_id, data) - - def store_invalid(self, token_id): - """Store invalid token in cache.""" - self._LOG.debug('Marking token as unauthorized in cache') - self._cache_store(token_id, self._INVALID_INDICATOR) - def _get_cache_key(self, token_id): """Get a unique key for this token id. @@ -229,32 +217,13 @@ class TokenCache(object): serialized = serialized.encode('utf8') data = self._deserialize(serialized, context) - # Note that _INVALID_INDICATOR and (data, expires) are the only - # valid types of serialized cache entries, so there is not - # a collision with jsonutils.loads(serialized) == None. if not isinstance(data, six.string_types): data = data.decode('utf-8') - cached = jsonutils.loads(data) - if cached == self._INVALID_INDICATOR: - self._LOG.debug('Cached Token is marked unauthorized') - raise exc.InvalidToken(_('Token authorization failed')) - # NOTE(jamielennox): Cached values used to be stored as a tuple of data - # and expiry time. They no longer are but we have to allow some time to - # transition the old format so if it's a tuple just return the data. - try: - data, expires = cached - except ValueError: - data = cached + return jsonutils.loads(data) - return data - - def _cache_store(self, token_id, data): - """Store value into memcache. - - data may be _INVALID_INDICATOR or a tuple like (data, expires) - - """ + def set(self, token_id, data): + """Store value into memcache.""" data = jsonutils.dumps(data) if isinstance(data, six.text_type): data = data.encode('utf-8') diff --git a/keystonemiddleware/tests/unit/auth_token/test_auth_token_middleware.py b/keystonemiddleware/tests/unit/auth_token/test_auth_token_middleware.py index ae34a6e1..39cd818c 100644 --- a/keystonemiddleware/tests/unit/auth_token/test_auth_token_middleware.py +++ b/keystonemiddleware/tests/unit/auth_token/test_auth_token_middleware.py @@ -1000,8 +1000,8 @@ class CommonAuthTokenMiddlewareTest(object): token = 'invalid-token' self.call_middleware(headers={'X-Auth-Token': token}, expected_status=401) - self.assertRaises(ksm_exceptions.InvalidToken, - self._get_cached_token, token) + self.assertEqual(auth_token._CACHE_INVALID_INDICATOR, + self._get_cached_token(token)) def test_memcache_set_expired(self, extra_conf={}, extra_environ={}): token_cache_time = 10 @@ -1825,8 +1825,9 @@ class v3AuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest, now = datetime.datetime.utcnow() delta = datetime.timedelta(hours=1) expires = strtime(at=(now + delta)) - self.middleware._token_cache.store(token, (data, expires)) - self.assertEqual(self.middleware._token_cache.get(token), data) + self.middleware._token_cache.set(token, (data, expires)) + new_data = self.middleware.fetch_token(token) + self.assertEqual(data, new_data) class DelayedAuthTests(BaseAuthTokenMiddlewareTest): diff --git a/keystonemiddleware/tests/unit/auth_token/test_cache.py b/keystonemiddleware/tests/unit/auth_token/test_cache.py index 0d69d3bd..df677bf7 100644 --- a/keystonemiddleware/tests/unit/auth_token/test_cache.py +++ b/keystonemiddleware/tests/unit/auth_token/test_cache.py @@ -118,7 +118,7 @@ class TestLiveMemcache(base.BaseAuthTokenTestCase): token_cache = self.create_simple_middleware(conf=conf)._token_cache token_cache.initialize({}) - token_cache._cache_store(token, data) + token_cache.set(token, data) self.assertEqual(token_cache.get(token), data) def test_sign_cache_data(self): @@ -134,7 +134,7 @@ class TestLiveMemcache(base.BaseAuthTokenTestCase): token_cache = self.create_simple_middleware(conf=conf)._token_cache token_cache.initialize({}) - token_cache._cache_store(token, data) + token_cache.set(token, data) self.assertEqual(token_cache.get(token), data) def test_no_memcache_protection(self): @@ -148,5 +148,5 @@ class TestLiveMemcache(base.BaseAuthTokenTestCase): token_cache = self.create_simple_middleware(conf=conf)._token_cache token_cache.initialize({}) - token_cache._cache_store(token, data) + token_cache.set(token, data) self.assertEqual(token_cache.get(token), data)