From b1bde57bef878478e3503ab07190fd207914ade9 Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Wed, 8 Feb 2023 04:11:08 +0100 Subject: [PATCH] [utils] `traverse_obj`: Fix several behavioral problems See #6180 for further info Authored by: Grub4K --- test/test_utils.py | 43 +++++++++----- yt_dlp/utils.py | 141 ++++++++++++++++++++++++++------------------- 2 files changed, 108 insertions(+), 76 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index ffe1b729f..190e4ef9b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2000,8 +2000,8 @@ def test_traverse_obj(self): # Test Ellipsis behavior self.assertCountEqual(traverse_obj(_TEST_DATA, ...), - (item for item in _TEST_DATA.values() if item is not None), - msg='`...` should give all values except `None`') + (item for item in _TEST_DATA.values() if item not in (None, [], {})), + msg='`...` should give all non discarded values') self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(), msg='`...` selection for dicts should select all values') self.assertEqual(traverse_obj(_TEST_DATA, (..., ..., 'url')), @@ -2084,15 +2084,23 @@ def test_traverse_obj(self): {0: ['https://www.example.com/1', 'https://www.example.com/0']}, msg='tripple nesting in dict path should be treated as branches') self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {}, - msg='remove `None` values when dict key') + msg='remove `None` values when top level dict key fails') self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=...), {0: ...}, - msg='do not remove `None` values if `default`') - self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}}, - msg='do not remove empty values when dict key') - self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: {}}, - msg='do not remove empty values when dict key and a default') - self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {0: []}, - msg='if branch in dict key not successful, return `[]`') + msg='use `default` if key fails and `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {}, + msg='remove empty values when dict key') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=...), {0: ...}, + msg='use `default` when dict key and `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {}, + msg='remove empty values when nested dict key fails') + self.assertEqual(traverse_obj(None, {0: 'fail'}), {}, + msg='default to dict if pruned') + self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {}, + msg='default to dict if pruned and default is given') + self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...), {0: {0: ...}}, + msg='use nested `default` when nested dict key fails and `default`') + self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', ...)}), {}, + msg='remove key if branch in dict key not successful') # Testing default parameter behavior _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []} @@ -2183,14 +2191,17 @@ def test_traverse_obj(self): traverse_string=True), '.', msg='traverse into converted data if `traverse_string`') self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', ...), - traverse_string=True), list('str'), - msg='`...` branching into string should result in list') + traverse_string=True), 'str', + msg='`...` should result in string (same value) if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)), + traverse_string=True), 'sr', + msg='`slice` should result in string if `traverse_string`') + self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"), + traverse_string=True), 'str', + msg='function should result in string if `traverse_string`') self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)), traverse_string=True), ['s', 'r'], - msg='branching into string should result in list') - self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x), - traverse_string=True), list('str'), - msg='function branching into string should result in list') + msg='branching should result in list if `traverse_string`') # Test is_user_input behavior _IS_USER_INPUT_DATA = {'range8': list(range(8))} diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index e1e0f7b25..878b2b6a8 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5420,7 +5420,7 @@ def traverse_obj( Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. - A value of None is treated as the absence of a value. + Unhelpful values (`[]`, `{}`, `None`) are treated as the absence of a value and discarded. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. @@ -5446,6 +5446,8 @@ def traverse_obj( @params paths Paths which to traverse by. @param default Value to return if the paths do not match. + If the last key in the path is a `dict`, it will apply to each value inside + the dict instead, depth first. Try to avoid if using nested `dict` keys. @param expected_type If a `type`, only accept final values of this type. If any other callable, try to call the function on each result. If the last key in the path is a `dict`, it will apply to each value inside @@ -5460,12 +5462,15 @@ def traverse_obj( @param traverse_string Whether to traverse into objects as strings. If `True`, any non-compatible object will first be converted into a string and then traversed into. + The return value of that path will be a string instead, + not respecting any further branching. @returns The result of the object traversal. If successful, `get_all=True`, and the path branches at least once, then a list of results is returned instead. - A list is always returned if the last path branches and no `default` is given. + If no `default` is given and the last path branches, a `list` of results + is always returned. If a path ends on a `dict` that result will always be a `dict`. """ is_sequence = lambda x: isinstance(x, collections.abc.Sequence) and not isinstance(x, (str, bytes)) casefold = lambda k: k.casefold() if isinstance(k, str) else k @@ -5475,87 +5480,94 @@ def traverse_obj( else: type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) - def apply_key(key, test_type, obj): + def apply_key(key, obj, is_last): + branching = False + result = None + if obj is None: - return + pass elif key is None: - yield obj + result = obj elif isinstance(key, set): assert len(key) == 1, 'Set should only be used to wrap a single item' item = next(iter(key)) if isinstance(item, type): if isinstance(obj, item): - yield obj + result = obj else: - yield try_call(item, args=(obj,)) + result = try_call(item, args=(obj,)) elif isinstance(key, (list, tuple)): - for branch in key: - _, result = apply_path(obj, branch, test_type) - yield from result + branching = True + result = itertools.chain.from_iterable( + apply_path(obj, branch, is_last)[0] for branch in key) elif key is ...: + branching = True if isinstance(obj, collections.abc.Mapping): - yield from obj.values() + result = obj.values() elif is_sequence(obj): - yield from obj + result = obj elif isinstance(obj, re.Match): - yield from obj.groups() + result = obj.groups() elif traverse_string: - yield from str(obj) + branching = False + result = str(obj) + else: + result = () elif callable(key): - if is_sequence(obj): - iter_obj = enumerate(obj) - elif isinstance(obj, collections.abc.Mapping): + branching = True + if isinstance(obj, collections.abc.Mapping): iter_obj = obj.items() + elif is_sequence(obj): + iter_obj = enumerate(obj) elif isinstance(obj, re.Match): iter_obj = itertools.chain( enumerate((obj.group(), *obj.groups())), obj.groupdict().items()) elif traverse_string: + branching = False iter_obj = enumerate(str(obj)) else: - return - yield from (v for k, v in iter_obj if try_call(key, args=(k, v))) + iter_obj = () + + result = (v for k, v in iter_obj if try_call(key, args=(k, v))) + if not branching: # string traversal + result = ''.join(result) elif isinstance(key, dict): - iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items()) - yield {k: v if v is not None else default for k, v in iter_obj - if v is not None or default is not NO_DEFAULT} + iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items()) + result = { + k: v if v is not None else default for k, v in iter_obj + if v is not None or default is not NO_DEFAULT + } or None elif isinstance(obj, collections.abc.Mapping): - yield (obj.get(key) if casesense or (key in obj) - else next((v for k, v in obj.items() if casefold(k) == key), None)) + result = (obj.get(key) if casesense or (key in obj) else + next((v for k, v in obj.items() if casefold(k) == key), None)) elif isinstance(obj, re.Match): if isinstance(key, int) or casesense: with contextlib.suppress(IndexError): - yield obj.group(key) - return + result = obj.group(key) - if not isinstance(key, str): - return - - yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) - - else: - if is_user_input: - key = (int_or_none(key) if ':' not in key - else slice(*map(int_or_none, key.split(':')))) - - if not isinstance(key, (int, slice)): - return + elif isinstance(key, str): + result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) + elif isinstance(key, (int, slice)): if not is_sequence(obj): - if not traverse_string: - return - obj = str(obj) + if traverse_string: + with contextlib.suppress(IndexError): + result = str(obj)[key] + else: + branching = isinstance(key, slice) + with contextlib.suppress(IndexError): + result = obj[key] - with contextlib.suppress(IndexError): - yield obj[key] + return branching, result if branching else (result,) def lazy_last(iterable): iterator = iter(iterable) @@ -5569,45 +5581,54 @@ def lazy_last(iterable): yield True, prev - def apply_path(start_obj, path, test_type=False): + def apply_path(start_obj, path, test_type): objs = (start_obj,) has_branched = False key = None for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): - if is_user_input and key == ':': - key = ... + if is_user_input and isinstance(key, str): + if key == ':': + key = ... + elif ':' in key: + key = slice(*map(int_or_none, key.split(':'))) + elif int_or_none(key) is not None: + key = int(key) if not casesense and isinstance(key, str): key = key.casefold() - if key is ... or isinstance(key, (list, tuple)) or callable(key): - has_branched = True - if __debug__ and callable(key): # Verify function signature inspect.signature(key).bind(None, None) - key_func = functools.partial(apply_key, key, last) - objs = itertools.chain.from_iterable(map(key_func, objs)) + new_objs = [] + for obj in objs: + branching, results = apply_key(key, obj, last) + has_branched |= branching + new_objs.append(results) + + objs = itertools.chain.from_iterable(new_objs) if test_type and not isinstance(key, (dict, list, tuple)): objs = map(type_test, objs) - return has_branched, objs - - def _traverse_obj(obj, path, use_list=True, test_type=True): - has_branched, results = apply_path(obj, path, test_type) - results = LazyList(x for x in results if x is not None) + return objs, has_branched, isinstance(key, dict) + def _traverse_obj(obj, path, allow_empty, test_type): + results, has_branched, is_dict = apply_path(obj, path, test_type) + results = LazyList(item for item in results if item not in (None, [], {})) if get_all and has_branched: - return results.exhaust() if results or use_list else None + if results: + return results.exhaust() + if allow_empty: + return [] if default is NO_DEFAULT else default + return None - return results[0] if results else None + return results[0] if results else {} if allow_empty and is_dict else None for index, path in enumerate(paths, 1): - use_list = default is NO_DEFAULT and index == len(paths) - result = _traverse_obj(obj, path, use_list) + result = _traverse_obj(obj, path, index == len(paths), True) if result is not None: return result