diff --git a/picard/webservice/api_helpers.py b/picard/webservice/api_helpers.py index d31f9b710..fb96df92c 100644 --- a/picard/webservice/api_helpers.py +++ b/picard/webservice/api_helpers.py @@ -239,10 +239,9 @@ class MBAPIHelper(APIHelper): return self.get_collection(None, handler) @staticmethod - def _collection_request(collection_id, releases): - while releases: - ids = ";".join(releases if len(releases) <= 400 else releases[:400]) - releases = releases[400:] + def _collection_request(collection_id, releases, batchsize=400): + for i in range(0, len(releases), batchsize): + ids = ";".join(releases[i:i+batchsize]) yield ("collection", collection_id, "releases", ids) @staticmethod diff --git a/test/test_api_helpers.py b/test/test_api_helpers.py index f130a51b3..0c88648a4 100644 --- a/test/test_api_helpers.py +++ b/test/test_api_helpers.py @@ -179,3 +179,15 @@ class MBAPITest(PicardTestCase): '' '' ) + + def test_collection_request(self): + releases = tuple("r"+str(i) for i in range(13)) + generator = self.api._collection_request("test", releases, batchsize=5) + batch = next(generator) + self.assertEqual(batch, ('collection', 'test', 'releases', 'r0;r1;r2;r3;r4')) + batch = next(generator) + self.assertEqual(batch, ('collection', 'test', 'releases', 'r5;r6;r7;r8;r9')) + batch = next(generator) + self.assertEqual(batch, ('collection', 'test', 'releases', 'r10;r11;r12')) + with self.assertRaises(StopIteration): + next(generator)