1
0
mirror of synced 2024-12-04 18:17:55 +01:00
artemis/core/data/schema/user.py

111 lines
3.5 KiB
Python
Raw Normal View History

from typing import Optional, List
from sqlalchemy import Table, Column
from sqlalchemy.types import Integer, String, TIMESTAMP
from sqlalchemy.sql import func
2023-02-19 21:40:25 +01:00
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.sql import func, select
2023-02-19 21:40:25 +01:00
from sqlalchemy.engine import Row
import bcrypt
from core.data.schema.base import BaseData, metadata
aime_user = Table(
"aime_user",
metadata,
Column("id", Integer, nullable=False, primary_key=True, autoincrement=True),
Column("username", String(25), unique=True),
Column("email", String(255), unique=True),
Column("password", String(255)),
2023-03-09 17:38:58 +01:00
Column("permissions", Integer),
Column("created_date", TIMESTAMP, server_default=func.now()),
Column("last_login_date", TIMESTAMP, onupdate=func.now()),
Column("suspend_expire_time", TIMESTAMP),
2023-03-09 17:38:58 +01:00
mysql_charset="utf8mb4",
)
class UserData(BaseData):
2024-01-09 20:42:17 +01:00
async def create_user(
2023-03-09 17:38:58 +01:00
self,
id: int = None,
username: str = None,
email: str = None,
password: str = None,
permission: int = 1,
) -> Optional[int]:
2023-02-19 21:40:25 +01:00
if id is None:
sql = insert(aime_user).values(
2023-03-09 17:38:58 +01:00
username=username,
email=email,
password=password,
permissions=permission,
2023-02-19 21:40:25 +01:00
)
else:
2023-02-19 21:40:25 +01:00
sql = insert(aime_user).values(
2023-03-09 17:38:58 +01:00
id=id,
username=username,
email=email,
password=password,
permissions=permission,
2023-02-19 21:40:25 +01:00
)
2023-02-19 21:40:25 +01:00
conflict = sql.on_duplicate_key_update(
2023-03-09 17:38:58 +01:00
username=username, email=email, password=password, permissions=permission
2023-02-19 21:40:25 +01:00
)
2023-03-09 17:38:58 +01:00
2024-01-09 20:42:17 +01:00
result = await self.execute(conflict)
2023-03-09 17:38:58 +01:00
if result is None:
return None
return result.lastrowid
2023-03-09 17:38:58 +01:00
2024-01-09 20:42:17 +01:00
async def get_user(self, user_id: int) -> Optional[Row]:
sql = select(aime_user).where(aime_user.c.id == user_id)
2024-01-09 20:42:17 +01:00
result = await self.execute(sql)
2023-03-09 17:38:58 +01:00
if result is None:
return False
2023-02-19 21:40:25 +01:00
return result.fetchone()
2023-03-09 17:38:58 +01:00
async def check_password(self, user_id: int, passwd: bytes = None) -> bool:
usr = await self.get_user(user_id)
2023-03-09 17:38:58 +01:00
if usr is None:
return False
2023-02-19 21:40:25 +01:00
2023-03-09 17:38:58 +01:00
if usr["password"] is None:
return False
if passwd is None or not passwd:
return False
2023-02-19 21:40:25 +01:00
2023-03-09 17:38:58 +01:00
return bcrypt.checkpw(passwd, usr["password"].encode())
2023-02-19 21:40:25 +01:00
2024-01-09 20:42:17 +01:00
async def delete_user(self, user_id: int) -> None:
sql = aime_user.delete(aime_user.c.id == user_id)
2024-01-09 20:42:17 +01:00
result = await self.execute(sql)
2023-03-09 17:38:58 +01:00
if result is None:
self.logger.error(f"Failed to delete user with id {user_id}")
2024-01-09 20:42:17 +01:00
async def get_unregistered_users(self) -> List[Row]:
"""
Returns a list of users who have not registered with the webui. They may or may not have cards.
"""
sql = select(aime_user).where(aime_user.c.password == None)
2024-01-09 20:42:17 +01:00
result = await self.execute(sql)
2023-03-09 17:38:58 +01:00
if result is None:
return None
2023-03-09 17:38:58 +01:00
return result.fetchall()
2024-01-09 20:42:17 +01:00
async def find_user_by_email(self, email: str) -> Row:
sql = select(aime_user).where(aime_user.c.email == email)
2024-01-09 20:42:17 +01:00
result = await self.execute(sql)
if result is None:
return False
return result.fetchone()
2024-01-09 20:42:17 +01:00
async def find_user_by_username(self, username: str) -> List[Row]:
sql = aime_user.select(aime_user.c.username.like(f"%{username}%"))
2024-01-09 20:42:17 +01:00
result = await self.execute(sql)
if result is None:
return False
return result.fetchall()