You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.5 KiB
76 lines
2.5 KiB
1 year ago
|
from typing import Any, Union
|
||
|
|
||
|
from sqlalchemy import select, exists
|
||
|
from sqlalchemy.engine import ChunkedIteratorResult
|
||
|
from sqlalchemy.exc import IntegrityError
|
||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
from sqlalchemy.sql import Select, Delete
|
||
|
|
||
|
from db.models.base import BaseModel
|
||
|
|
||
|
|
||
|
class BaseRepository:
|
||
|
def __init__(self, session: AsyncSession) -> None:
|
||
|
self._session = session
|
||
|
|
||
|
async def execute(self, query: Select) -> ChunkedIteratorResult:
|
||
|
return await self._session.execute(query)
|
||
|
|
||
|
async def execute_parametrize(self, query: Union[str, Select], params: dict) -> ChunkedIteratorResult:
|
||
|
return await self._session.execute(statement=query, params=params)
|
||
|
|
||
|
async def one(self, query: Select) -> Any:
|
||
|
result = await self.execute(query)
|
||
|
return result.one()
|
||
|
|
||
|
async def one_or_none(self, query: Select) -> Any:
|
||
|
result = await self.execute(query)
|
||
|
return result.one_or_none()
|
||
|
|
||
|
async def one_val(self, query: Select) -> Any:
|
||
|
result = await self.one(query)
|
||
|
return result[0]
|
||
|
|
||
|
async def one_or_none_val(self, query: Select) -> Any:
|
||
|
result = await self.one_or_none(query)
|
||
|
if not result:
|
||
|
return None
|
||
|
return result[0]
|
||
|
|
||
|
async def add_model(self, model: BaseModel) -> None:
|
||
|
self._session.add(model)
|
||
|
await self._session.commit()
|
||
|
|
||
|
async def refresh_model(self, model: BaseModel):
|
||
|
await self._session.refresh(model)
|
||
|
|
||
|
async def add_model_ignore_exceptions(self, model: BaseModel) -> None:
|
||
|
try:
|
||
|
async with self._session.begin_nested():
|
||
|
self._session.add(model)
|
||
|
except IntegrityError:
|
||
|
pass
|
||
|
|
||
|
async def add_models(self, models: list[BaseModel]) -> None:
|
||
|
for model in models:
|
||
|
await self.add_model(model)
|
||
|
await self._session.commit()
|
||
|
|
||
|
async def delete(self, model: BaseModel) -> None:
|
||
|
await self._session.delete(model)
|
||
|
|
||
|
async def delete_many(self, models: list[BaseModel]) -> None:
|
||
|
for model in models:
|
||
|
await self.delete(model)
|
||
|
|
||
|
async def all(self, query: Select) -> list[Any]:
|
||
|
result = await self.execute(query)
|
||
|
return result.all()
|
||
|
|
||
|
async def all_ones(self, query: Select) -> list[Any]:
|
||
|
result = await self.execute(query)
|
||
|
return [row[0] for row in result.all()]
|
||
|
|
||
|
async def exists(self, query: Select) -> bool:
|
||
|
query = select(exists(query))
|
||
|
return await self.one_val(query)
|