1
0
mirror of synced 2024-12-04 19:17:55 +01:00

Holy fuck it's working

This commit is contained in:
Stepland 2020-07-03 01:21:37 +02:00
parent 5217651551
commit cfc23c2f90
7 changed files with 401 additions and 152 deletions

View File

@ -18,7 +18,7 @@ command_grammar = Grammar(
equals_value = ws "=" ws value
value = value_in_quotes / number
value_in_quotes = '"' quoted_value '"'
quoted_value = ~r"[^\"]+"
quoted_value = ~r"[^\"]*"
number = ~r"\d+(\.\d+)?"
ws = ~r"[\t ]*"
comment = ~r"//.*"

View File

@ -0,0 +1,75 @@
NOTE_SYMBOLS = [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
]
CIRCLE_FREE_SYMBOLS = [
"", # ⎫
"", # ⎪
"", # ⎪
"", # ⎪
"", # ⎬ FULLWIDTH
"", # ⎪
"", # ⎪
"", # ⎪
"", # ⎭
"10", # ⎫
"11", # ⎪
"12", # ⎪
"13", # ⎪
"14", # ⎪
"15", # ⎬ HALFWIDTH
"16", # ⎪
"17", # ⎪
"18", # ⎪
"19", # ⎪
"20", # ⎭
]

View File

@ -1,4 +1,5 @@
from collections import ChainMap, defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from decimal import Decimal
from fractions import Fraction
@ -6,46 +7,40 @@ from io import StringIO
from itertools import chain
from typing import IO, Dict, Iterator, List, Optional, Tuple
from more_itertools import collapse, intersperse, windowed
from sortedcontainers import SortedKeyList, SortedSet
from more_itertools import collapse, intersperse, mark_ends, windowed
from sortedcontainers import SortedDict, SortedKeyList, SortedSet
from jubeatools import __version__
from jubeatools.song import BeatsTime, Chart, LongNote, Metadata, Song, TapNote, Timing
from jubeatools.song import (BeatsTime, Chart, LongNote, Metadata,
NotePosition, Song, TapNote, Timing)
from ..command import dump_command
def dump_mono_column(song: Song) -> Dict[str, IO]:
files = {}
for difname, chart in song.charts.items():
filename = f"{song.metadata.title} [{difname}].txt"
files[filename] = _dump_mono_column_chart(
difname, chart, song.metadata, chart.timing or song.global_timing,
)
return files
def _raise_if_unfit_for_mono_column(chart: Chart, timing: Timing):
if any(isinstance(note, LongNote) for note in chart.notes):
raise ValueError(
"Long notes aren't currently supported when dumping to mono-column"
)
if len(timing.event) < 1:
raise ValueError("No BPM found in file") from None
first_bpm = min(timing.events, key=lambda e: e.time)
if first_bpm.time != 0:
raise ValueError("First BPM event does not happen on beat zero")
from .commons import NOTE_SYMBOLS, CIRCLE_FREE_SYMBOLS
COMMAND_ORDER = SortedSet(
["b", "t", "m", "o", "r", "title", "artist", "lev", "dif", "jacket", "prevpos"]
)
BEATS_TIME_TO_SYMBOL = {
BeatsTime(1, 4) * index: symbol for index, symbol in enumerate("①②③④⑤⑥⑦⑧⑨⑩⑪⑫⑬⑭⑮⑯")
BeatsTime(1, 4) * index: symbol for index, symbol in enumerate(NOTE_SYMBOLS)
}
BEATS_TIME_TO_CIRCLE_FREE = {
BeatsTime(1, 4) * index: symbol for index, symbol in enumerate(CIRCLE_FREE_SYMBOLS)
}
DIRECTION_TO_ARROW = {
NotePosition(-1, 0): "", # U+FF1E : FULLWIDTH GREATER-THAN SIGN
NotePosition( 1, 0): "", # U+FF1C : FULLWIDTH LESS-THAN SIGN
NotePosition( 0, -1): "", # U+2228 : LOGICAL OR
NotePosition( 0, 1): "", # U+2227 : LOGICAL AND
}
DIRECTION_TO_LINE = {
NotePosition(-1, 0): "", # U+2015 : HORIZONTAL BAR
NotePosition( 1, 0): "",
NotePosition( 0, -1): "", # U+FF5C : FULLWIDTH VERTICAL LINE
NotePosition( 0, 1): "",
}
@ -57,18 +52,23 @@ def fraction_to_decimal(frac: Fraction):
@dataclass
class MonoColumnDumpedSection:
commands: Dict[str, Optional[str]] = field(default_factory=dict)
extra_symbols: Dict[BeatsTime, str] = field(default_factory=dict)
symbol_definitions: Dict[BeatsTime, str] = field(default_factory=dict)
notes: List[TapNote] = field(default_factory=list)
def __str__(self) -> str:
def render(
self,
current_beat: BeatsTime,
extra_symbols: Dict[BeatsTime, str],
circle_free: bool = False
) -> str:
blocs = []
commands = list(self._dump_commands())
if commands:
blocs.append(commands)
symbols = list(self._dump_extra_symbols())
symbols = list(self._dump_symbol_definitions())
if symbols:
blocs.append(symbols)
notes = list(self._dump_notes())
notes = list(self._dump_notes(current_beat, extra_symbols, circle_free))
if notes:
blocs.append(notes)
return "\n".join(collapse([intersperse("", blocs), "--"]))
@ -82,50 +82,136 @@ class MonoColumnDumpedSection:
continue
yield dump_command(key, value)
def _dump_extra_symbols(self) -> Iterator[str]:
for time, symbol in self.extra_symbols.items():
def _dump_symbol_definitions(self) -> Iterator[str]:
for time, symbol in self.symbol_definitions.items():
decimal_time = fraction_to_decimal(time)
yield f"*{symbol}:{decimal_time:.6f}"
def _dump_notes(self) -> Iterator[str]:
frames: List[Dict[Tuple[int, int], str]] = []
frame: Dict[Tuple[int, int], str] = {}
def _dump_notes(
self,
current_beat: BeatsTime,
extra_symbols: Dict[BeatsTime, str],
circle_free: bool = False
) -> Iterator[str]:
frames: List[Dict[NotePosition, str]] = []
frame: Dict[NotePosition, str] = {}
symbols: Dict[BeatsTime, str] = ChainMap(
self.extra_symbols, BEATS_TIME_TO_SYMBOL
extra_symbols, BEATS_TIME_TO_SYMBOL
)
for note in self.notes:
pos = note.position.as_tuple()
if pos in frame:
frames.append(frame)
frame = {}
symbol = symbols[note.time]
frame[pos] = symbol
if isinstance(note, LongNote):
needed_positions = set(note.positions_covered())
if needed_positions & frame.keys():
frames.append(frame)
frame = {}
direction = note.tail_direction()
arrow = DIRECTION_TO_ARROW[direction]
line = DIRECTION_TO_LINE[direction]
for is_first, is_last, pos in note.positions_covered():
if is_first:
time_in_section = note.time - current_beat
symbol = symbols[time_in_section]
frame[pos] = symbol
if is_last:
frame[pos] = arrow
else:
frame[pos] = line
elif isinstance(note, TapNote):
if note.position in frame:
frames.append(frame)
frame = {}
time_in_section = note.time - current_beat
symbol = symbols[time_in_section]
frame[note.position] = symbol
elif isinstance(note, LongNoteEnd):
if note.position in frame:
frames.append(frame)
frame = {}
time_in_section = note.time - current_beat
if circle_free:
symbol = CIRCLE_FREE_SYMBOLS[time_in_section]
else:
symbol = symbols[time_in_section]
frame[note.position] = symbol
frames.append(frame)
dumped_frames = map(self._dump_frame, frames)
yield from collapse(intersperse("", dumped_frames))
@staticmethod
def _dump_frame(frame: Dict[Tuple[int, int], str]) -> Iterator[str]:
def _dump_frame(frame: Dict[NotePosition, str]) -> Iterator[str]:
for y in range(4):
yield "".join(frame.get((x, y), "") for x in range(4))
yield "".join(frame.get(NotePosition(x, y), "") for x in range(4))
DIFFICULTIES = {"BSC": 1, "ADV": 2, "EXT": 3}
DEFAULT_EXTRA_SYMBOLS = "" ""
# I put a FUCKTON of extra characters just in case some insane chart uses
# loads of unusual beat divisions
DEFAULT_EXTRA_SYMBOLS = (
""
""
"あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほまみむめもやゆよらりるれろわをん"
"アイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワヲン"
)
def _raise_if_unfit_for_mono_column(chart: Chart, timing: Timing, circle_free: bool = False):
if len(timing.events) < 1:
raise ValueError("No BPM found in file") from None
first_bpm = min(timing.events, key=lambda e: e.time)
if first_bpm.time != 0:
raise ValueError("First BPM event does not happen on beat zero")
if any(not note.tail_is_straight() for note in chart.notes if isinstance(note, LongNote)):
raise ValueError(
"Chart contains diagonal long notes, reprensenting these in"
" mono_column format is not supported by jubeatools"
)
if circle_free and any(
(note.time + note.duration) % BeatsTime(1, 4) != 0
for note in chart.notes if isinstance(note, LongNote)
):
raise ValueError(
"Chart contains long notes whose ending timing aren't"
" representable in #circlefree mode"
)
class SortedDefaultDict(SortedDict):
def __init__(self, default_factory, *args, **kwargs):
super().__init__(*args, **kwargs)
self.default_factory = default_factory
def __missing__(self, key):
value = self.default_factory()
self.__setitem__(key, value)
return value
@dataclass(frozen=True)
class LongNoteEnd:
time: BeatsTime
position: NotePosition
def _dump_mono_column_chart(
difficulty: str, chart: Chart, metadata: Metadata, timing: Timing
) -> IO:
) -> StringIO:
_raise_if_unfit_for_mono_column(chart, timing)
timing_events = sorted(timing.events, key=lambda e: e.time)
notes = SortedKeyList(set(chart.notes), key=lambda n: n.time)
# Add long note ends
for note in chart.notes:
if isinstance(note, LongNote):
notes.add(LongNoteEnd(note.time+note.duration, note.position))
last_event_time = max(timing_events[-1].time, notes[-1].time)
last_measure = last_event_time // 4
sections = defaultdict(
sections = SortedDefaultDict(
MonoColumnDumpedSection,
{BeatsTime(4) * i: MonoColumnDumpedSection() for i in range(last_measure + 1)},
)
@ -142,45 +228,53 @@ def _dump_mono_column_chart(
# Potentially create sub-sections for bpm changes
for event in timing_events:
sections[event.time].commands["t"] = event.BPM
# Frist, Set every single b=... value
section_starts = sorted(sections.keys())
for key, next_key in windowed(section_starts + [None], 2):
# First, Set every single b=... value
for key, next_key in windowed(chain(sections.keys(),[None]), 2):
if next_key is None:
sections[key].commands["b"] = 4
else:
sections[key].commands["b"] = fraction_to_decimal(next_key - key)
# Then, trim all the redundant b=...
last_b = 4
for key in section_starts:
current_b = sections[key].commands["b"]
for section in sections.values():
current_b = section.commands["b"]
if current_b == last_b:
del sections[key].commands["b"]
del section.commands["b"]
else:
last_b = current_b
# Fill sections with notes
for key, next_key in windowed(section_starts + [None], 2):
for key, next_key in windowed(chain(sections.keys(),[None]), 2):
sections[key].notes = list(
notes.irange_key(min_key=key, max_key=next_key, inclusive=(True, False))
)
# Define extra symbols
existing_symbols = {
BeatsTime(1, 4) * index: symbol
for index, symbol in enumerate("①②③④⑤⑥⑦⑧⑨⑩⑪⑫⑬⑭⑮⑯")
}
existing_symbols = deepcopy(BEATS_TIME_TO_SYMBOL)
extra_symbols = iter(DEFAULT_EXTRA_SYMBOLS)
for section_start in section_starts:
section = sections[section_start]
all_extra_symbols = {}
for section_start, section in sections.items():
for note in section.notes:
time_in_section = note.time - section_start
if time_in_section not in existing_symbols:
section.extra_symbols[note.time] = next(extra_symbols)
new_symbol = next(extra_symbols)
section.symbol_definitions[time_in_section] = new_symbol
all_extra_symbols[time_in_section] = new_symbol
existing_symbols[time_in_section] = new_symbol
# Actual output to file
file = StringIO()
file.write(f"// Converted using jubeatools {__version__}\n")
file.write(f"// https://github.com/Stepland/jubeatools\n\n")
for key in section_starts:
section = sections[key]
file.write(str(section) + "\n")
for section_start, section in sections.items():
file.write(section.render(section_start, all_extra_symbols) + "\n")
return file
def dump_mono_column(song: Song) -> Dict[str, IO]:
files = {}
for difname, chart in song.charts.items():
filename = f"{song.metadata.title} [{difname}].txt"
files[filename] = _dump_mono_column_chart(
difname, chart, song.metadata, chart.timing or song.global_timing,
)
return files

View File

@ -1,4 +1,5 @@
import re
import warnings
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
@ -6,7 +7,6 @@ from decimal import Decimal
from enum import Enum
from itertools import product
from typing import Dict, Iterator, List, Set, Tuple
import warnings
import constraint
from parsimonious import Grammar, NodeVisitor, ParseError
@ -28,6 +28,7 @@ from jubeatools.song import (
from ..command import is_command, parse_command
from ..symbol import is_symbol_definition, parse_symbol_definition
from .commons import NOTE_SYMBOLS, CIRCLE_FREE_SYMBOLS
mono_column_chart_line_grammar = Grammar(
r"""
@ -79,8 +80,6 @@ def is_empty_line(line: str) -> bool:
DIFFICULTIES = {1: "BSC", 2: "ADV", 3: "EXT"}
NOTE_SYMBOLS = "①②③④⑤⑥⑦⑧⑨⑩⑪⑫⑬⑭⑮⑯⑳㉑㉒㉓㉔㉕㉖㉗㉘㉙㉚㉛㉜㉝㉞㉟㊱㊲㊳㊴㊵㊶㊷㊸㊹㊺㊻㊼㊽㊾㊿"
SYMBOL_TO_DECIMAL_TIME = {
symbol: Decimal("0.25") * index for index, symbol in enumerate(NOTE_SYMBOLS)
}
@ -184,50 +183,34 @@ LONG_DIRECTION = {
**{c: (0, -1) for c in LONG_ARROW_UP},
}
CIRCLE_FREE_SYMBOLS = {
"", # ⎫
"", # ⎪
"", # ⎪
"", # ⎪
"", # ⎬ FULLWIDTH
"", # ⎪
"", # ⎪
"", # ⎪
"", # ⎭
"10", # ⎫
"11", # ⎪
"12", # ⎪
"13", # ⎪
"14", # ⎪
"15", # ⎬ HALFWIDTH
"16", # ⎪
"17", # ⎪
"18", # ⎪
"19", # ⎪
"20", # ⎭
}
CIRCLE_FREE_TO_DECIMAL_TIME = {
c: Decimal("0.25") * i for i, c in enumerate(CIRCLE_FREE_SYMBOLS)
}
def _distance(a: NotePosition, b: NotePosition) -> float:
return abs(complex(*a.as_tuple())-complex(*b.as_tuple()))
def _long_note_solution_heuristic(solution: Dict[NotePosition, NotePosition]) -> Tuple[int, int, int]:
c = Counter(int(_distance(k,v)) for k,v in solution.items())
def _distance(a: NotePosition, b: NotePosition) -> float:
return abs(complex(*a.as_tuple()) - complex(*b.as_tuple()))
def _long_note_solution_heuristic(
solution: Dict[NotePosition, NotePosition]
) -> Tuple[int, int, int]:
c = Counter(int(_distance(k, v)) for k, v in solution.items())
return (c[3], c[2], c[1])
def _is_simple_solution(solution, domains) -> bool:
return all(
solution[v] == min(domains[v], key=lambda e: _distance(e,v))
solution[v] == min(domains[v], key=lambda e: _distance(e, v))
for v in solution.keys()
)
def decimal_to_beats(current_beat: Decimal, symbol_timing: Decimal) -> BeatsTime:
decimal_time = current_beat + symbol_timing
return BeatsTime(decimal_time).limit_denominator(240)
class MonoColumnParser:
def __init__(self):
self.music = None
@ -393,8 +376,10 @@ class MonoColumnParser:
yield from self._iter_notes()
else:
yield from self._iter_notes_without_longs()
def _iter_blocs(self) -> Iterator[Tuple[Decimal, MonoColumnLoadedSection, List[List[str]]]]:
def _iter_blocs(
self,
) -> Iterator[Tuple[Decimal, MonoColumnLoadedSection, List[List[str]]]]:
current_beat = Decimal(0)
for section in self.sections:
for bloc in section.blocs():
@ -405,7 +390,7 @@ class MonoColumnParser:
unfinished_longs: Dict[NotePosition, UnfinishedLongNote] = {}
for current_beat, section, bloc in self._iter_blocs():
should_skip: Set[NotePosition] = set()
# 1/3 : look for ends to unfinished long notes
for pos, unfinished_long in unfinished_longs.items():
x, y = pos.as_tuple()
@ -427,10 +412,9 @@ class MonoColumnParser:
symbol_time = section.symbols[symbol]
note_time = decimal_to_beats(current_beat, symbol_time)
yield unfinished_long.ends_at(note_time)
unfinished_longs = {
k: unfinished_longs[k]
for k in unfinished_longs.keys() - should_skip
k: unfinished_longs[k] for k in unfinished_longs.keys() - should_skip
}
# 2/3 : look for new long notes starting on this bloc
@ -468,14 +452,17 @@ class MonoColumnParser:
solutions = problem.getSolutions()
if not solutions:
raise SyntaxError(
"Invalid long note arrow pattern in bloc :\n"+
"\n".join(''.join(line) for line in bloc)
"Invalid long note arrow pattern in bloc :\n"
+ "\n".join("".join(line) for line in bloc)
)
solution = min(solutions, key=_long_note_solution_heuristic)
if len(solutions) > 1 and not _is_simple_solution(solution, arrow_to_note_candidates):
if len(solutions) > 1 and not _is_simple_solution(
solution, arrow_to_note_candidates
):
warnings.warn(
"Ambiguous arrow pattern in bloc :\n"+
"\n".join(''.join(line) for line in bloc)+"\n"
"Ambiguous arrow pattern in bloc :\n"
+ "\n".join("".join(line) for line in bloc)
+ "\n"
"The resulting long notes might not be what you expect"
)
for arrow_pos, note_pos in solution.items():
@ -485,11 +472,9 @@ class MonoColumnParser:
symbol_time = section.symbols[symbol]
note_time = decimal_to_beats(current_beat, symbol_time)
unfinished_longs[note_pos] = UnfinishedLongNote(
time=note_time,
position=note_pos,
tail_tip=arrow_pos,
time=note_time, position=note_pos, tail_tip=arrow_pos,
)
# 3/3 : find regular notes
for y, x in product(range(4), range(4)):
position = NotePosition(x, y)
@ -500,7 +485,6 @@ class MonoColumnParser:
symbol_time = section.symbols[symbol]
note_time = decimal_to_beats(current_beat, symbol_time)
yield TapNote(note_time, position)
def _iter_notes_without_longs(self) -> Iterator[TapNote]:
current_beat = Decimal(0)
@ -544,7 +528,9 @@ def load_mono_column(path: Path) -> Song:
)
charts = {
state.difficulty: Chart(
level=state.level, timing=timing, notes=list(state.notes())
level=state.level,
timing=timing,
notes=sorted(state.notes(), key=lambda n: (n.time, n.position)),
)
}
return Song(metadata=metadata, charts=charts)

View File

@ -221,6 +221,7 @@ def test_irregular_beats_per_frame_2():
]
compare_chart_notes(chart, expected)
def test_long_notes():
chart = """
#holdbyarrow=1
@ -240,11 +241,12 @@ def test_long_notes():
time=BeatsTime(0),
position=NotePosition(0, 0),
duration=BeatsTime(4),
tail_tip=NotePosition(3,0),
tail_tip=NotePosition(3, 0),
)
]
compare_chart_notes(chart, expected)
def test_long_notes_ambiguous_case():
chart = """
#holdbyarrow=1
@ -261,14 +263,12 @@ def test_long_notes_ambiguous_case():
"""
expected = [
LongNote(BeatsTime(0), NotePosition(x, y), BeatsTime(4), NotePosition(tx, ty))
for (x, y), (tx, ty) in [
((0, 0), (2, 0)),
((1, 0), (3, 0)),
]
for (x, y), (tx, ty) in [((0, 0), (2, 0)), ((1, 0), (3, 0)),]
]
with pytest.warns(UserWarning):
compare_chart_notes(chart, expected)
@pytest.mark.filterwarnings("error")
def test_long_notes_simple_solution_no_warning():
chart = """
@ -286,10 +286,7 @@ def test_long_notes_simple_solution_no_warning():
"""
expected = [
LongNote(BeatsTime(0), NotePosition(x, y), BeatsTime(4), NotePosition(tx, ty))
for (x, y), (tx, ty) in [
((1, 1), (0, 1)),
((2, 1), (3, 1)),
]
for (x, y), (tx, ty) in [((1, 1), (0, 1)), ((2, 1), (3, 1)),]
]
compare_chart_notes(chart, expected)
@ -310,10 +307,27 @@ def test_long_notes_complex_case():
"""
expected = [
LongNote(BeatsTime(0), NotePosition(x, y), BeatsTime(4), NotePosition(tx, ty))
for (x, y), (tx, ty) in [
((1, 3), (1, 2)),
((2, 3), (2, 1)),
((3, 3), (0, 3)),
]
for (x, y), (tx, ty) in [((1, 3), (1, 2)), ((2, 3), (2, 1)), ((3, 3), (0, 3)),]
]
compare_chart_notes(chart, expected)
def test_circle_free():
chart = """
#holdbyarrow=1
#circlefree=1
--
13
--
"""
expected = [
LongNote(BeatsTime(0), NotePosition(3, 3), BeatsTime(7), NotePosition(0, 3))
]
compare_chart_notes(chart, expected)

View File

@ -0,0 +1,47 @@
from decimal import Decimal
from hypothesis import given
import hypothesis.strategies as st
from jubeatools.song import BeatsTime, LongNote, NotePosition, TapNote, Chart, Timing, BPMEvent, SecondsTime, BeatsTime, Metadata
from ..mono_column.load import MonoColumnParser
from ..mono_column.dump import _dump_mono_column_chart
@st.composite
def beat_time(draw):
numerator = draw(st.integers(min_value=0, max_value=240*4*10))
return BeatsTime(numerator, 240)
@st.composite
def note_position(draw):
x = draw(st.integers(min_value=0, max_value=3))
y = draw(st.integers(min_value=0, max_value=3))
return NotePosition(x, y)
@st.composite
def tap_note(draw):
time = draw(beat_time())
position = draw(note_position())
return TapNote(time, position)
@given(st.sets(tap_note(), min_size=1, max_size=2000))
def test_tap_notes(notes):
timing = Timing(
events=[BPMEvent(BeatsTime(0), Decimal(120))],
beat_zero_offset=SecondsTime(0)
)
chart = Chart(
level=0,
timing=timing,
notes=sorted(notes, key=lambda n: (n.time, n.position))
)
metadata = Metadata("", "", "", "")
string_io = _dump_mono_column_chart("", chart, metadata, timing)
chart = string_io.getvalue()
parser = MonoColumnParser()
for line in chart.split("\n"):
parser.load_line(line)
actual = set(parser.notes())
assert notes == actual

View File

@ -11,7 +11,8 @@ from collections import UserList, namedtuple
from dataclasses import dataclass, field
from decimal import Decimal
from fractions import Fraction
from typing import List, Mapping, Optional, Type, Union
from functools import wraps
from typing import Iterator, List, Mapping, Optional, Type, Union
from multidict import MultiDict
from path import Path
@ -26,11 +27,33 @@ def beats_time_from_ticks(ticks: int, resolution: int) -> BeatsTime:
return BeatsTime(ticks, resolution)
def convert_other(f):
@wraps(f)
def wrapped(self, other):
if isinstance(other, NotePosition):
other_note = other
else:
try:
other_note = NotePosition(*other)
except Exception:
raise ValueError(
f"Invalid type for {f.__name__} with NotePosition : {type(other).__name__}"
)
return f(self, other_note)
return wrapped
@dataclass(frozen=True)
class NotePosition:
x: int
y: int
def __iter__(self):
yield self.x
yield self.y
@property
def index(self):
return self.x + 4 * self.y
@ -44,30 +67,18 @@ class NotePosition:
raise ValueError(f"Note position index out of range : {index}")
return cls(x=index % 4, y=index // 4)
@convert_other
def __lt__(self, other):
if not isinstance(other, NotePosition):
try:
x, y = other
except ValueError:
raise ValueError(f"Cannot add NotePosition with {type(other).__name__}")
else:
x = other.x
y = other.y
return self.as_tuple() < (x, y)
return self.as_tuple() < other.as_tuple()
@convert_other
def __add__(self, other):
if not isinstance(other, NotePosition):
try:
x, y = other
except ValueError:
raise ValueError(f"Cannot add NotePosition with {type(other).__name__}")
else:
x = other.x
y = other.y
return NotePosition(self.x + other.x, self.y + other.y)
return NotePosition(self.x+x, self.y+y)
@convert_other
def __sub__(self, other):
return NotePosition(self.x - other.x, self.y - other.y)
@dataclass(frozen=True)
@ -86,6 +97,28 @@ class LongNote:
def __hash__(self):
return hash((self.time, self.position))
def tail_is_straight(self) -> bool:
return (self.position.x == self.tail_tip.x) or (
self.position.y == self.tail_tip.y
)
def tail_direction(self) -> NotePosition:
if not self.tail_is_straight():
raise ValueError("Can't get tail direction when it's not straight")
diff = self.tail_tip - self.position
if diff.x == 0:
diff.y //= abs(diff.y)
else:
diff.x //= abs(diff.x)
return diff
def positions_covered(self) -> Iterator[NotePosition]:
direction = self.tail_direction()
position = self.position
while position != self.tail_tip:
yield position
position = position + direction
@dataclass
class BPMEvent: