close dangerous api methods under api auth (#78)

* close dangerous api methods under api auth

* rename access_token method
This commit is contained in:
Dmitry Afanasyev
2024-01-07 20:06:02 +03:00
committed by GitHub
parent 8266342214
commit de55d873f9
12 changed files with 210 additions and 18 deletions

View File

@@ -1,3 +1,4 @@
import uuid
from datetime import datetime
from sqlalchemy import INTEGER, TIMESTAMP, VARCHAR, Boolean, ForeignKey, String
@@ -26,7 +27,7 @@ class User(Base):
"UserQuestionCount",
primaryjoin="UserQuestionCount.user_id == User.id",
backref="user",
lazy="selectin",
lazy="noload",
uselist=False,
cascade="delete",
)
@@ -68,11 +69,25 @@ class AccessToken(Base):
__tablename__ = "access_token" # type: ignore[assignment]
user_id = mapped_column(INTEGER, ForeignKey("users.id", ondelete="cascade"), nullable=False)
token: Mapped[str] = mapped_column(String(length=42), primary_key=True)
token: Mapped[str] = mapped_column(String(length=42), primary_key=True, default=lambda: str(uuid.uuid4()))
created_at: Mapped[datetime] = mapped_column(
TIMESTAMP(timezone=True), index=True, nullable=False, default=datetime.now
)
user: Mapped["User"] = relationship(
"User",
backref="access_token",
lazy="selectin",
uselist=False,
cascade="expunge",
)
@property
def username(self) -> str:
if self.user:
return self.user.username
return ""
class UserQuestionCount(Base):
__tablename__ = "user_question_count" # type: ignore[assignment]

View File

@@ -5,7 +5,7 @@ from sqlalchemy.dialects.sqlite import insert
from sqlalchemy.orm import load_only
from core.auth.dto import UserIsBannedDTO
from core.auth.models.users import User, UserQuestionCount
from core.auth.models.users import AccessToken, User, UserQuestionCount
from infra.database.db_adapter import Database
@@ -76,3 +76,10 @@ class UserRepository:
async with self.db.session() as session:
await session.execute(query)
async def get_user_access_token(self, username: str | None) -> str | None:
query = select(AccessToken.token).join(AccessToken.user).where(User.username == username)
async with self.db.session() as session:
result = await session.execute(query)
return result.scalar()

View File

@@ -62,6 +62,9 @@ class UserService:
async def check_user_is_banned(self, user_id: int) -> UserIsBannedDTO:
return await self.repository.check_user_is_banned(user_id)
async def get_user_access_token_by_username(self, username: str | None) -> str | None:
return await self.repository.get_user_access_token(username)
def check_user_is_banned(func: Any) -> Any:
@wraps(func)