1+ package sqlancer .spark ;
2+
3+ import java .sql .Connection ;
4+ import java .sql .DriverManager ;
5+ import java .sql .SQLException ;
6+ import java .sql .Statement ;
7+
8+ import com .google .auto .service .AutoService ;
9+
10+ import sqlancer .AbstractAction ;
11+ import sqlancer .DatabaseProvider ;
12+ import sqlancer .IgnoreMeException ;
13+ import sqlancer .MainOptions ;
14+ import sqlancer .Randomly ;
15+ import sqlancer .SQLConnection ;
16+ import sqlancer .SQLProviderAdapter ;
17+ import sqlancer .StatementExecutor ;
18+ import sqlancer .common .query .SQLQueryAdapter ;
19+ import sqlancer .common .query .SQLQueryProvider ;
20+ import sqlancer .spark .gen .SparkInsertGenerator ;
21+ import sqlancer .spark .gen .SparkTableGenerator ;
22+
23+ @ AutoService (DatabaseProvider .class )
24+ public class SparkProvider extends SQLProviderAdapter <SparkGlobalState , SparkOptions > {
25+
26+ public SparkProvider () {
27+ super (SparkGlobalState .class , SparkOptions .class );
28+ }
29+
30+ public enum Action implements AbstractAction <SparkGlobalState > {
31+ INSERT (SparkInsertGenerator ::getQuery ); // You will need to create this class
32+
33+ private final SQLQueryProvider <SparkGlobalState > sqlQueryProvider ;
34+
35+ Action (SQLQueryProvider <SparkGlobalState > sqlQueryProvider ) {
36+ this .sqlQueryProvider = sqlQueryProvider ;
37+ }
38+
39+ @ Override
40+ public SQLQueryAdapter getQuery (SparkGlobalState state ) throws Exception {
41+ return sqlQueryProvider .getQuery (state );
42+ }
43+ }
44+
45+ private static int mapActions (SparkGlobalState globalState , Action a ) {
46+ Randomly r = globalState .getRandomly ();
47+ switch (a ) {
48+ case INSERT :
49+ return r .getInteger (0 , globalState .getOptions ().getMaxNumberInserts ());
50+ default :
51+ throw new AssertionError (a );
52+ }
53+ }
54+
55+ @ Override
56+ public void generateDatabase (SparkGlobalState globalState ) throws Exception {
57+ for (int i = 0 ; i < Randomly .fromOptions (1 , 2 ); i ++) {
58+ boolean success ;
59+ do {
60+ String tableName = globalState .getSchema ().getFreeTableName ();
61+ SQLQueryAdapter qt = SparkTableGenerator .generate (globalState , tableName );
62+ success = globalState .executeStatement (qt );
63+ } while (!success );
64+ }
65+
66+ if (globalState .getSchema ().getDatabaseTables ().isEmpty ()) {
67+ throw new IgnoreMeException ();
68+ }
69+
70+ StatementExecutor <SparkGlobalState , Action > se = new StatementExecutor <>(globalState , Action .values (),
71+ SparkProvider ::mapActions , (q ) -> {
72+ if (globalState .getSchema ().getDatabaseTables ().isEmpty ()) {
73+ throw new IgnoreMeException ();
74+ }
75+ });
76+ se .executeStatements ();
77+ }
78+
79+ @ Override
80+ public SQLConnection createDatabase (SparkGlobalState globalState ) throws SQLException {
81+ String username = globalState .getOptions ().getUserName ();
82+ String password = globalState .getOptions ().getPassword ();
83+ String host = globalState .getOptions ().getHost ();
84+ int port = globalState .getOptions ().getPort ();
85+
86+ if (host == null ) {
87+ host = SparkOptions .DEFAULT_HOST ;
88+ }
89+ if (port == MainOptions .NO_SET_PORT ) {
90+ port = SparkOptions .DEFAULT_PORT ;
91+ }
92+
93+ String databaseName = globalState .getDatabaseName ();
94+
95+ // Spark uses the Hive driver for JDBC usually
96+ String url = String .format ("jdbc:hive2://%s:%d/%s" , host , port , "default" );
97+
98+ // Connect to default to create the fuzzing DB
99+ Connection con = DriverManager .getConnection (url , username , password );
100+ try (Statement s = con .createStatement ()) {
101+ s .execute ("DROP DATABASE IF EXISTS " + databaseName + " CASCADE" );
102+ }
103+ try (Statement s = con .createStatement ()) {
104+ s .execute ("CREATE DATABASE " + databaseName );
105+ }
106+ con .close ();
107+
108+ // Connect to the specific fuzzing DB
109+ con = DriverManager .getConnection (String .format ("jdbc:hive2://%s:%d/%s" , host , port , databaseName ), username ,
110+ password );
111+ try (Statement s = con .createStatement ()) {
112+ // This allows casting things like BOOLEAN to DATE/TIMESTAMP, which the generator loves to do.
113+ s .execute ("SET spark.sql.ansi.enabled=false" );
114+ }
115+ return new SQLConnection (con );
116+ }
117+
118+ @ Override
119+ public String getDBMSName () {
120+ return "spark" ;
121+ }
122+ }
0 commit comments