Coverage for packages/kwai-core/src/kwai_core/db/database.py: 91%

112 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1"""Module for database classes/functions.""" 

2 

3import dataclasses 

4 

5from collections import namedtuple 

6from typing import Any, AsyncIterator, TypeAlias 

7 

8import asyncmy 

9 

10from loguru import logger 

11from sql_smith import QueryFactory 

12from sql_smith.engine import MysqlEngine 

13from sql_smith.functions import field 

14from sql_smith.query import AbstractQuery, SelectQuery 

15 

16from kwai_core.db.exceptions import DatabaseException, QueryException 

17from kwai_core.settings import DatabaseSettings 

18 

19 

20Record: TypeAlias = dict[str, Any] 

21ExecuteResult = namedtuple("ExecuteResult", ("rowcount", "last_insert_id")) 

22 

23 

24class Database: 

25 """Class for communicating with a database. 

26 

27 Attributes: 

28 _connection: A connection 

29 _settings (DatabaseSettings): The settings for this database connection. 

30 """ 

31 

32 def __init__(self, settings: DatabaseSettings): 

33 self._connection: asyncmy.Connection | None = None 

34 self._settings = settings 

35 

36 async def setup(self): 

37 """Set up the connection pool.""" 

38 try: 

39 self._connection = await asyncmy.connect( 

40 host=self._settings.host, 

41 database=self._settings.name, 

42 user=self._settings.user, 

43 password=self._settings.password, 

44 ) 

45 except Exception as exc: 

46 raise DatabaseException( 

47 f"Setting up connection for database {self._settings.name} " 

48 f"failed: {exc}" 

49 ) from exc 

50 

51 async def check_connection(self): 

52 """Check if the connection is set, if not it will try to connect.""" 

53 if self._connection is None: 

54 await self.setup() 

55 

56 async def close(self): 

57 """Close the connection.""" 

58 if self._connection: 

59 await self._connection.ensure_closed() 

60 self._connection = None 

61 

62 @classmethod 

63 def create_query_factory(cls) -> QueryFactory: 

64 """Return a query factory for the current database engine. 

65 

66 The query factory is used to start creating a SELECT, INSERT, UPDATE or 

67 DELETE query. 

68 

69 Returns: 

70 (QueryFactory): The query factory from sql-smith. 

71 Currently, it returns a query factory for the mysql engine. In the 

72 future it can provide other engines. 

73 """ 

74 return QueryFactory(MysqlEngine()) 

75 

76 async def commit(self): 

77 """Commit all changes.""" 

78 await self.check_connection() 

79 await self._connection.commit() 

80 

81 async def execute(self, query: AbstractQuery) -> ExecuteResult: 

82 """Execute a query. 

83 

84 The last rowid from the cursor is returned when the query executed 

85 successfully. On insert, this can be used to determine the new id of a row. 

86 

87 Args: 

88 query (AbstractQuery): The query to execute. 

89 

90 Returns: 

91 (int): When the query is an insert query, it will return the last rowid. 

92 (None): When there is no last rowid. 

93 

94 Raises: 

95 (QueryException): Raised when the query contains an error. 

96 """ 

97 compiled_query = query.compile() 

98 

99 await self.check_connection() 

100 async with self._connection.cursor() as cursor: 

101 try: 

102 generated_sql = cursor.mogrify( 

103 compiled_query.sql, compiled_query.params 

104 ) 

105 self.log_query(generated_sql) 

106 await cursor.execute(generated_sql) 

107 return ExecuteResult(cursor.rowcount, cursor.lastrowid) 

108 except Exception as exc: 

109 raise QueryException(compiled_query.sql) from exc 

110 

111 async def fetch_one(self, query: SelectQuery) -> Record | None: 

112 """Execute a query and return the first row. 

113 

114 Args: 

115 query (SelectQuery): The query to execute. 

116 

117 Returns: 

118 (Record): A row is a dictionary using the column names 

119 as key and the column values as value. 

120 (None): The query resulted in no rows found. 

121 

122 Raises: 

123 (QueryException): Raised when the query contains an error. 

124 """ 

125 compiled_query = query.compile() 

126 

127 await self.check_connection() 

128 try: 

129 async with self._connection.cursor() as cursor: 

130 generated_sql = cursor.mogrify( 

131 compiled_query.sql, compiled_query.params 

132 ) 

133 self.log_query(generated_sql) 

134 await cursor.execute(generated_sql) 

135 column_names = [column[0] for column in cursor.description] 

136 if row := await cursor.fetchone(): 

137 return { 

138 column_name: column 

139 for column, column_name in zip(row, column_names, strict=True) 

140 } 

141 except Exception as exc: 

142 raise QueryException(compiled_query.sql) from exc 

143 

144 return None # Nothing found 

145 

146 async def fetch(self, query: SelectQuery) -> AsyncIterator[Record]: 

147 """Execute a query and yields each row. 

148 

149 Args: 

150 query (SelectQuery): The query to execute. 

151 

152 Yields: 

153 (Record): A row is a dictionary using the column names 

154 as key and the column values as value. 

155 

156 Raises: 

157 (QueryException): Raised when the query contains an error. 

158 """ 

159 compiled_query = query.compile() 

160 self.log_query(compiled_query.sql) 

161 

162 await self.check_connection() 

163 try: 

164 async with self._connection.cursor() as cursor: 

165 await cursor.execute(compiled_query.sql, compiled_query.params) 

166 column_names = [column[0] for column in cursor.description] 

167 while row := await cursor.fetchone(): 

168 yield { 

169 column_name: column 

170 for column, column_name in zip(row, column_names, strict=True) 

171 } 

172 except Exception as exc: 

173 raise QueryException(compiled_query.sql) from exc 

174 

175 async def insert( 

176 self, table_name: str, *table_data: Any, id_column: str = "id" 

177 ) -> int: 

178 """Insert one or more instances of a dataclass into the given table. 

179 

180 Args: 

181 table_name: The name of the table 

182 table_data: One or more instances of a dataclass containing the values 

183 id_column: The name of the id column (default is 'id') 

184 

185 Returns: 

186 (int): The last inserted id. When multiple inserts are performed, this will 

187 be the id of the last executed insert. 

188 

189 Raises: 

190 (QueryException): Raised when the query contains an error. 

191 """ 

192 assert len(table_data) > 0, "There should be at least one row to insert." 

193 assert dataclasses.is_dataclass(table_data[0]), ( 

194 "table_data should be a dataclass" 

195 ) 

196 

197 record = dataclasses.asdict(table_data[0]) 

198 if id_column in record: 

199 del record[id_column] 

200 query = self.create_query_factory().insert(table_name).columns(*record.keys()) 

201 

202 for data in table_data: 

203 assert dataclasses.is_dataclass(data), "table_data should be a dataclass" 

204 record = dataclasses.asdict(data) 

205 if id_column in record: 

206 del record[id_column] 

207 query = query.values(*record.values()) 

208 

209 execute_result = await self.execute(query) 

210 return execute_result.last_insert_id 

211 

212 async def update( 

213 self, id_: Any, table_name: str, table_data: Any, id_column: str = "id" 

214 ) -> int: 

215 """Update a dataclass in the given table. 

216 

217 Args: 

218 id_: The id of the data to update. 

219 table_name: The name of the table. 

220 table_data: The dataclass containing the data. 

221 id_column: The name of the id column (default is 'id'). 

222 

223 Raises: 

224 (QueryException): Raised when the query contains an error. 

225 

226 Returns: 

227 The number of rows affected. 

228 """ 

229 assert dataclasses.is_dataclass(table_data), "table_data should be a dataclass" 

230 

231 record = dataclasses.asdict(table_data) 

232 del record[id_column] 

233 query = ( 

234 self.create_query_factory() 

235 .update(table_name) 

236 .set(record) 

237 .where(field(id_column).eq(id_)) 

238 ) 

239 execute_result = await self.execute(query) 

240 return execute_result.rowcount 

241 

242 async def delete(self, id_: Any, table_name: str, id_column: str = "id"): 

243 """Delete a row from the table using the id field. 

244 

245 Args: 

246 id_ (Any): The id of the row to delete. 

247 table_name (str): The name of the table. 

248 id_column (str): The name of the id column (default is 'id') 

249 

250 Raises: 

251 (QueryException): Raised when the query results in an error. 

252 """ 

253 query = ( 

254 self.create_query_factory() 

255 .delete(table_name) 

256 .where(field(id_column).eq(id_)) 

257 ) 

258 await self.execute(query) 

259 

260 def log_query(self, query: str): 

261 """Log a query. 

262 

263 Args: 

264 query (str): The query to log. 

265 """ 

266 db_logger = logger.bind(database=self._settings.name) 

267 db_logger.info( 

268 "DB: {database} - Query: {query}", database=self._settings.name, query=query 

269 ) 

270 

271 def log_affected_rows(self, rowcount: int): 

272 """Log the number of affected rows of the last executed query. 

273 

274 Args: 

275 rowcount: The number of affected rows. 

276 """ 

277 db_logger = logger.bind(database=self._settings.name) 

278 db_logger.info( 

279 "DB: {database} - Affected rows: {rowcount}", 

280 database=self._settings.name, 

281 rowcount=rowcount, 

282 ) 

283 

284 @property 

285 def settings(self) -> DatabaseSettings: 

286 """Return the database settings. 

287 

288 This property is immutable: the returned value is a copy of the current 

289 settings. 

290 """ 

291 return self._settings.model_copy() 

292 

293 async def begin(self): 

294 """Start a transaction.""" 

295 await self.check_connection() 

296 await self._connection.begin() 

297 

298 async def rollback(self): 

299 """Rollback a transaction.""" 

300 await self.check_connection() 

301 await self._connection.rollback()