Coverage for packages/kwai-core/src/kwai_core/db/table_row.py: 100%
62 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1"""Module that defines some dataclasses that can be used as data transfer objects."""
3from dataclasses import dataclass, fields
4from typing import ClassVar, Self
6from sql_smith.functions import alias
7from sql_smith.functions import field as sql_field
8from sql_smith.functions import search as sql_search
9from sql_smith.interfaces import ExpressionInterface
11from kwai_core.db.database import Record
14def _validate_dataclass(t):
15 """Check if all fields contains data with the correct type.
17 A ValueError will be raised when the data for a given field contains data with
18 an invalid type.
19 The None value will be ignored because sometimes a value can be None when
20 the TableRow dataclass is used in a join that doesn't match.
21 """
22 for f in fields(t):
23 value = getattr(t, f.name)
24 if value is None:
25 continue
26 if not isinstance(value, f.type):
27 raise ValueError(f"{f.name}({value}) of {t} should be of type {f.type}!")
30@dataclass(frozen=True, kw_only=True, slots=True)
31class TableRow:
32 """A data transfer object for a row of one table.
34 The derived class must be a dataclass.
36 Note:
37 The derived class is also the ideal place to act as builder for an entity.
38 """
40 __table_name__: ClassVar[str]
42 @classmethod
43 def get_column_alias(cls, name: str, prefix: str | None = None) -> str:
44 """Return the alias for a column."""
45 prefix = prefix or cls.__table_name__
46 return f"{prefix}_{name}"
48 @classmethod
49 def get_aliases(cls, prefix: str | None = None) -> list[ExpressionInterface]:
50 """Return aliases for all the fields of the dataclass."""
51 result = []
52 for field in fields(cls):
53 result.append(
54 alias(
55 f"{cls.__table_name__}.{field.name}",
56 cls.get_column_alias(field.name, prefix),
57 )
58 )
59 return result
61 @classmethod
62 def column(cls, column_name: str) -> str:
63 """Return the column prefixed with the table name."""
64 return f"{cls.__table_name__}.{column_name}"
66 @classmethod
67 def field(cls, column_name: str):
68 """Call sql-smith field with the given column.
70 short-cut for: field(table.table_name + '.' + column_name)
71 """
72 return sql_field(cls.column(column_name))
74 @classmethod
75 def search(cls, column_name: str):
76 """Call sql-smith search with the given column.
78 short-cut for: search(table.table_name + '.' + column_name)
79 """
80 return sql_search(cls.column(column_name))
82 @classmethod
83 def map(cls, row: Record, prefix: str | None = None) -> Self:
84 """Map the data of a row to the dataclass.
86 A ValueError will be raised when a field contains data with the wrong type.
87 """
88 values = {}
89 for field in fields(cls):
90 column_alias = cls.get_column_alias(field.name, prefix)
91 values[field.name] = row.get(column_alias)
93 instance = cls(**values) # noqa
94 _validate_dataclass(instance)
96 return instance
99@dataclass(frozen=True, kw_only=True, slots=True)
100class JoinedTableRow:
101 """A data transfer object for data from multiple tables.
103 Each field of the dataclass will represent a table. The name of the field
104 will be used as prefix for creating an alias for each column of the associated
105 table.
107 The derived class must be a dataclass.
109 Note:
110 The derived class is also the ideal place to act as builder for an entity.
111 """
113 @classmethod
114 def get_aliases(cls) -> list[ExpressionInterface]:
115 """Return fields of all the TableRow dataclasses as aliases.
117 The name of the field will be used as prefix for the alias.
118 """
119 assert len(fields(cls)) > 0, "There are no fields. Is this a dataclass?"
121 aliases = []
122 for field in fields(cls):
123 aliases.extend(field.type.get_aliases(field.name))
124 return aliases
126 @classmethod
127 def map(cls, row: Record) -> Self:
128 """Map all fields of this dataclass to the TableRow dataclasses."""
129 tables = {}
130 for table_field in fields(cls):
131 tables[table_field.name] = table_field.type.map(row, table_field.name)
132 return cls(**tables) # noqa
135def unwrap[T](val: T | None) -> T:
136 """Assert when the value is None."""
137 assert val is not None
138 return val