diff --git a/picard/webservice/__init__.py b/picard/webservice/__init__.py index cb8e033cd..d66a0d186 100644 --- a/picard/webservice/__init__.py +++ b/picard/webservice/__init__.py @@ -260,6 +260,64 @@ class WSPostRequest(WSRequest): super()._init_headers() +class RequestTask(namedtuple('RequestTask', 'hostkey func priority')): + + @staticmethod + def from_request(request, func): + # priority is a boolean + return RequestTask(request.get_host_key(), func, int(request.priority)) + + +class RequestPriorityQueue: + + def __init__(self, ratecontrol): + self._queues = defaultdict(lambda: defaultdict(deque)) + self._ratecontrol = ratecontrol + self._count = 0 + + def count(self): + return self._count + + def add_task(self, task, important=False): + (hostkey, func, prio) = task + queue = self._queues[prio][hostkey] + if important: + queue.appendleft(func) + else: + queue.append(func) + self._count += 1 + return RequestTask(hostkey, func, prio) + + def remove_task(self, task): + hostkey, func, prio = task + try: + self._queues[prio][hostkey].remove(func) + self._count -= 1 + except Exception as e: + log.debug(e) + + def run_ready_tasks(self): + delay = sys.maxsize + for prio in sorted(self._queues.keys(), reverse=True): + prio_queue = self._queues[prio] + if not prio_queue: + del(self._queues[prio]) + continue + for hostkey in sorted(prio_queue.keys(), + key=self._ratecontrol.current_delay): + queue = self._queues[prio][hostkey] + if not queue: + del(self._queues[prio][hostkey]) + continue + wait, d = self._ratecontrol.get_delay_to_next_request(hostkey) + if not wait: + queue.popleft()() + self._count -= 1 + if d < delay: + delay = d + return delay + + class WebService(QtCore.QObject): PARSERS = dict() @@ -308,9 +366,8 @@ class WebService(QtCore.QObject): def _init_queues(self): self._active_requests = {} - self._queues = defaultdict(lambda: defaultdict(deque)) + self._queue = RequestPriorityQueue(ratecontrol) self.num_pending_web_requests = 0 - self._last_num_pending_web_requests = -1 def _init_timers(self): self._timer_run_next_task = QtCore.QTimer(self) @@ -478,8 +535,8 @@ class WebService(QtCore.QObject): else: redirect = reply.attribute(QNetworkRequest.RedirectionTargetAttribute) - fromCache = reply.attribute(QNetworkRequest.SourceIsFromCacheAttribute) - cached = ' (CACHED)' if fromCache else '' + from_cache = reply.attribute(QNetworkRequest.SourceIsFromCacheAttribute) + cached = ' (CACHED)' if from_cache else '' log.debug("Received reply for %s: HTTP %d (%s) %s", url, response_code, @@ -561,45 +618,21 @@ class WebService(QtCore.QObject): self._init_queues() def _count_pending_requests(self): - count = len(self._active_requests) - for prio_queue in self._queues.values(): - for queue in prio_queue.values(): - count += len(queue) - self.num_pending_web_requests = count - if count != self._last_num_pending_web_requests: - self._last_num_pending_web_requests = count + count = len(self._active_requests) + self._queue.count() + if count != self.num_pending_web_requests: + self.num_pending_web_requests = count self.tagger.tagger_stats_changed.emit() if count: self._timer_count_pending_requests.start(COUNT_REQUESTS_DELAY_MS) def _run_next_task(self): - delay = sys.maxsize - for prio in sorted(self._queues.keys(), reverse=True): - prio_queue = self._queues[prio] - if not prio_queue: - del(self._queues[prio]) - continue - for hostkey in sorted(prio_queue.keys(), - key=ratecontrol.current_delay): - queue = self._queues[prio][hostkey] - if not queue: - del(self._queues[prio][hostkey]) - continue - wait, d = ratecontrol.get_delay_to_next_request(hostkey) - if not wait: - queue.popleft()() - if d < delay: - delay = d + delay = self._queue.run_ready_tasks() if delay < sys.maxsize: self._timer_run_next_task.start(delay) def add_task(self, func, request): - hostkey = request.get_host_key() - prio = int(request.priority) # priority is a boolean - if request.important: - self._queues[prio][hostkey].appendleft(func) - else: - self._queues[prio][hostkey].append(func) + task = RequestTask.from_request(request, func) + self._queue.add_task(task, request.important) if not self._timer_run_next_task.isActive(): self._timer_run_next_task.start(0) @@ -607,19 +640,15 @@ class WebService(QtCore.QObject): if not self._timer_count_pending_requests.isActive(): self._timer_count_pending_requests.start(0) - return (hostkey, func, prio) + return task def add_request(self, request): return self.add_task(partial(self._start_request, request), request) def remove_task(self, task): - hostkey, func, prio = task - try: - self._queues[prio][hostkey].remove(func) - if not self._timer_count_pending_requests.isActive(): - self._timer_count_pending_requests.start(0) - except Exception as e: - log.debug(e) + self._queue.remove_task(task) + if not self._timer_count_pending_requests.isActive(): + self._timer_count_pending_requests.start(0) @classmethod def add_parser(cls, response_type, mimetype, parser): diff --git a/test/test_webservice.py b/test/test_webservice.py index eea40cc27..f5a16e3eb 100644 --- a/test/test_webservice.py +++ b/test/test_webservice.py @@ -5,7 +5,7 @@ # Copyright (C) 2017 Sambhav Kothari # Copyright (C) 2017-2018 Wieland Hoffmann # Copyright (C) 2018, 2020 Laurent Monin -# Copyright (C) 2019-2020 Philipp Wolfer +# Copyright (C) 2019-2021 Philipp Wolfer # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License @@ -22,6 +22,7 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +import sys from unittest.mock import ( MagicMock, patch, @@ -33,6 +34,8 @@ from test.picardtestcase import PicardTestCase from picard import config from picard.webservice import ( + RequestPriorityQueue, + RequestTask, UnknownResponseParserError, WebService, WSRequest, @@ -98,109 +101,196 @@ class WebServiceTaskTest(PicardTestCase): 'network_transfer_timeout_seconds': 30, }) self.ws = WebService() + self.queue = self.ws._queue = MagicMock() # Patching the QTimers since they can only be started in a QThread self.ws._timer_run_next_task = MagicMock() self.ws._timer_count_pending_requests = MagicMock() def test_add_task(self): + request = WSRequest("", "abc.xyz", 80, "", None) + func = 1 + task = self.ws.add_task(func, request) + self.assertEqual((request.get_host_key(), func, 0), task) + self.ws._queue.add_task.assert_called_with(task, False) + request.important = True + task = self.ws.add_task(func, request) + self.ws._queue.add_task.assert_called_with(task, True) + def test_add_task_calls_timers(self): mock_timer1 = self.ws._timer_run_next_task mock_timer2 = self.ws._timer_count_pending_requests - - host = "abc.xyz" - port = 80 - request = WSRequest("", host, port, "", None) - key = request.get_host_key() + request = WSRequest("", "abc.xyz", 80, "", None) self.ws.add_task(0, request) - request.priority = True - self.ws.add_task(0, request) - request.important = True - self.ws.add_task(1, request) + mock_timer1.start.assert_not_called() + mock_timer2.start.assert_not_called() # Test if timer start was called in case it was inactive mock_timer1.isActive.return_value = False mock_timer2.isActive.return_value = False - request.priority = False - self.ws.add_task(1, request) - self.assertIn('start', repr(mock_timer1.method_calls)) - - # Test if key was added to prio queue - self.assertEqual(len(self.ws._queues[1]), 1) - self.assertIn(key, self.ws._queues[1]) - - # Test if 2 requests were added in prio queue - self.assertEqual(len(self.ws._queues[1][key]), 2) - - # Test if important request was added ahead in the queue - self.assertEqual(self.ws._queues[0][key][0], 1) + self.ws.add_task(0, request) + mock_timer1.start.assert_called_with(0) + mock_timer2.start.assert_called_with(0) def test_remove_task(self): - host = "abc.xyz" - port = 80 - request = WSRequest("", host, port, "", None) - key = request.get_host_key() + task = RequestTask(('example.com', 80), lambda: 1, priority=0) + self.ws.remove_task(task) + self.ws._queue.remove_task.assert_called_with(task) - # Add a task and check for its existance - task = self.ws.add_task(0, request) - self.assertIn(key, self.ws._queues[0]) - self.assertEqual(len(self.ws._queues[0][key]), 1) + def test_remove_task_calls_timers(self): + mock_timer = self.ws._timer_count_pending_requests + task = RequestTask(('example.com', 80), lambda: 1, priority=0) + self.ws.remove_task(task) + mock_timer.start.assert_not_called() + mock_timer.isActive.return_value = False + self.ws.remove_task(task) + mock_timer.start.assert_called_with(0) + + def test_run_next_task(self): + mock_timer = self.ws._timer_run_next_task + self.ws._queue.run_ready_tasks.return_value = sys.maxsize + self.ws._run_next_task() + self.ws._queue.run_ready_tasks.assert_called() + mock_timer.start.assert_not_called() + + def test_run_next_task_starts_next(self): + mock_timer = self.ws._timer_run_next_task + delay = 42 + self.ws._queue.run_ready_tasks.return_value = delay + self.ws._run_next_task() + self.ws._queue.run_ready_tasks.assert_called() + mock_timer.start.assert_called_with(42) + + +class RequestTaskTest(PicardTestCase): + + def test_from_request(self): + request = WSRequest('', 'example.com', 443, '', None, priority=True) + func = 1 + task = RequestTask.from_request(request, func) + self.assertEqual(request.get_host_key(), task.hostkey) + self.assertEqual(func, task.func) + self.assertEqual(1, task.priority) + self.assertEqual((request.get_host_key(), func, 1), task) + + +class RequestPriorityQueueTest(PicardTestCase): + + def test_add_task(self): + queue = RequestPriorityQueue(ratecontrol) + key = ("abc.xyz", 80) + + task1 = RequestTask(key, lambda: 1, priority=0) + queue.add_task(task1) + task2 = RequestTask(key, lambda: 1, priority=1) + queue.add_task(task2) + task3 = RequestTask(key, lambda: 1, priority=0) + queue.add_task(task3, important=True) + task4 = RequestTask(key, lambda: 1, priority=1) + queue.add_task(task4, important=True) + + # Test if 2 requests were added in each queue + self.assertEqual(len(queue._queues[0][key]), 2) + self.assertEqual(len(queue._queues[1][key]), 2) + + # Test if important request was added ahead in the queue + self.assertEqual(queue._queues[0][key][0], task3.func) + self.assertEqual(queue._queues[0][key][1], task1.func) + self.assertEqual(queue._queues[1][key][0], task4.func) + self.assertEqual(queue._queues[1][key][1], task2.func) + + def test_remove_task(self): + queue = RequestPriorityQueue(ratecontrol) + key = ("abc.xyz", 80) + + # Add a task and check for its existence + task = RequestTask(key, lambda: 1, priority=0) + task = queue.add_task(task) + self.assertIn(key, queue._queues[0]) + self.assertEqual(len(queue._queues[0][key]), 1) # Remove the task and check - self.ws.remove_task(task) - self.assertIn(key, self.ws._queues[0]) - self.assertEqual(len(self.ws._queues[0][key]), 0) + queue.remove_task(task) + self.assertIn(key, queue._queues[0]) + self.assertEqual(len(queue._queues[0][key]), 0) # Try to remove a non existing task and check for errors non_existing_task = (1, "a", "b") - self.ws.remove_task(non_existing_task) + queue.remove_task(non_existing_task) def test_run_task(self): - host = "abc.xyz" - port = 80 - request = WSRequest("", host, port, "", None) - key = request.get_host_key() + mock_ratecontrol = MagicMock() + delay_func = mock_ratecontrol.get_delay_to_next_request = MagicMock() - mock_task = MagicMock() - mock_task2 = MagicMock() - delay_func = ratecontrol.get_delay_to_next_request = MagicMock() + queue = RequestPriorityQueue(mock_ratecontrol) + key = ("abc.xyz", 80) # Patching the get delay function to delay the 2nd task on queue to the next call delay_func.side_effect = [(False, 0), (True, 0), (False, 0), (False, 0), (False, 0), (False, 0)] - self.ws.add_task(mock_task, request) - request.priority = True - self.ws.add_task(mock_task2, request) - request.priority = False - self.ws.add_task(mock_task, request) - self.ws.add_task(mock_task, request) + func1 = MagicMock() + task1 = RequestTask(key, func1, priority=0) + queue.add_task(task1) + func2 = MagicMock() + task2 = RequestTask(key, func2, priority=1) + queue.add_task(task2) + task3 = RequestTask(key, func1, priority=0) + queue.add_task(task3) + task4 = RequestTask(key, func1, priority=0) + queue.add_task(task4) # Ensure no tasks are run before run_next_task is called - self.assertEqual(mock_task.call_count, 0) - self.ws._run_next_task() + self.assertEqual(func1.call_count, 0) + queue.run_ready_tasks() # Ensure priority task is run first - self.assertEqual(mock_task2.call_count, 1) - self.assertEqual(mock_task.call_count, 0) - self.assertIn(key, self.ws._queues[1]) + self.assertEqual(func2.call_count, 1) + self.assertEqual(func1.call_count, 0) + self.assertIn(key, queue._queues[1]) # Ensure that the calls are run as expected - self.ws._run_next_task() - self.assertEqual(mock_task.call_count, 1) + queue.run_ready_tasks() + self.assertEqual(func1.call_count, 1) # Checking if the cleanup occurred on the prio queue - self.assertNotIn(key, self.ws._queues[1]) + self.assertNotIn(key, queue._queues[1]) # Check the call counts on proper execution of tasks - self.ws._run_next_task() - self.assertEqual(mock_task.call_count, 2) - self.ws._run_next_task() - self.assertEqual(mock_task.call_count, 3) + queue.run_ready_tasks() + self.assertEqual(func1.call_count, 2) + queue.run_ready_tasks() + self.assertEqual(func1.call_count, 3) # Ensure that the clean up happened on the normal queue - self.ws._run_next_task() - self.assertEqual(mock_task.call_count, 3) - self.assertNotIn(key, self.ws._queues[0]) + queue.run_ready_tasks() + self.assertEqual(func1.call_count, 3) + self.assertNotIn(key, queue._queues[0]) + + def test_count(self): + queue = RequestPriorityQueue(ratecontrol) + key = ("abc.xyz", 80) + + self.assertEqual(0, queue.count()) + task1 = RequestTask(key, lambda: 1, priority=0) + queue.add_task(task1) + self.assertEqual(1, queue.count()) + task2 = RequestTask(key, lambda: 1, priority=1) + queue.add_task(task2) + self.assertEqual(2, queue.count()) + task3 = RequestTask(key, lambda: 1, priority=0) + queue.add_task(task3, important=True) + self.assertEqual(3, queue.count()) + task4 = RequestTask(key, lambda: 1, priority=1) + queue.add_task(task4, important=True) + self.assertEqual(4, queue.count()) + queue.remove_task(task1) + self.assertEqual(3, queue.count()) + queue.remove_task(task2) + self.assertEqual(2, queue.count()) + queue.remove_task(task3) + self.assertEqual(1, queue.count()) + queue.remove_task(task4) + self.assertEqual(0, queue.count()) class WebServiceProxyTest(PicardTestCase):