diff --git a/picard/__init__.py b/picard/__init__.py index eb5d67bba..954830c5f 100644 --- a/picard/__init__.py +++ b/picard/__init__.py @@ -25,9 +25,15 @@ PICARD_ORG_NAME = "MusicBrainz" PICARD_VERSION = (1, 3, 0, 'dev', 2) +class VersionError(Exception): + pass + + def version_to_string(version, short=False): - assert len(version) == 5 - assert version[3] in ('final', 'dev') + if len(version) != 5: + raise VersionError("Length != 5") + if version[3] not in ('final', 'dev'): + raise VersionError("Should be either 'final' or 'dev'") _version = [] for p in version: try: @@ -47,9 +53,14 @@ def version_to_string(version, short=False): return version_str +_version_re = re.compile("^(\d+)[._](\d+)[._](\d+)[._]?(dev|final)[._]?(\d+)$") def version_from_string(version_str): - g = re.match(r"^(\d+)[._](\d+)[._](\d+)[._]?(dev|final)[._]?(\d+)$", version_str).groups() - return (int(g[0]), int(g[1]), int(g[2]), g[3], int(g[4])) + m = _version_re.search(version_str) + if m: + g = m.groups() + return (int(g[0]), int(g[1]), int(g[2]), g[3], int(g[4])) + raise VersionError("String '%s' do not match regex '%s'" % (version_str, + _version_re.pattern)) __version__ = PICARD_VERSION_STR = version_to_string(PICARD_VERSION) diff --git a/test/test_versions.py b/test/test_versions.py index c99dca7e5..aad5b8e6f 100644 --- a/test/test_versions.py +++ b/test/test_versions.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- import unittest -from picard import version_to_string, version_from_string +from picard import (version_to_string, + version_from_string, + VersionError) class VersionsTest(unittest.TestCase): @@ -23,8 +25,8 @@ class VersionsTest(unittest.TestCase): def test_version_conv_4(self): l, s = (1, 0, 2, '', 0), '1.0.2' - self.assertRaises(AssertionError, version_to_string, (l)) - self.assertRaises(AttributeError, version_from_string, (s)) + self.assertRaises(VersionError, version_to_string, (l)) + self.assertRaises(VersionError, version_from_string, (s)) def test_version_conv_5(self): l, s = (999, 999, 999, 'dev', 999), '999.999.999dev999' @@ -33,7 +35,7 @@ class VersionsTest(unittest.TestCase): def test_version_conv_6(self): l = (1, 0, 2, 'xx', 0) - self.assertRaises(AssertionError, version_to_string, (l)) + self.assertRaises(VersionError, version_to_string, (l)) def test_version_conv_7(self): l, s = (1, 1, 0, 'final', 0), '1.1'