Coverage for packages/sql-smith/src/sql_smith/functions.py: 96%

49 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2024-01-01 00:00 +0000

1from sql_smith.interfaces import ExpressionInterface, StatementInterface 

2from sql_smith.partial import ( 

3 Criteria, 

4 Expression, 

5 Identifier, 

6 Listing, 

7 Literal, 

8 QualifiedIdentifier, 

9) 

10from sql_smith.partial.parameter import Parameter 

11 

12 

13def __is_statement(value) -> bool: 

14 return isinstance(value, StatementInterface) 

15 

16 

17def param(value) -> "StatementInterface": 

18 """Create a parameter. 

19 

20 >>> func('POINT', param(1), param(2)) # POINT(? , ?) 

21 """ 

22 if __is_statement(value): 

23 return value 

24 return Parameter.create(value) 

25 

26 

27def param_all(*values): 

28 return tuple(map(param, values)) 

29 

30 

31def express(pattern: str, *args): 

32 """Create an expression. 

33 

34 >>> express('{} + 1', identify('visit')) # "visit" + 1 

35 """ 

36 return Expression(pattern, *param_all(*args)) 

37 

38 

39def alias(field_name, field_alias: str) -> "ExpressionInterface": 

40 """Create an alias for a column or a function. 

41 

42 >>> alias('users.id', 'uid') # "users"."id" AS "uid" 

43 """ 

44 return express("{} AS {}", identify(field_name), identify(field_alias)) 

45 

46 

47def listing(values: tuple | list, separator: str = ", ") -> Listing: 

48 """Create a listing. 

49 

50 >>> listing((1, 1, 2, 3, 5)) # ?, ?, ?, ?, ? 

51 >>> listing(identify_all('id', 'username', 'email')) # "id", "username", "email" 

52 """ 

53 return Listing(separator, *param_all(*values)) 

54 

55 

56def func(function: str, *args) -> "ExpressionInterface": 

57 """Create a function. 

58 

59 >>> func('COUNT', 'user.id') # COUNT("users"."id") 

60 """ 

61 return express("{}({{}})".format(function), listing(identify_all(*args))) 

62 

63 

64def literal(value) -> "StatementInterface": 

65 """Create a literal.""" 

66 if __is_statement(value): 

67 return value 

68 return Literal(value) 

69 

70 

71def criteria(pattern: str, *args) -> "CriteriaInterface": 

72 """Create a criteria. 

73 

74 >>> c = criteria( 

75 >>> "{} = {}", 

76 >>> func( 

77 >>> 'YEAR', 

78 >>> identify('start_date') 

79 >>> ), 

80 >>> literal(2021) 

81 >>> ) # YEAR("start_date") = 2021 

82 """ 

83 return Criteria(express(pattern, *args)) 

84 

85 

86def on(left: str, right: str): 

87 """Create an on clause.""" 

88 return criteria("{} = {}", identify(left), identify(right)) 

89 

90 

91def order(column, direction: str = None) -> "StatementInterface": 

92 """Create an order clause.""" 

93 if direction is None: 

94 return identify(column) 

95 return express("{{}} {}".format(direction.upper()), identify(column)) 

96 

97 

98def group(c: "CriteriaInterface") -> "CriteriaInterface": 

99 """Create a group of criteria. 

100 

101 >>> group( 

102 >>> field('username').eq('tom') 

103 >>> .or_(field('first_name').eq('Tom')) 

104 >>> ).and_( 

105 >>> field('is_active').eq(1) 

106 >>> ) 

107 >>> # ("username" = ? OR "first_name" = ?) AND "is_active" = ? 

108 """ 

109 return criteria("({})", c) 

110 

111 

112def field(name): 

113 """Starts a criteria for a column. Use it to create a condition. 

114 

115 >>> field('users.id').eq(100) # "users".id = ? 

116 """ 

117 from sql_smith.builder import CriteriaBuilder 

118 

119 return CriteriaBuilder(identify(name)) 

120 

121 

122def search(name): 

123 """Start a LIKE clause. 

124 

125 >>> search('username').contains('admin') # "username" LIKE '%admin%' 

126 """ 

127 from sql_smith.builder import LikeBuilder 

128 

129 return LikeBuilder(identify(name)) 

130 

131 

132def identify_all(*names) -> tuple: 

133 """Identify all names. 

134 

135 >>> identify_all('id', 'username') # ("id", "username") 

136 """ 

137 return tuple(map(identify, names)) 

138 

139 

140def identify(name) -> "StatementInterface": 

141 """Identify a name. 

142 

143 >>> identify('users.id') # "users"."id" 

144 """ 

145 if __is_statement(name): 

146 return name 

147 

148 if name.find(".") != -1: 

149 return QualifiedIdentifier(*identify_all(*name.split("."))) 

150 

151 if name == "*": 

152 return Literal(name) 

153 

154 return Identifier(name)