diff --git a/picard/version.py b/picard/version.py index a44201fa5..dad2fc978 100644 --- a/picard/version.py +++ b/picard/version.py @@ -37,40 +37,45 @@ class Version(tuple): 'final': 4 } - def __new__(cls, major, minor, patch, identifier='final', revision=0): + def __new__(cls, major, minor, patch=0, identifier='final', revision=0): if identifier not in cls.valid_identifiers(): raise VersionError("Should be either 'final', 'dev', 'alpha', 'beta' or 'rc'") identifier = {'a': 'alpha', 'b': 'beta'}.get(identifier, identifier) + try: + major = int(major) + minor = int(minor) + patch = int(patch) + revision = int(revision) + except (TypeError, ValueError): + raise VersionError("major, minor, patch and revision must be integer values") return super(Version, cls).__new__(cls, (major, minor, patch, identifier, revision)) @classmethod def from_string(cls, version_str): - m = cls._version_re.search(version_str) - if m: - g = m.groups() - if g[2] is None: - return Version(int(g[0]), int(g[1]), 0, 'final', 0) - if g[3] is None: - return Version(int(g[0]), int(g[1]), int(g[2]), 'final', 0) - return Version(int(g[0]), int(g[1]), int(g[2]), g[3], int(g[4])) + match = cls._version_re.search(version_str) + if match: + (major, minor, patch, identifier, revision) = match.groups() + major = int(major) + minor = int(minor) + if patch is None: + return Version(major, minor) + patch = int(patch) + if identifier is None: + return Version(major, minor, patch) + revision = int(revision) + return Version(major, minor, patch, identifier, revision) raise VersionError("String '%s' does not match regex '%s'" % (version_str, cls._version_re.pattern)) @classmethod def valid_identifiers(cls): - return cls._identifiers.keys() + return set(cls._identifiers.keys()) def to_string(self, short=False): - _version = [] - for p in self: - try: - n = int(p) - except ValueError: - n = p - _version.append(n) - if short and _version[3] in ('alpha', 'beta'): - _version[3] = _version[3][:1] - version = tuple(_version) + if short and self[3] in ('alpha', 'beta'): + version = (self[0], self[1], self[2], self[3][:1], self[4]) + else: + version = self if short and version[3] == 'final': if version[2] == 0: version_str = '%d.%d' % version[:2] diff --git a/test/test_versions.py b/test/test_versions.py index 4178b9c91..5a043266c 100644 --- a/test/test_versions.py +++ b/test/test_versions.py @@ -99,6 +99,12 @@ class VersionsTest(PicardTestCase): b = api_versions_tuple[i+1] self.assertLess(a, b) + def test_version_invalid_new(self): + self.assertRaises(VersionError, Version, '1', 'a') + self.assertRaises(VersionError, Version, None, 0) + self.assertRaises(VersionError, Version, 1, 0, 0, 'final', None) + self.assertRaises(VersionError, Version, 1, 0, 0, 'invalid', 0) + def test_sortkey(self): self.assertEqual((2, 1, 3, 4, 2), Version(2, 1, 3, 'final', 2).sortkey) self.assertEqual((2, 0, 0, 1, 0), Version(2, 0, 0, 'a', 0).sortkey)