1
0
mirror of synced 2024-12-04 19:17:55 +01:00

Split NotePosition into two classes, one for actual button positions (bounds checked) and the other for 2d integer vectors (no bounds check)

This commit is contained in:
Stepland 2021-05-09 17:04:34 +02:00
parent 25b6529245
commit 4d11287be2
3 changed files with 65 additions and 41 deletions

View File

@ -230,7 +230,12 @@ def find_long_note_candidates(
note_candidates = set()
𝛿pos = LONG_DIRECTION[symbol]
candidate = NotePosition(x, y) + 𝛿pos
while 0 <= candidate.x < 4 and 0 <= candidate.y < 4:
while True:
try:
candidate = NotePosition.from_raw_position(candidate)
except ValueError:
break
if candidate not in should_skip:
new_symbol = bloc[candidate.y][candidate.x]
if new_symbol in note_symbols:

View File

@ -141,7 +141,9 @@ def _load_memon_note_v0(
time = jbt.beats_time_from_ticks(ticks=note["t"], resolution=resolution)
if note["l"] > 0:
duration = jbt.beats_time_from_ticks(ticks=note["l"], resolution=resolution)
tail_tip = position + jbt.NotePosition(*P_VALUE_TO_X_Y_OFFSET[note["p"]])
p_value = note["p"]
𝛿x, 𝛿y = P_VALUE_TO_X_Y_OFFSET[p_value]
tail_tip = jbt.NotePosition.from_raw_position(position + jbt.Position(𝛿x, 𝛿y))
return jbt.LongNote(time, position, duration, tail_tip)
else:
return jbt.TapNote(time, position)

View File

@ -1,21 +1,19 @@
"""
Provides the Song class, the central model for chartsets
"""Provides the Song class, the central model for chartsets
Every input format is converted to a Song instance
Every output format is created from a Song instance
Most timing-related info is stored as beat fractions,
otherwise a decimal number of seconds is used
"""
Most timing-related info is stored as beat fractions, otherwise a decimal
number of seconds is used"""
from __future__ import annotations
from dataclasses import astuple, dataclass, field
from decimal import Decimal
from enum import Enum, auto
from fractions import Fraction
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Iterator, List, Mapping, Optional, Type, Union, Tuple
from enum import Enum, auto
from typing import Any, Callable, Iterator, List, Mapping, Optional, Tuple, Union
from multidict import MultiDict
@ -30,27 +28,48 @@ def beats_time_from_ticks(ticks: int, resolution: int) -> BeatsTime:
def convert_other(
f: Callable[[NotePosition, NotePosition], NotePosition]
) -> Callable[[NotePosition, Any], NotePosition]:
f: Callable[[Position, Position], Position]
) -> Callable[[Position, Any], Position]:
@wraps(f)
def wrapped(self: NotePosition, other: Any) -> NotePosition:
if isinstance(other, NotePosition):
other_note = other
def wrapped(self: Position, other: Any) -> Position:
if isinstance(other, Position):
other_pos = other
else:
try:
other_note = NotePosition(*other)
other_pos = Position(*other)
except Exception:
raise ValueError(
f"Invalid type for {f.__name__} with NotePosition : {type(other).__name__}"
)
raise ValueError(f"Could not convert {type(other)} to a Position")
return f(self, other_note)
return f(self, other_pos)
return wrapped
@dataclass(frozen=True, order=True)
class Position:
"""2D integer vector"""
x: int
y: int
def __iter__(self) -> Iterator[int]:
yield from astuple(self)
@convert_other
def __add__(self, other: Position) -> Position:
return Position(self.x + other.x, self.y + other.y)
@convert_other
def __sub__(self, other: Position) -> Position:
return Position(self.x - other.x, self.y - other.y)
def __mul__(self, other: int) -> Position:
return Position(self.x * other, self.y * other)
__rmul__ = __mul__
@dataclass(frozen=True, order=True)
class NotePosition:
class NotePosition(Position):
"""A specific square on the controller. (0, 0) is the top-left button, x
goes right, y goes down.
@ -60,13 +79,15 @@ class NotePosition:
1
2
3
The main difference with Position is that x and y MUST be between 0 and 3
"""
x: int
y: int
def __iter__(self) -> Iterator[int]:
yield from astuple(self)
def __post_init__(self) -> None:
if not 0 <= self.x < 4:
raise ValueError("x out of [0, 3] range")
if not 0 <= self.y < 4:
raise ValueError("y out of [0, 3] range")
@property
def index(self) -> int:
@ -78,14 +99,10 @@ class NotePosition:
raise ValueError(f"Note position index out of range : {index}")
return cls(x=index % 4, y=index // 4)
@convert_other
def __add__(self, other: NotePosition) -> NotePosition:
return NotePosition(self.x + other.x, self.y + other.y)
@convert_other
def __sub__(self, other: NotePosition) -> NotePosition:
return NotePosition(self.x - other.x, self.y - other.y)
@classmethod
def from_raw_position(cls, pos: Position) -> NotePosition:
return cls(x=pos.x, y=pos.y)
@dataclass(frozen=True, unsafe_hash=True)
@ -110,7 +127,7 @@ class LongNote:
def tail_direction(self) -> Direction:
"""Direction in which the tail moves"""
if not self.has_straight_tail():
raise ValueError("Can't get tail direction when it's not straight")
raise ValueError("Can't get cardinal direction of diagonal long note")
if self.tail_tip.x == self.position.x:
if self.tail_tip.y > self.position.y:
@ -125,11 +142,11 @@ class LongNote:
def positions_covered(self) -> Iterator[NotePosition]:
direction = self.tail_direction()
step = TAIL_DIRECTION_TO_NOTE_TO_TAIL_VECTOR[direction]
step = TAIL_DIRECTION_TO_OUTWARDS_VECTOR[direction]
position = self.position
yield position
while position != self.tail_tip:
position = position + step
position = NotePosition.from_raw_position(position + step)
yield position
@ -140,11 +157,11 @@ class Direction(Enum):
RIGHT = auto()
TAIL_DIRECTION_TO_NOTE_TO_TAIL_VECTOR = {
Direction.UP: NotePosition(0, 1),
Direction.DOWN: NotePosition(0, -1),
Direction.LEFT: NotePosition(1, 0),
Direction.RIGHT: NotePosition(-1, 0),
TAIL_DIRECTION_TO_OUTWARDS_VECTOR = {
Direction.UP: Position(0, 1),
Direction.DOWN: Position(0, -1),
Direction.LEFT: Position(1, 0),
Direction.RIGHT: Position(-1, 0),
}