Coverage for packages/kwai-core/src/kwai_core/db/table_row.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module that defines some dataclasses that can be used as data transfer objects.""" 

2 

3from dataclasses import dataclass, fields 

4from typing import ClassVar, Self 

5 

6from sql_smith.functions import alias 

7from sql_smith.functions import field as sql_field 

8from sql_smith.functions import search as sql_search 

9from sql_smith.interfaces import ExpressionInterface 

10 

11from kwai_core.db.database import Record 

12 

13 

14def _validate_dataclass(t): 

15 """Check if all fields contains data with the correct type. 

16 

17 A ValueError will be raised when the data for a given field contains data with 

18 an invalid type. 

19 The None value will be ignored because sometimes a value can be None when 

20 the TableRow dataclass is used in a join that doesn't match. 

21 """ 

22 for f in fields(t): 

23 value = getattr(t, f.name) 

24 if value is None: 

25 continue 

26 if not isinstance(value, f.type): 

27 raise ValueError(f"{f.name}({value}) of {t} should be of type {f.type}!") 

28 

29 

30@dataclass(frozen=True, kw_only=True, slots=True) 

31class TableRow: 

32 """A data transfer object for a row of one table. 

33 

34 The derived class must be a dataclass. 

35 

36 Note: 

37 The derived class is also the ideal place to act as builder for an entity. 

38 """ 

39 

40 __table_name__: ClassVar[str] 

41 

42 @classmethod 

43 def get_column_alias(cls, name: str, prefix: str | None = None) -> str: 

44 """Return the alias for a column.""" 

45 prefix = prefix or cls.__table_name__ 

46 return f"{prefix}_{name}" 

47 

48 @classmethod 

49 def get_aliases(cls, prefix: str | None = None) -> list[ExpressionInterface]: 

50 """Return aliases for all the fields of the dataclass.""" 

51 result = [] 

52 for field in fields(cls): 

53 result.append( 

54 alias( 

55 f"{cls.__table_name__}.{field.name}", 

56 cls.get_column_alias(field.name, prefix), 

57 ) 

58 ) 

59 return result 

60 

61 @classmethod 

62 def column(cls, column_name: str) -> str: 

63 """Return the column prefixed with the table name.""" 

64 return f"{cls.__table_name__}.{column_name}" 

65 

66 @classmethod 

67 def field(cls, column_name: str): 

68 """Call sql-smith field with the given column. 

69 

70 short-cut for: field(table.table_name + '.' + column_name) 

71 """ 

72 return sql_field(cls.column(column_name)) 

73 

74 @classmethod 

75 def search(cls, column_name: str): 

76 """Call sql-smith search with the given column. 

77 

78 short-cut for: search(table.table_name + '.' + column_name) 

79 """ 

80 return sql_search(cls.column(column_name)) 

81 

82 @classmethod 

83 def map(cls, row: Record, prefix: str | None = None) -> Self: 

84 """Map the data of a row to the dataclass. 

85 

86 A ValueError will be raised when a field contains data with the wrong type. 

87 """ 

88 values = {} 

89 for field in fields(cls): 

90 column_alias = cls.get_column_alias(field.name, prefix) 

91 values[field.name] = row.get(column_alias) 

92 

93 instance = cls(**values) # noqa 

94 _validate_dataclass(instance) 

95 

96 return instance 

97 

98 

99@dataclass(frozen=True, kw_only=True, slots=True) 

100class JoinedTableRow: 

101 """A data transfer object for data from multiple tables. 

102 

103 Each field of the dataclass will represent a table. The name of the field 

104 will be used as prefix for creating an alias for each column of the associated 

105 table. 

106 

107 The derived class must be a dataclass. 

108 

109 Note: 

110 The derived class is also the ideal place to act as builder for an entity. 

111 """ 

112 

113 @classmethod 

114 def get_aliases(cls) -> list[ExpressionInterface]: 

115 """Return fields of all the TableRow dataclasses as aliases. 

116 

117 The name of the field will be used as prefix for the alias. 

118 """ 

119 assert len(fields(cls)) > 0, "There are no fields. Is this a dataclass?" 

120 

121 aliases = [] 

122 for field in fields(cls): 

123 aliases.extend(field.type.get_aliases(field.name)) 

124 return aliases 

125 

126 @classmethod 

127 def map(cls, row: Record) -> Self: 

128 """Map all fields of this dataclass to the TableRow dataclasses.""" 

129 tables = {} 

130 for table_field in fields(cls): 

131 tables[table_field.name] = table_field.type.map(row, table_field.name) 

132 return cls(**tables) # noqa 

133 

134 

135def unwrap[T](val: T | None) -> T: 

136 """Assert when the value is None.""" 

137 assert val is not None 

138 return val