diff --git a/jubeatools/formats/jubeat_analyser/load_tools.py b/jubeatools/formats/jubeat_analyser/load_tools.py index 3844d8b..9630e34 100644 --- a/jubeatools/formats/jubeat_analyser/load_tools.py +++ b/jubeatools/formats/jubeat_analyser/load_tools.py @@ -3,7 +3,7 @@ import re import warnings from collections import Counter from copy import deepcopy -from dataclasses import dataclass +from dataclasses import astuple, dataclass from decimal import Decimal from itertools import product, zip_longest from typing import Dict, Iterator, List, Optional, Set, Tuple @@ -226,7 +226,7 @@ def pick_correct_long_note_candidates( def note_distance(a: NotePosition, b: NotePosition) -> float: - return abs(complex(*a.as_tuple()) - complex(*b.as_tuple())) + return abs(complex(*astuple(a)) - complex(*astuple(b))) def long_note_solution_heuristic( diff --git a/jubeatools/formats/jubeat_analyser/memo/load.py b/jubeatools/formats/jubeat_analyser/memo/load.py index 3e168bb..61ab32a 100644 --- a/jubeatools/formats/jubeat_analyser/memo/load.py +++ b/jubeatools/formats/jubeat_analyser/memo/load.py @@ -1,7 +1,7 @@ import warnings from collections import ChainMap from copy import deepcopy -from dataclasses import dataclass +from dataclasses import astuple, dataclass from decimal import Decimal from functools import reduce from itertools import chain, product, zip_longest @@ -220,7 +220,7 @@ class MemoParser(JubeatAnalyserParser): # 1/3 : look for ends to unfinished long notes for pos, unfinished_long in unfinished_longs.items(): - x, y = pos.as_tuple() + x, y = astuple(pos) symbol = frame.position_part[y][x] if self.circle_free and symbol in CIRCLE_FREE_SYMBOLS: circled_symbol = CIRCLE_FREE_TO_NOTE_SYMBOL[symbol] diff --git a/jubeatools/formats/jubeat_analyser/memo1/load.py b/jubeatools/formats/jubeat_analyser/memo1/load.py index 9b38ac3..6cf924a 100644 --- a/jubeatools/formats/jubeat_analyser/memo1/load.py +++ b/jubeatools/formats/jubeat_analyser/memo1/load.py @@ -1,7 +1,7 @@ import warnings from collections import ChainMap from copy import deepcopy -from dataclasses import dataclass +from dataclasses import astuple, dataclass from decimal import Decimal from functools import reduce from itertools import chain, product, zip_longest @@ -214,7 +214,7 @@ class Memo1Parser(JubeatAnalyserParser): # 1/3 : look for ends to unfinished long notes for pos, unfinished_long in unfinished_longs.items(): - x, y = pos.as_tuple() + x, y = astuple(pos) symbol = frame.position_part[y][x] if self.circle_free and symbol in CIRCLE_FREE_SYMBOLS: circled_symbol = CIRCLE_FREE_TO_NOTE_SYMBOL[symbol] diff --git a/jubeatools/formats/jubeat_analyser/memo2/load.py b/jubeatools/formats/jubeat_analyser/memo2/load.py index b377c0b..b1beef0 100644 --- a/jubeatools/formats/jubeat_analyser/memo2/load.py +++ b/jubeatools/formats/jubeat_analyser/memo2/load.py @@ -1,7 +1,7 @@ import warnings from collections import ChainMap from copy import deepcopy -from dataclasses import dataclass +from dataclasses import astuple, dataclass from decimal import Decimal from functools import reduce from itertools import chain, product, zip_longest @@ -359,7 +359,7 @@ class Memo2Parser(JubeatAnalyserParser): should_skip: Set[NotePosition] = set() # 1/3 : look for ends to unfinished long notes for pos, unfinished_long in unfinished_longs.items(): - x, y = pos.as_tuple() + x, y = astuple(pos) symbol = frame.position_part[y][x] if self.circle_free and symbol in CIRCLE_FREE_SYMBOLS: circled_symbol = CIRCLE_FREE_TO_NOTE_SYMBOL[symbol] diff --git a/jubeatools/formats/jubeat_analyser/mono_column/load.py b/jubeatools/formats/jubeat_analyser/mono_column/load.py index cd381aa..1de5010 100644 --- a/jubeatools/formats/jubeat_analyser/mono_column/load.py +++ b/jubeatools/formats/jubeat_analyser/mono_column/load.py @@ -2,7 +2,7 @@ import re import warnings from collections import Counter from copy import deepcopy -from dataclasses import dataclass +from dataclasses import astuple, dataclass from decimal import Decimal from enum import Enum from functools import reduce @@ -200,7 +200,7 @@ class MonoColumnParser(JubeatAnalyserParser): # 1/3 : look for ends to unfinished long notes for pos, unfinished_long in unfinished_longs.items(): - x, y = pos.as_tuple() + x, y = astuple(pos) symbol = bloc[y][x] if self.circle_free: if symbol in CIRCLE_FREE_SYMBOLS: diff --git a/jubeatools/song.py b/jubeatools/song.py index 0e35a0c..95632b0 100644 --- a/jubeatools/song.py +++ b/jubeatools/song.py @@ -8,7 +8,7 @@ otherwise a decimal number of seconds is used """ from collections import UserList, namedtuple -from dataclasses import dataclass, field +from dataclasses import astuple, dataclass, field from decimal import Decimal from fractions import Fraction from functools import wraps @@ -45,22 +45,18 @@ def convert_other(f): return wrapped -@dataclass(frozen=True) +@dataclass(frozen=True, order=True) class NotePosition: x: int y: int def __iter__(self): - yield self.x - yield self.y + yield from astuple(self) @property def index(self): return self.x + 4 * self.y - def as_tuple(self): - return (self.x, self.y) - @classmethod def from_index(cls: Type["NotePosition"], index: int) -> "NotePosition": if not (0 <= index < 16): @@ -68,10 +64,6 @@ class NotePosition: return cls(x=index % 4, y=index // 4) - @convert_other - def __lt__(self, other): - return self.as_tuple() < other.as_tuple() - @convert_other def __add__(self, other): return NotePosition(self.x + other.x, self.y + other.y) @@ -81,13 +73,13 @@ class NotePosition: return NotePosition(self.x - other.x, self.y - other.y) -@dataclass(frozen=True) +@dataclass(frozen=True, unsafe_hash=True) class TapNote: time: BeatsTime position: NotePosition -@dataclass(frozen=True) +@dataclass(frozen=True, unsafe_hash=True) class LongNote: time: BeatsTime position: NotePosition @@ -96,9 +88,6 @@ class LongNote: # playfield tail_tip: NotePosition - def __hash__(self): - return hash((self.time, self.position)) - def tail_is_straight(self) -> bool: return (self.position.x == self.tail_tip.x) or ( self.position.y == self.tail_tip.y @@ -107,7 +96,7 @@ class LongNote: def tail_direction(self) -> NotePosition: if not self.tail_is_straight(): raise ValueError("Can't get tail direction when it's not straight") - x, y = (self.tail_tip - self.position).as_tuple() + x, y = astuple(self.tail_tip - self.position) if x == 0: y //= abs(y) else: @@ -129,17 +118,11 @@ class BPMEvent: BPM: Decimal -@dataclass +@dataclass(unsafe_hash=True) class Timing: events: List[BPMEvent] beat_zero_offset: SecondsTime - def __hash__(self): - return hash(( - tuple(self.events), - self.beat_zero_offset, - )) - @dataclass class Chart: @@ -161,6 +144,7 @@ class Metadata: audio: Path cover: Path preview: Optional[Preview] = None + preview_file: Optional[Path] = None @dataclass