Use more builtins from dataclasses
This commit is contained in:
parent
378abcbe01
commit
b64685e73d
@ -3,7 +3,7 @@ import re
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import astuple, dataclass
|
||||
from decimal import Decimal
|
||||
from itertools import product, zip_longest
|
||||
from typing import Dict, Iterator, List, Optional, Set, Tuple
|
||||
@ -226,7 +226,7 @@ def pick_correct_long_note_candidates(
|
||||
|
||||
|
||||
def note_distance(a: NotePosition, b: NotePosition) -> float:
|
||||
return abs(complex(*a.as_tuple()) - complex(*b.as_tuple()))
|
||||
return abs(complex(*astuple(a)) - complex(*astuple(b)))
|
||||
|
||||
|
||||
def long_note_solution_heuristic(
|
||||
|
@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from collections import ChainMap
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import astuple, dataclass
|
||||
from decimal import Decimal
|
||||
from functools import reduce
|
||||
from itertools import chain, product, zip_longest
|
||||
@ -220,7 +220,7 @@ class MemoParser(JubeatAnalyserParser):
|
||||
|
||||
# 1/3 : look for ends to unfinished long notes
|
||||
for pos, unfinished_long in unfinished_longs.items():
|
||||
x, y = pos.as_tuple()
|
||||
x, y = astuple(pos)
|
||||
symbol = frame.position_part[y][x]
|
||||
if self.circle_free and symbol in CIRCLE_FREE_SYMBOLS:
|
||||
circled_symbol = CIRCLE_FREE_TO_NOTE_SYMBOL[symbol]
|
||||
|
@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from collections import ChainMap
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import astuple, dataclass
|
||||
from decimal import Decimal
|
||||
from functools import reduce
|
||||
from itertools import chain, product, zip_longest
|
||||
@ -214,7 +214,7 @@ class Memo1Parser(JubeatAnalyserParser):
|
||||
|
||||
# 1/3 : look for ends to unfinished long notes
|
||||
for pos, unfinished_long in unfinished_longs.items():
|
||||
x, y = pos.as_tuple()
|
||||
x, y = astuple(pos)
|
||||
symbol = frame.position_part[y][x]
|
||||
if self.circle_free and symbol in CIRCLE_FREE_SYMBOLS:
|
||||
circled_symbol = CIRCLE_FREE_TO_NOTE_SYMBOL[symbol]
|
||||
|
@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from collections import ChainMap
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import astuple, dataclass
|
||||
from decimal import Decimal
|
||||
from functools import reduce
|
||||
from itertools import chain, product, zip_longest
|
||||
@ -359,7 +359,7 @@ class Memo2Parser(JubeatAnalyserParser):
|
||||
should_skip: Set[NotePosition] = set()
|
||||
# 1/3 : look for ends to unfinished long notes
|
||||
for pos, unfinished_long in unfinished_longs.items():
|
||||
x, y = pos.as_tuple()
|
||||
x, y = astuple(pos)
|
||||
symbol = frame.position_part[y][x]
|
||||
if self.circle_free and symbol in CIRCLE_FREE_SYMBOLS:
|
||||
circled_symbol = CIRCLE_FREE_TO_NOTE_SYMBOL[symbol]
|
||||
|
@ -2,7 +2,7 @@ import re
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import astuple, dataclass
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
@ -200,7 +200,7 @@ class MonoColumnParser(JubeatAnalyserParser):
|
||||
|
||||
# 1/3 : look for ends to unfinished long notes
|
||||
for pos, unfinished_long in unfinished_longs.items():
|
||||
x, y = pos.as_tuple()
|
||||
x, y = astuple(pos)
|
||||
symbol = bloc[y][x]
|
||||
if self.circle_free:
|
||||
if symbol in CIRCLE_FREE_SYMBOLS:
|
||||
|
@ -8,7 +8,7 @@ otherwise a decimal number of seconds is used
|
||||
"""
|
||||
|
||||
from collections import UserList, namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import astuple, dataclass, field
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from functools import wraps
|
||||
@ -45,22 +45,18 @@ def convert_other(f):
|
||||
return wrapped
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, order=True)
|
||||
class NotePosition:
|
||||
x: int
|
||||
y: int
|
||||
|
||||
def __iter__(self):
|
||||
yield self.x
|
||||
yield self.y
|
||||
yield from astuple(self)
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
return self.x + 4 * self.y
|
||||
|
||||
def as_tuple(self):
|
||||
return (self.x, self.y)
|
||||
|
||||
@classmethod
|
||||
def from_index(cls: Type["NotePosition"], index: int) -> "NotePosition":
|
||||
if not (0 <= index < 16):
|
||||
@ -68,10 +64,6 @@ class NotePosition:
|
||||
|
||||
return cls(x=index % 4, y=index // 4)
|
||||
|
||||
@convert_other
|
||||
def __lt__(self, other):
|
||||
return self.as_tuple() < other.as_tuple()
|
||||
|
||||
@convert_other
|
||||
def __add__(self, other):
|
||||
return NotePosition(self.x + other.x, self.y + other.y)
|
||||
@ -81,13 +73,13 @@ class NotePosition:
|
||||
return NotePosition(self.x - other.x, self.y - other.y)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, unsafe_hash=True)
|
||||
class TapNote:
|
||||
time: BeatsTime
|
||||
position: NotePosition
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, unsafe_hash=True)
|
||||
class LongNote:
|
||||
time: BeatsTime
|
||||
position: NotePosition
|
||||
@ -96,9 +88,6 @@ class LongNote:
|
||||
# playfield
|
||||
tail_tip: NotePosition
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.time, self.position))
|
||||
|
||||
def tail_is_straight(self) -> bool:
|
||||
return (self.position.x == self.tail_tip.x) or (
|
||||
self.position.y == self.tail_tip.y
|
||||
@ -107,7 +96,7 @@ class LongNote:
|
||||
def tail_direction(self) -> NotePosition:
|
||||
if not self.tail_is_straight():
|
||||
raise ValueError("Can't get tail direction when it's not straight")
|
||||
x, y = (self.tail_tip - self.position).as_tuple()
|
||||
x, y = astuple(self.tail_tip - self.position)
|
||||
if x == 0:
|
||||
y //= abs(y)
|
||||
else:
|
||||
@ -129,17 +118,11 @@ class BPMEvent:
|
||||
BPM: Decimal
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(unsafe_hash=True)
|
||||
class Timing:
|
||||
events: List[BPMEvent]
|
||||
beat_zero_offset: SecondsTime
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
tuple(self.events),
|
||||
self.beat_zero_offset,
|
||||
))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chart:
|
||||
@ -161,6 +144,7 @@ class Metadata:
|
||||
audio: Path
|
||||
cover: Path
|
||||
preview: Optional[Preview] = None
|
||||
preview_file: Optional[Path] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
Loading…
x
Reference in New Issue
Block a user