Coverage for packages/sql-smith/src/sql_smith/query/select_query.py: 83%
96 statements
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2024-01-01 00:00 +0000
1from sql_smith.capability import (
2 CanUnionMixin,
3 HasFromMixin,
4 HasLimitMixin,
5 HasOffsetMixin,
6 HasOrderByMixin,
7 HasWhereMixin,
8)
9from sql_smith.functions import express, identify, identify_all, listing
11from ..interfaces import ExpressionInterface
12from .abstract_query import AbstractQuery
15class SelectQuery(
16 CanUnionMixin,
17 HasFromMixin,
18 HasOrderByMixin,
19 HasWhereMixin,
20 HasLimitMixin,
21 HasOffsetMixin,
22 AbstractQuery,
23):
24 """Implements a SELECT query."""
26 def __init__(self, engine: "EngineInterface"):
27 super().__init__(engine)
28 self._from = ()
29 self._limit = None
30 self._offset = None
31 self._order_by = []
32 self._where = None
34 self._distinct = False
35 self._columns = []
36 self._joins = []
37 self._group_by = []
38 self._having = None
40 self._cte = {}
41 self._cte_recursive = False
43 def distinct(self, state: bool = True) -> "SelectQuery":
44 """Add or remove DISTINCT."""
45 self._distinct = state
46 return self
48 def columns(self, *columns) -> "SelectQuery":
49 """Set the columns to select."""
50 self._columns = identify_all(*columns)
51 return self
53 def add_columns(self, *columns):
54 """Add columns to the selection."""
55 self._columns = (*self._columns, *identify_all(*columns))
56 return self
58 def join(
59 self, table: str, criteria: "CriteriaInterface", join_type: str = ""
60 ) -> "SelectQuery":
61 """Add a join."""
62 sql = "{} JOIN {{}} ON {{}}".format(join_type.upper()).strip()
63 self._joins.append(express(sql, identify(table), criteria))
64 return self
66 def inner_join(self, table: str, criteria: "CriteriaInterface") -> "SelectQuery":
67 """Add an INNER join."""
68 return self.join(table, criteria, "INNER")
70 def left_join(self, table: str, criteria: "CriteriaInterface") -> "SelectQuery":
71 """Add a LEFT join."""
72 return self.join(table, criteria, "LEFT")
74 def right_join(self, table: str, criteria: "CriteriaInterface") -> "SelectQuery":
75 """Add a RIGHT join."""
76 return self.join(table, criteria, "RIGHT")
78 def full_join(self, table: str, criteria: "CriteriaInterface") -> "SelectQuery":
79 """Add a FULL join."""
80 return self.join(table, criteria, "FULL")
82 def group_by(self, *columns):
83 """Add a GROUP BY clause."""
84 self._group_by = identify_all(*columns)
85 return self
87 def having(self, criteria: "CriteriaInterface"):
88 """Add an HAVING clause."""
89 self._having = criteria
90 return self
92 def with_(
93 self,
94 name: str | None = None,
95 query: ExpressionInterface | None = None,
96 recursive: bool = False,
97 ) -> "SelectQuery":
98 """Add a query as CTE.
100 When no name and query is passed, all CTE's for this query will be removed.
101 When no query is passed, the CTE with the given name will be removed.
102 """
103 if name is None:
104 self._cte = {}
105 return self
107 if query is None:
108 del self._cte[name]
109 return self
111 self._cte[name] = query
112 self._cte_recursive = recursive
114 return self
116 def as_expression(self) -> "ExpressionInterface":
117 if len(self._cte) > 0:
118 if self._cte_recursive:
119 query = express("WITH RECURSIVE")
120 else:
121 query = express("WITH")
122 query = self.__apply_with(query)
123 query = query.append("SELECT")
124 else:
125 query = self.start_expression()
126 query = self.__apply_distinct(query)
127 query = self.__apply_columns(query)
128 query = self._apply_from(query)
129 query = self.__apply_joins(query)
130 query = self._apply_where(query)
131 query = self.__apply_group_by(query)
132 query = self.__apply_having(query)
133 query = self._apply_order_by(query)
134 query = self._apply_limit(query)
135 query = self._apply_offset(query)
136 return query
138 def start_expression(self) -> "ExpressionInterface":
139 return express("SELECT")
141 def __apply_columns(self, query: "ExpressionInterface") -> "ExpressionInterface":
142 if len(self._columns) > 0:
143 return query.append("{}", listing(self._columns))
144 return query.append("*")
146 def __apply_distinct(self, query: "ExpressionInterface") -> "ExpressionInterface":
147 if self._distinct:
148 return query.append("DISTINCT")
149 return query
151 def __apply_joins(self, query: "ExpressionInterface") -> "ExpressionInterface":
152 return (
153 query.append("{}", listing(self._joins, " "))
154 if len(self._joins) > 0
155 else query
156 )
158 def __apply_group_by(self, query: "ExpressionInterface") -> "ExpressionInterface":
159 return (
160 query.append("GROUP BY {}", listing(self._group_by, " "))
161 if len(self._group_by) > 0
162 else query
163 )
165 def __apply_having(self, query: "ExpressionInterface") -> "ExpressionInterface":
166 return query.append("HAVING {}", self._having) if self._having else query
168 def __apply_with(self, query: "ExpressionInterface") -> "ExpressionInterface":
169 if len(self._cte) > 0:
170 for name, cte_query in self._cte.items():
171 query = query.append(name + " AS ({})", cte_query)
172 return query