1
0
mirror of synced 2024-12-05 03:27:54 +01:00

Split NotePosition into two classes, one for actual button positions (bounds checked) and the other for 2d integer vectors (no bounds check)

This commit is contained in:
Stepland 2021-05-09 17:04:34 +02:00
parent 25b6529245
commit 4d11287be2
3 changed files with 65 additions and 41 deletions

View File

@ -230,7 +230,12 @@ def find_long_note_candidates(
note_candidates = set() note_candidates = set()
𝛿pos = LONG_DIRECTION[symbol] 𝛿pos = LONG_DIRECTION[symbol]
candidate = NotePosition(x, y) + 𝛿pos candidate = NotePosition(x, y) + 𝛿pos
while 0 <= candidate.x < 4 and 0 <= candidate.y < 4: while True:
try:
candidate = NotePosition.from_raw_position(candidate)
except ValueError:
break
if candidate not in should_skip: if candidate not in should_skip:
new_symbol = bloc[candidate.y][candidate.x] new_symbol = bloc[candidate.y][candidate.x]
if new_symbol in note_symbols: if new_symbol in note_symbols:

View File

@ -141,7 +141,9 @@ def _load_memon_note_v0(
time = jbt.beats_time_from_ticks(ticks=note["t"], resolution=resolution) time = jbt.beats_time_from_ticks(ticks=note["t"], resolution=resolution)
if note["l"] > 0: if note["l"] > 0:
duration = jbt.beats_time_from_ticks(ticks=note["l"], resolution=resolution) duration = jbt.beats_time_from_ticks(ticks=note["l"], resolution=resolution)
tail_tip = position + jbt.NotePosition(*P_VALUE_TO_X_Y_OFFSET[note["p"]]) p_value = note["p"]
𝛿x, 𝛿y = P_VALUE_TO_X_Y_OFFSET[p_value]
tail_tip = jbt.NotePosition.from_raw_position(position + jbt.Position(𝛿x, 𝛿y))
return jbt.LongNote(time, position, duration, tail_tip) return jbt.LongNote(time, position, duration, tail_tip)
else: else:
return jbt.TapNote(time, position) return jbt.TapNote(time, position)

View File

@ -1,21 +1,19 @@
""" """Provides the Song class, the central model for chartsets
Provides the Song class, the central model for chartsets
Every input format is converted to a Song instance Every input format is converted to a Song instance
Every output format is created from a Song instance Every output format is created from a Song instance
Most timing-related info is stored as beat fractions, Most timing-related info is stored as beat fractions, otherwise a decimal
otherwise a decimal number of seconds is used number of seconds is used"""
"""
from __future__ import annotations from __future__ import annotations
from dataclasses import astuple, dataclass, field from dataclasses import astuple, dataclass, field
from decimal import Decimal from decimal import Decimal
from enum import Enum, auto
from fractions import Fraction from fractions import Fraction
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterator, List, Mapping, Optional, Type, Union, Tuple from typing import Any, Callable, Iterator, List, Mapping, Optional, Tuple, Union
from enum import Enum, auto
from multidict import MultiDict from multidict import MultiDict
@ -30,27 +28,48 @@ def beats_time_from_ticks(ticks: int, resolution: int) -> BeatsTime:
def convert_other( def convert_other(
f: Callable[[NotePosition, NotePosition], NotePosition] f: Callable[[Position, Position], Position]
) -> Callable[[NotePosition, Any], NotePosition]: ) -> Callable[[Position, Any], Position]:
@wraps(f) @wraps(f)
def wrapped(self: NotePosition, other: Any) -> NotePosition: def wrapped(self: Position, other: Any) -> Position:
if isinstance(other, NotePosition): if isinstance(other, Position):
other_note = other other_pos = other
else: else:
try: try:
other_note = NotePosition(*other) other_pos = Position(*other)
except Exception: except Exception:
raise ValueError( raise ValueError(f"Could not convert {type(other)} to a Position")
f"Invalid type for {f.__name__} with NotePosition : {type(other).__name__}"
)
return f(self, other_note) return f(self, other_pos)
return wrapped return wrapped
@dataclass(frozen=True, order=True)
class Position:
"""2D integer vector"""
x: int
y: int
def __iter__(self) -> Iterator[int]:
yield from astuple(self)
@convert_other
def __add__(self, other: Position) -> Position:
return Position(self.x + other.x, self.y + other.y)
@convert_other
def __sub__(self, other: Position) -> Position:
return Position(self.x - other.x, self.y - other.y)
def __mul__(self, other: int) -> Position:
return Position(self.x * other, self.y * other)
__rmul__ = __mul__
@dataclass(frozen=True, order=True) @dataclass(frozen=True, order=True)
class NotePosition: class NotePosition(Position):
"""A specific square on the controller. (0, 0) is the top-left button, x """A specific square on the controller. (0, 0) is the top-left button, x
goes right, y goes down. goes right, y goes down.
@ -60,13 +79,15 @@ class NotePosition:
1 1
2 2
3 3
The main difference with Position is that x and y MUST be between 0 and 3
""" """
x: int def __post_init__(self) -> None:
y: int if not 0 <= self.x < 4:
raise ValueError("x out of [0, 3] range")
def __iter__(self) -> Iterator[int]: if not 0 <= self.y < 4:
yield from astuple(self) raise ValueError("y out of [0, 3] range")
@property @property
def index(self) -> int: def index(self) -> int:
@ -78,14 +99,10 @@ class NotePosition:
raise ValueError(f"Note position index out of range : {index}") raise ValueError(f"Note position index out of range : {index}")
return cls(x=index % 4, y=index // 4) return cls(x=index % 4, y=index // 4)
@convert_other @classmethod
def __add__(self, other: NotePosition) -> NotePosition: def from_raw_position(cls, pos: Position) -> NotePosition:
return NotePosition(self.x + other.x, self.y + other.y) return cls(x=pos.x, y=pos.y)
@convert_other
def __sub__(self, other: NotePosition) -> NotePosition:
return NotePosition(self.x - other.x, self.y - other.y)
@dataclass(frozen=True, unsafe_hash=True) @dataclass(frozen=True, unsafe_hash=True)
@ -110,7 +127,7 @@ class LongNote:
def tail_direction(self) -> Direction: def tail_direction(self) -> Direction:
"""Direction in which the tail moves""" """Direction in which the tail moves"""
if not self.has_straight_tail(): if not self.has_straight_tail():
raise ValueError("Can't get tail direction when it's not straight") raise ValueError("Can't get cardinal direction of diagonal long note")
if self.tail_tip.x == self.position.x: if self.tail_tip.x == self.position.x:
if self.tail_tip.y > self.position.y: if self.tail_tip.y > self.position.y:
@ -125,11 +142,11 @@ class LongNote:
def positions_covered(self) -> Iterator[NotePosition]: def positions_covered(self) -> Iterator[NotePosition]:
direction = self.tail_direction() direction = self.tail_direction()
step = TAIL_DIRECTION_TO_NOTE_TO_TAIL_VECTOR[direction] step = TAIL_DIRECTION_TO_OUTWARDS_VECTOR[direction]
position = self.position position = self.position
yield position yield position
while position != self.tail_tip: while position != self.tail_tip:
position = position + step position = NotePosition.from_raw_position(position + step)
yield position yield position
@ -140,11 +157,11 @@ class Direction(Enum):
RIGHT = auto() RIGHT = auto()
TAIL_DIRECTION_TO_NOTE_TO_TAIL_VECTOR = { TAIL_DIRECTION_TO_OUTWARDS_VECTOR = {
Direction.UP: NotePosition(0, 1), Direction.UP: Position(0, 1),
Direction.DOWN: NotePosition(0, -1), Direction.DOWN: Position(0, -1),
Direction.LEFT: NotePosition(1, 0), Direction.LEFT: Position(1, 0),
Direction.RIGHT: NotePosition(-1, 0), Direction.RIGHT: Position(-1, 0),
} }