Coverage for bc/kwai-bc-club/src/kwai_bc_club/repositories/contact_db_repository.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that defines a contact repository for a database.""" 

2 

3from dataclasses import dataclass 

4 

5from kwai_core.db.database import Database 

6from kwai_core.db.table_row import JoinedTableRow 

7from sql_smith.functions import on 

8 

9from kwai_bc_club.domain.contact import ContactEntity, ContactIdentifier 

10from kwai_bc_club.repositories._tables import ContactRow, CountryRow 

11from kwai_bc_club.repositories.contact_repository import ( 

12 ContactNotFoundException, 

13 ContactRepository, 

14) 

15 

16 

17@dataclass(kw_only=True, frozen=True, slots=True) 

18class ContactQueryRow(JoinedTableRow): 

19 """A data transfer object for a Contact query.""" 

20 

21 contact: ContactRow 

22 country: CountryRow 

23 

24 def create_entity(self) -> ContactEntity: 

25 """Create a Contact entity from a row.""" 

26 return self.contact.create_entity(self.country.create_country()) 

27 

28 

29class ContactDbRepository(ContactRepository): 

30 """A contact repository for a database.""" 

31 

32 def __init__(self, database: Database): 

33 self._database = database 

34 

35 async def create(self, contact: ContactEntity) -> ContactEntity: 

36 new_contact_id = await self._database.insert( 

37 ContactRow.__table_name__, ContactRow.persist(contact) 

38 ) 

39 return contact.set_id(ContactIdentifier(new_contact_id)) 

40 

41 async def delete(self, contact: ContactEntity): 

42 await self._database.delete(contact.id.value, ContactRow.__table_name__) 

43 

44 async def update(self, contact: ContactEntity): 

45 await self._database.update( 

46 contact.id.value, ContactRow.__table_name__, ContactRow.persist(contact) 

47 ) 

48 

49 async def get(self, id_: ContactIdentifier) -> ContactEntity: 

50 query = Database.create_query_factory().select() 

51 query.from_(ContactRow.__table_name__).columns( 

52 *ContactQueryRow.get_aliases() 

53 ).inner_join( 

54 CountryRow.__table_name__, 

55 on(CountryRow.column("id"), ContactRow.column("country_id")), 

56 ).where(ContactRow.field("id").eq(id_.value)) 

57 row = await self._database.fetch_one(query) 

58 if row: 

59 return ContactQueryRow.map(row).create_entity() 

60 

61 raise ContactNotFoundException(f"Contact with {id_} not found")