1
0
mirror of synced 2024-11-27 15:40:48 +01:00

Finally upgrade to SQLAlchemy 2.0, fix a bunch of type issues relating to that from Python 3.12.

This commit is contained in:
Jennifer Taylor 2024-09-26 00:04:23 +00:00
parent eb8264c553
commit da5e3fe52f
12 changed files with 97 additions and 94 deletions

View File

@ -97,7 +97,6 @@ class Data:
session_factory = sessionmaker(
bind=config.database.engine,
autoflush=True,
autocommit=True,
)
self.__config = config
self.__session = scoped_session(session_factory)

View File

@ -55,7 +55,7 @@ class APIData(APIProviderInterface, BaseData):
result["name"],
result["token"],
)
for result in cursor
for result in cursor.mappings()
]
def validate_client(self, token: str) -> bool:
@ -70,7 +70,7 @@ class APIData(APIProviderInterface, BaseData):
"""
sql = "SELECT count(*) AS count FROM client WHERE token = :token"
cursor = self.execute(sql, {"token": token})
return cursor.fetchone()["count"] == 1
return cursor.mappings().fetchone()["count"] == 1 # type: ignore
def create_client(self, name: str) -> int:
"""
@ -109,7 +109,7 @@ class APIData(APIProviderInterface, BaseData):
# Couldn't find an entry with this ID
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return Client(
clientid,
result["timestamp"],
@ -159,7 +159,7 @@ class APIData(APIProviderInterface, BaseData):
sql = "SELECT id, timestamp, uri, token, config FROM server ORDER BY timestamp ASC"
cursor = self.execute(sql)
return [format_result(result) for result in cursor]
return [format_result(result) for result in cursor.mappings()]
def create_server(self, uri: str, token: str) -> int:
"""
@ -199,7 +199,7 @@ class APIData(APIProviderInterface, BaseData):
# Couldn't find an entry with this ID
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
allow_stats = (result["config"] & 0x1) == 0
allow_scores = (result["config"] & 0x2) == 0
return Server(

View File

@ -6,7 +6,7 @@ from typing_extensions import Final
from bemani.common import Time
from bemani.data.config import Config
from sqlalchemy.engine import CursorResult # type: ignore
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import scoped_session
from sqlalchemy.sql import text
from sqlalchemy.types import String, Integer
@ -82,10 +82,12 @@ class BaseData:
includes = all(s in lowered for s in write_statement_group)
if includes and not safe_write_operation:
raise Exception("Read-only mode is active!")
return self.__conn.execute(
result = self.__conn.execute(
text(sql),
params if params is not None else {},
)
self.__conn.commit()
return result
def serialize(self, data: Dict[str, Any]) -> str:
"""
@ -141,7 +143,7 @@ class BaseData:
# Couldn't find a user with this session
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["id"]
def _create_session(self, opid: int, optype: str, expiration: int = (30 * 86400)) -> str:

View File

@ -96,7 +96,7 @@ class GameData(BaseData):
# Settings doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"]))
def put_settings(self, game: GameConstants, userid: UserID, settings: Dict[str, Any]) -> None:
@ -158,7 +158,7 @@ class GameData(BaseData):
# score doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"]))
def get_achievements(self, game: GameConstants, userid: UserID) -> List[Achievement]:
@ -182,7 +182,7 @@ class GameData(BaseData):
None,
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]
def put_achievement(
@ -251,7 +251,7 @@ class GameData(BaseData):
# setting doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
retval = ValidatedDict(self.deserialize(result["data"]))
retval["start_time"] = result["start_time"]
retval["end_time"] = result["end_time"]
@ -288,7 +288,7 @@ class GameData(BaseData):
"end_time": result["end_time"],
}
)
for result in cursor
for result in cursor.mappings()
]
def put_time_sensitive_settings(
@ -336,7 +336,7 @@ class GameData(BaseData):
"end_time": end_time,
},
)
for result in cursor:
for result in cursor.mappings():
if result["start_time"] == start_time and result["end_time"] == end_time:
# This is just this event being updated, that's fine.
continue
@ -387,7 +387,7 @@ class GameData(BaseData):
# entry doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"]))
def get_items(self, game: GameConstants, version: int) -> List[Item]:
@ -410,5 +410,5 @@ class GameData(BaseData):
result["id"],
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]

View File

@ -82,7 +82,7 @@ class LobbyData(BaseData):
# Settings doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
data = ValidatedDict(self.deserialize(result["data"]))
data["id"] = result["id"]
data["time"] = result["time"]
@ -119,7 +119,7 @@ class LobbyData(BaseData):
data["time"] = result["time"]
return data
return [(UserID(result["userid"]), format_result(result)) for result in cursor]
return [(UserID(result["userid"]), format_result(result)) for result in cursor.mappings()]
def put_play_session_info(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None:
"""
@ -214,7 +214,7 @@ class LobbyData(BaseData):
# Settings doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
data = ValidatedDict(self.deserialize(result["data"]))
data["id"] = result["id"]
data["time"] = result["time"]
@ -252,7 +252,7 @@ class LobbyData(BaseData):
data["time"] = result["time"]
return data
return [(UserID(result["userid"]), format_result(result)) for result in cursor]
return [(UserID(result["userid"]), format_result(result)) for result in cursor.mappings()]
def put_lobby(self, game: GameConstants, version: int, userid: UserID, data: Dict[str, Any]) -> None:
"""

View File

@ -102,7 +102,7 @@ class MachineData(BaseData):
# Machine doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["pcbid"]
def from_machine_id(self, machine_id: int) -> Optional[str]:
@ -122,7 +122,7 @@ class MachineData(BaseData):
# Machine doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["pcbid"]
def from_userid(self, userid: UserID) -> List[ArcadeID]:
@ -137,7 +137,7 @@ class MachineData(BaseData):
"""
sql = "SELECT arcadeid FROM arcade_owner WHERE userid = :userid"
cursor = self.execute(sql, {"userid": userid})
return [ArcadeID(result["arcadeid"]) for result in cursor]
return [ArcadeID(result["arcadeid"]) for result in cursor.mappings()]
def from_session(self, session: str) -> Optional[ArcadeID]:
"""
@ -173,7 +173,7 @@ class MachineData(BaseData):
# Machine doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return Machine(
result["id"],
pcbid,
@ -212,7 +212,7 @@ class MachineData(BaseData):
result["version"],
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]
def put_machine(self, machine: Machine) -> None:
@ -279,7 +279,7 @@ class MachineData(BaseData):
port = None
else:
# Grab highest port
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
port = result["port"]
if port is not None:
port = port + 1
@ -383,7 +383,7 @@ class MachineData(BaseData):
# Arcade doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
sql = "SELECT userid FROM arcade_owner WHERE arcadeid = :id"
cursor = self.execute(sql, {"id": arcadeid})
@ -396,7 +396,7 @@ class MachineData(BaseData):
result["pref"],
result["area"] or None,
self.deserialize(result["data"]),
[owner["userid"] for owner in cursor],
[owner["userid"] for owner in cursor.mappings()],
)
def put_arcade(self, arcade: Arcade) -> None:
@ -464,7 +464,7 @@ class MachineData(BaseData):
sql = "SELECT userid, arcadeid FROM arcade_owner"
cursor = self.execute(sql)
arcade_to_owners: Dict[int, List[UserID]] = {}
for row in cursor:
for row in cursor.mappings():
arcade = row["arcadeid"]
owner = UserID(row["userid"])
if arcade not in arcade_to_owners:
@ -484,7 +484,7 @@ class MachineData(BaseData):
self.deserialize(result["data"]),
arcade_to_owners.get(result["id"], []),
)
for result in cursor
for result in cursor.mappings()
]
def get_settings(
@ -512,7 +512,7 @@ class MachineData(BaseData):
# Settings doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"]))
def put_settings(
@ -566,7 +566,7 @@ class MachineData(BaseData):
UserID(entry["userid"]),
entry["balance"],
)
for entry in cursor
for entry in cursor.mappings()
]
def create_session(self, arcadeid: ArcadeID, expiration: int = (30 * 86400)) -> str:

View File

@ -107,7 +107,7 @@ class MusicData(BaseData):
if cursor.rowcount != 1:
# music doesn't exist
raise Exception(f"Song {songid} chart {songchart} doesn't exist for game {game} version {version}")
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["id"]
def put_score(
@ -296,7 +296,7 @@ class MusicData(BaseData):
# score doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return Score(
result["scorekey"],
result["songid"],
@ -356,7 +356,7 @@ class MusicData(BaseData):
# score doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return (
UserID(result["userid"]),
Score(
@ -439,7 +439,7 @@ class MusicData(BaseData):
result["plays"],
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]
def get_most_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]:
@ -472,7 +472,7 @@ class MusicData(BaseData):
{"userid": userid, "game": game.value, "version": version, "count": count},
)
return [(result["songid"], result["plays"]) for result in cursor]
return [(result["songid"], result["plays"]) for result in cursor.mappings()]
def get_last_played(self, game: GameConstants, version: int, userid: UserID, count: int) -> List[Tuple[int, int]]:
"""
@ -504,7 +504,7 @@ class MusicData(BaseData):
{"userid": userid, "game": game.value, "version": version, "count": count},
)
return [(result["songid"], result["timestamp"]) for result in cursor]
return [(result["songid"], result["timestamp"]) for result in cursor.mappings()]
def get_hit_chart(
self,
@ -551,7 +551,7 @@ class MusicData(BaseData):
},
)
return [(result["songid"], result["plays"]) for result in cursor]
return [(result["songid"], result["plays"]) for result in cursor.mappings()]
def get_song(
self,
@ -597,7 +597,7 @@ class MusicData(BaseData):
if cursor.rowcount != 1:
# music doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return Song(
game,
version,
@ -647,7 +647,7 @@ class MusicData(BaseData):
result["genre"],
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]
def get_all_scores(
@ -752,7 +752,7 @@ class MusicData(BaseData):
self.deserialize(result["data"]),
),
)
for result in cursor
for result in cursor.mappings()
]
def get_all_records(
@ -863,7 +863,7 @@ class MusicData(BaseData):
self.deserialize(result["data"]),
),
)
for result in cursor
for result in cursor.mappings()
]
def get_attempt_by_key(self, game: GameConstants, version: int, key: int) -> Optional[Tuple[UserID, Attempt]]:
@ -908,7 +908,7 @@ class MusicData(BaseData):
# score doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return (
UserID(result["userid"]),
Attempt(
@ -1019,5 +1019,5 @@ class MusicData(BaseData):
self.deserialize(result["data"]),
),
)
for result in cursor
for result in cursor.mappings()
]

View File

@ -73,7 +73,7 @@ class NetworkData(BaseData):
result["title"],
result["body"],
)
for result in cursor
for result in cursor.mappings()
]
def create_news(self, title: str, body: str) -> int:
@ -107,7 +107,7 @@ class NetworkData(BaseData):
# Couldn't find an entry with this ID
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return News(
newsid,
result["timestamp"],
@ -181,7 +181,7 @@ class NetworkData(BaseData):
# No scheduled work was registered, so time to get going!
return True
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
if schedule == "daily":
# Just look at the day and year, make sure it matches
@ -317,7 +317,7 @@ class NetworkData(BaseData):
result["type"],
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]
def delete_events(self, oldest_event_ts: int) -> None:

View File

@ -194,7 +194,7 @@ class UserData(BaseData):
# Couldn't find a user with this card
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return UserID(result["userid"])
def from_username(self, username: str) -> Optional[UserID]:
@ -213,7 +213,7 @@ class UserData(BaseData):
# Couldn't find this username
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return UserID(result["id"])
def from_refid(self, game: GameConstants, version: int, refid: str) -> Optional[UserID]:
@ -238,7 +238,7 @@ class UserData(BaseData):
# Couldn't find a user with this refid
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return UserID(result["userid"])
def from_extid(self, game: GameConstants, version: int, extid: int) -> Optional[UserID]:
@ -263,7 +263,7 @@ class UserData(BaseData):
# Couldn't find a user with this refid
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return UserID(result["userid"])
def from_session(self, session: str) -> Optional[UserID]:
@ -297,7 +297,7 @@ class UserData(BaseData):
# User doesn't exist, but we have a reference?
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return User(userid, result["username"], result["email"], result["admin"] == 1)
def get_all_users(self) -> List[User]:
@ -316,7 +316,7 @@ class UserData(BaseData):
result["email"],
result["admin"] == 1,
)
for result in cursor
for result in cursor.mappings()
]
def get_all_usernames(self) -> List[str]:
@ -331,7 +331,7 @@ class UserData(BaseData):
"""
sql = "SELECT username FROM user WHERE username is not null"
cursor = self.execute(sql)
return [res["username"] for res in cursor]
return [res["username"] for res in cursor.mappings()]
def get_all_cards(self) -> List[Tuple[str, UserID]]:
"""
@ -342,7 +342,7 @@ class UserData(BaseData):
"""
sql = "SELECT id, userid FROM card"
cursor = self.execute(sql)
return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor]
return [(str(res["id"]).upper(), UserID(res["userid"])) for res in cursor.mappings()]
def get_cards(self, userid: UserID) -> List[str]:
"""
@ -356,7 +356,7 @@ class UserData(BaseData):
"""
sql = "SELECT id FROM card WHERE userid = :userid"
cursor = self.execute(sql, {"userid": userid})
return [str(res["id"]).upper() for res in cursor]
return [str(res["id"]).upper() for res in cursor.mappings()]
def add_card(self, userid: UserID, cardid: str) -> None:
"""
@ -437,7 +437,7 @@ class UserData(BaseData):
# User doesn't exist, but we have a reference?
return False
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return pin == result["pin"]
def update_pin(self, userid: UserID, pin: str) -> None:
@ -468,7 +468,7 @@ class UserData(BaseData):
# User doesn't exist, but we have a reference?
return False
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
passhash = result["password"]
try:
@ -517,7 +517,7 @@ class UserData(BaseData):
# Profile doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return Profile(
game,
version,
@ -577,7 +577,7 @@ class UserData(BaseData):
cursor = self.execute(sql, {"game": game.value, "userids": userids})
profilever: Dict[UserID, int] = {}
for result in cursor:
for result in cursor.mappings():
tuid = UserID(result["userid"])
tver = result["version"]
@ -626,7 +626,7 @@ class UserData(BaseData):
vals["game"] = game.value
cursor = self.execute(sql, vals)
return [(GameConstants(result["game"]), result["version"]) for result in cursor]
return [(GameConstants(result["game"]), result["version"]) for result in cursor.mappings()]
def get_all_profiles(self, game: GameConstants, version: int) -> List[Tuple[UserID, Profile]]:
"""
@ -662,7 +662,7 @@ class UserData(BaseData):
self.deserialize(result["data"]),
),
)
for result in cursor
for result in cursor.mappings()
]
def get_all_players(self, game: GameConstants, version: int) -> List[UserID]:
@ -682,7 +682,7 @@ class UserData(BaseData):
"""
cursor = self.execute(sql, {"game": game.value, "version": version})
return [UserID(result["userid"]) for result in cursor]
return [UserID(result["userid"]) for result in cursor.mappings()]
def get_all_achievements(
self,
@ -732,7 +732,7 @@ class UserData(BaseData):
self.deserialize(result["data"]),
),
)
for result in cursor
for result in cursor.mappings()
]
def put_profile(self, game: GameConstants, version: int, userid: UserID, profile: Profile) -> None:
@ -808,7 +808,7 @@ class UserData(BaseData):
# score doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"]))
def get_achievements(self, game: GameConstants, version: int, userid: UserID) -> List[Achievement]:
@ -834,7 +834,7 @@ class UserData(BaseData):
None,
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]
def put_achievement(
@ -945,7 +945,7 @@ class UserData(BaseData):
result["timestamp"],
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]
def put_time_based_achievement(
@ -1023,7 +1023,7 @@ class UserData(BaseData):
self.deserialize(result["data"]),
),
)
for result in cursor
for result in cursor.mappings()
]
def get_link(
@ -1074,7 +1074,7 @@ class UserData(BaseData):
# score doesn't exist
return None
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return ValidatedDict(self.deserialize(result["data"]))
def get_links(self, game: GameConstants, version: int, userid: UserID) -> List[Link]:
@ -1103,7 +1103,7 @@ class UserData(BaseData):
UserID(result["other_userid"]),
self.deserialize(result["data"]),
)
for result in cursor
for result in cursor.mappings()
]
def put_link(
@ -1196,7 +1196,7 @@ class UserData(BaseData):
sql = "SELECT balance FROM balance WHERE userid = :userid AND arcadeid = :arcadeid"
cursor = self.execute(sql, {"userid": userid, "arcadeid": arcadeid})
if cursor.rowcount == 1:
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["balance"]
else:
return 0
@ -1242,7 +1242,7 @@ class UserData(BaseData):
sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version"
cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version})
if cursor.rowcount == 1:
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["refid"]
else:
return self.create_refid(game, version, userid)
@ -1265,7 +1265,7 @@ class UserData(BaseData):
sql = "SELECT extid FROM extid WHERE userid = :userid AND game = :game"
cursor = self.execute(sql, {"userid": userid, "game": game.value})
if cursor.rowcount == 1:
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["extid"]
else:
return None
@ -1371,7 +1371,7 @@ class UserData(BaseData):
sql = "SELECT refid FROM refid WHERE userid = :userid AND game = :game AND version = :version"
cursor = self.execute(sql, {"userid": userid, "game": game.value, "version": version})
if cursor.rowcount == 1:
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["refid"]
# Shouldn't be possible, but here we are
raise AccountCreationException("Failed to recover lost race refid!")

View File

@ -26,6 +26,9 @@ class FakeCursor:
self.rowcount = len(rows)
self.pos = -1
def mappings(self) -> "FakeCursor":
return self
def fetchone(self) -> Dict[str, Any]:
if len(self.__rows) != 1:
raise Exception(f"Tried to fetch one row and there are {len(self.__rows)} rows!")

View File

@ -4,14 +4,14 @@ import csv
import argparse
import copy
import io
import jaconv # type: ignore
import jaconv
import json
import os
import struct
import xml.etree.ElementTree as ET
from pathlib import Path
from sqlalchemy.engine import CursorResult # type: ignore
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.sql import text
from sqlalchemy.exc import IntegrityError
from typing import Any, Callable, Dict, List, Optional, Tuple
@ -70,17 +70,18 @@ class ImportBase:
self.update = update
self.no_combine = no_combine
self.__config = config
self.__engine = self.__config.database.engine
self.__sessionmanager = sessionmaker(self.__engine)
self.__conn = self.__engine.connect()
self.__session = self.__sessionmanager(bind=self.__conn)
self.__batch = False
# Set up DB connection stuff.
self.__engine = self.__config.database.engine
session_factory = sessionmaker(self.__engine)
self.__conn = scoped_session(session_factory)
def start_batch(self) -> None:
self.__batch = True
def finish_batch(self) -> None:
self.__session.commit()
self.__conn.commit()
self.__batch = False
def execute(self, sql: str, params: Optional[Dict[str, Any]] = None) -> CursorResult:
@ -96,12 +97,12 @@ class ImportBase:
]:
if write_statement in sql.lower():
raise Exception("Read-only mode is active!")
return self.__session.execute(text(sql), params if params is not None else {})
return self.__conn.execute(text(sql), params if params is not None else {})
def remote_music(self, server: str, token: str) -> GlobalMusicData:
api = ReadAPI(server, token)
user = UserData(self.__config, self.__session)
music = MusicData(self.__config, self.__session)
user = UserData(self.__config, self.__conn)
music = MusicData(self.__config, self.__conn)
return GlobalMusicData(api, user, music)
def remote_game(self, server: str, token: str) -> GlobalGameData:
@ -110,7 +111,7 @@ class ImportBase:
def get_next_music_id(self) -> int:
cursor = self.execute("SELECT MAX(id) AS next_id FROM `music`")
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
try:
return result["next_id"] + 1
except TypeError:
@ -138,7 +139,7 @@ class ImportBase:
},
)
if cursor.rowcount != 0:
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["id"]
else:
return None
@ -183,7 +184,7 @@ class ImportBase:
},
)
if cursor.rowcount != 0:
result = cursor.fetchone()
result = cursor.mappings().fetchone() # type: ignore
return result["id"]
else:
return None
@ -369,8 +370,6 @@ class ImportBase:
# Make sure we don't leak connections after finising insertion.
if self.__batch:
raise Exception("Logic error, opened a batch without closing!")
if self.__session is not None:
self.__session.close()
if self.__conn is not None:
self.__conn.close()
self.__conn = None

View File

@ -1,6 +1,6 @@
setuptools
Cython
SQLAlchemy<2.0.0
SQLAlchemy
alembic
PyYAML
Flask