Further improvements; Throttle maintains its own storage and no longer requires it to be supplied as an arg.

This commit is contained in:
Johnny 2018-09-21 17:38:31 +00:00
parent b50364038c
commit 791ace73bc
2 changed files with 62 additions and 53 deletions

View file

@ -88,32 +88,38 @@ class ThrottleTest(EvenniaTest):
def test_throttle(self):
ips = ('94.100.176.153', '45.56.148.77', '5.196.1.129')
kwargs = {
'maxlim': 5,
'timeout': 5 * 60
'limit': 5,
'timeout': 15 * 60
}
throttle = Throttle(**kwargs)
for ip in ips:
# Throttle should not be engaged by default
self.assertFalse(Throttle.check(ip, **kwargs))
self.assertFalse(throttle.check(ip))
# Pretend to fail a bunch of events
for x in xrange(5):
obj = Throttle.update(ip)
for x in xrange(50):
obj = throttle.update(ip)
self.assertFalse(obj)
# Next ones should be blocked
self.assertTrue(Throttle.check(ip, **kwargs))
self.assertTrue(throttle.check(ip))
for x in xrange(Throttle.cache_size * 2):
obj = Throttle.update(ip)
for x in xrange(throttle.cache_size * 2):
obj = throttle.update(ip)
self.assertFalse(obj)
# Should still be blocked
self.assertTrue(Throttle.check(ip, **kwargs))
self.assertTrue(throttle.check(ip))
# Number of values should be limited by cache size
self.assertEqual(Throttle.cache_size, len(Throttle.get(ip)))
self.assertEqual(throttle.cache_size, len(throttle.get(ip)))
cache = throttle.get()
# Make sure there are entries for each IP
self.assertEqual(len(ips), len(cache.keys()))
# There should only be (cache_size * num_ips) total in the Throttle cache
cache = Throttle.get()
self.assertEqual(sum([len(cache[x]) for x in cache.keys()]), Throttle.cache_size * len(ips))
self.assertEqual(sum([len(cache[x]) for x in cache.keys()]), throttle.cache_size * len(ips))

View file

@ -1,8 +1,6 @@
from collections import defaultdict, deque
import time
_LATEST_FAILURES = defaultdict(deque)
class Throttle(object):
"""
Keeps a running count of failed actions per IP address.
@ -17,12 +15,26 @@ class Throttle(object):
"""
error_msg = 'Too many failed attempts; you must wait a few minutes before trying again.'
cache_size = 20
@classmethod
def get(cls, ip=None, storage=_LATEST_FAILURES):
def __init__(self, **kwargs):
"""
Convenience function that appends a new event to the table.
Allows setting of throttle parameters.
Kwargs:
limit (int): Max number of failures before imposing limiter
timeout (int): number of timeout seconds after
max number of tries has been reached.
cache_size (int): Max number of attempts to record per IP within a
rolling window; this is NOT the same as the limit after which
the throttle is imposed!
"""
self.storage = defaultdict(deque)
self.cache_size = self.limit = kwargs.get('limit', 5)
self.timeout = kwargs.get('timeout', 5 * 60)
def get(self, ip=None):
"""
Convenience function that returns the storage table, or part of.
Args:
ip (str, optional): IP address of requestor
@ -35,64 +47,55 @@ class Throttle(object):
timestamps of recent failures only for that IP.
"""
if ip: return storage.get(ip, deque(maxlen=cls.cache_size))
return storage
if ip: return self.storage.get(ip, deque(maxlen=self.cache_size))
else: return self.storage
@classmethod
def update(cls, ip):
def update(self, ip):
"""
Convenience function that appends a new event to the table.
Store the time of the latest failure/
Args:
ip (str): IP address of requestor
Returns:
throttled (False): Always returns False
None
"""
return cls.check(ip)
# Enforce length limits
if not self.storage[ip].maxlen:
self.storage[ip] = deque(maxlen=self.cache_size)
self.storage[ip].append(time.time())
@classmethod
def check(cls, ip, maxlim=None, timeout=None, storage=_LATEST_FAILURES):
def check(self, ip):
"""
This will check the session's address against the
_LATEST_FAILURES dictionary to check they haven't
spammed too many fails recently.
storage dictionary to check they haven't spammed too many
fails recently.
Args:
ip (str): IP address of requestor
maxlim (int): max number of attempts to allow
timeout (int): number of timeout seconds after
max number of tries has been reached.
Returns:
throttled (bool): True if throttling is active,
False otherwise.
Notes:
If maxlim and/or timeout are set, the function will
just do the comparison, not append a new datapoint.
"""
now = time.time()
ip = str(ip)
if maxlim and timeout:
# checking mode
latest_fails = storage[ip]
if latest_fails and len(latest_fails) >= maxlim:
# too many fails recently
if now - latest_fails[-1] < timeout:
# too soon - timeout in play
return True
else:
# timeout has passed. clear faillist
del(storage[ip])
return False
# checking mode
latest_fails = self.storage[ip]
if latest_fails and len(latest_fails) >= self.limit:
# too many fails recently
if now - latest_fails[-1] < self.timeout:
# too soon - timeout in play
return True
else:
# timeout has passed. clear faillist
del(self.storage[ip])
return False
else:
# store the time of the latest fail
if ip not in storage or not storage[ip].maxlen:
storage[ip] = deque(maxlen=cls.cache_size)
storage[ip].append(time.time())
return False
return False