Coverage for packages/sql-smith/src/sql_smith/query/select_query.py: 83%

96 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1from sql_smith.capability import ( 

2 CanUnionMixin, 

3 HasFromMixin, 

4 HasLimitMixin, 

5 HasOffsetMixin, 

6 HasOrderByMixin, 

7 HasWhereMixin, 

8) 

9from sql_smith.functions import express, identify, identify_all, listing 

10 

11from ..interfaces import ExpressionInterface 

12from .abstract_query import AbstractQuery 

13 

14 

15class SelectQuery( 

16 CanUnionMixin, 

17 HasFromMixin, 

18 HasOrderByMixin, 

19 HasWhereMixin, 

20 HasLimitMixin, 

21 HasOffsetMixin, 

22 AbstractQuery, 

23): 

24 """Implements a SELECT query.""" 

25 

26 def __init__(self, engine: "EngineInterface"): 

27 super().__init__(engine) 

28 self._from = () 

29 self._limit = None 

30 self._offset = None 

31 self._order_by = [] 

32 self._where = None 

33 

34 self._distinct = False 

35 self._columns = [] 

36 self._joins = [] 

37 self._group_by = [] 

38 self._having = None 

39 

40 self._cte = {} 

41 self._cte_recursive = False 

42 

43 def distinct(self, state: bool = True) -> "SelectQuery": 

44 """Add or remove DISTINCT.""" 

45 self._distinct = state 

46 return self 

47 

48 def columns(self, *columns) -> "SelectQuery": 

49 """Set the columns to select.""" 

50 self._columns = identify_all(*columns) 

51 return self 

52 

53 def add_columns(self, *columns): 

54 """Add columns to the selection.""" 

55 self._columns = (*self._columns, *identify_all(*columns)) 

56 return self 

57 

58 def join( 

59 self, table: str, criteria: "CriteriaInterface", join_type: str = "" 

60 ) -> "SelectQuery": 

61 """Add a join.""" 

62 sql = "{} JOIN {{}} ON {{}}".format(join_type.upper()).strip() 

63 self._joins.append(express(sql, identify(table), criteria)) 

64 return self 

65 

66 def inner_join(self, table: str, criteria: "CriteriaInterface") -> "SelectQuery": 

67 """Add an INNER join.""" 

68 return self.join(table, criteria, "INNER") 

69 

70 def left_join(self, table: str, criteria: "CriteriaInterface") -> "SelectQuery": 

71 """Add a LEFT join.""" 

72 return self.join(table, criteria, "LEFT") 

73 

74 def right_join(self, table: str, criteria: "CriteriaInterface") -> "SelectQuery": 

75 """Add a RIGHT join.""" 

76 return self.join(table, criteria, "RIGHT") 

77 

78 def full_join(self, table: str, criteria: "CriteriaInterface") -> "SelectQuery": 

79 """Add a FULL join.""" 

80 return self.join(table, criteria, "FULL") 

81 

82 def group_by(self, *columns): 

83 """Add a GROUP BY clause.""" 

84 self._group_by = identify_all(*columns) 

85 return self 

86 

87 def having(self, criteria: "CriteriaInterface"): 

88 """Add an HAVING clause.""" 

89 self._having = criteria 

90 return self 

91 

92 def with_( 

93 self, 

94 name: str | None = None, 

95 query: ExpressionInterface | None = None, 

96 recursive: bool = False, 

97 ) -> "SelectQuery": 

98 """Add a query as CTE. 

99 

100 When no name and query is passed, all CTE's for this query will be removed. 

101 When no query is passed, the CTE with the given name will be removed. 

102 """ 

103 if name is None: 

104 self._cte = {} 

105 return self 

106 

107 if query is None: 

108 del self._cte[name] 

109 return self 

110 

111 self._cte[name] = query 

112 self._cte_recursive = recursive 

113 

114 return self 

115 

116 def as_expression(self) -> "ExpressionInterface": 

117 if len(self._cte) > 0: 

118 if self._cte_recursive: 

119 query = express("WITH RECURSIVE") 

120 else: 

121 query = express("WITH") 

122 query = self.__apply_with(query) 

123 query = query.append("SELECT") 

124 else: 

125 query = self.start_expression() 

126 query = self.__apply_distinct(query) 

127 query = self.__apply_columns(query) 

128 query = self._apply_from(query) 

129 query = self.__apply_joins(query) 

130 query = self._apply_where(query) 

131 query = self.__apply_group_by(query) 

132 query = self.__apply_having(query) 

133 query = self._apply_order_by(query) 

134 query = self._apply_limit(query) 

135 query = self._apply_offset(query) 

136 return query 

137 

138 def start_expression(self) -> "ExpressionInterface": 

139 return express("SELECT") 

140 

141 def __apply_columns(self, query: "ExpressionInterface") -> "ExpressionInterface": 

142 if len(self._columns) > 0: 

143 return query.append("{}", listing(self._columns)) 

144 return query.append("*") 

145 

146 def __apply_distinct(self, query: "ExpressionInterface") -> "ExpressionInterface": 

147 if self._distinct: 

148 return query.append("DISTINCT") 

149 return query 

150 

151 def __apply_joins(self, query: "ExpressionInterface") -> "ExpressionInterface": 

152 return ( 

153 query.append("{}", listing(self._joins, " ")) 

154 if len(self._joins) > 0 

155 else query 

156 ) 

157 

158 def __apply_group_by(self, query: "ExpressionInterface") -> "ExpressionInterface": 

159 return ( 

160 query.append("GROUP BY {}", listing(self._group_by, " ")) 

161 if len(self._group_by) > 0 

162 else query 

163 ) 

164 

165 def __apply_having(self, query: "ExpressionInterface") -> "ExpressionInterface": 

166 return query.append("HAVING {}", self._having) if self._having else query 

167 

168 def __apply_with(self, query: "ExpressionInterface") -> "ExpressionInterface": 

169 if len(self._cte) > 0: 

170 for name, cte_query in self._cte.items(): 

171 query = query.append(name + " AS ({})", cte_query) 

172 return query