Skip to content

Commit 05bb28c

Browse files
author
Aman
committed
Add Spark support for TLP Oracle
1 parent 8a5ea5d commit 05bb28c

29 files changed

+1460
-2
lines changed

.github/workflows/main.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,44 @@ jobs:
333333
- name: Run Tests
334334
run: HIVE_AVAILABLE=true mvn -Dtest=TestHiveTLP test
335335

336+
spark:
337+
name: DBMS Tests (Spark)
338+
runs-on: ubuntu-latest
339+
340+
services:
341+
spark:
342+
image: apache/spark:3.5.1
343+
ports:
344+
- 10000:10000
345+
346+
command: >-
347+
/opt/spark/bin/spark-submit
348+
--class org.apache.spark.sql.hive.thriftserver.HiveThriftServer2
349+
--name "Thrift JDBC/ODBC Server"
350+
--master local[*]
351+
--driver-memory 4g
352+
--conf spark.hive.server2.thrift.port=10000
353+
--conf spark.sql.warehouse.dir=/tmp/spark-warehouse
354+
spark-internal
355+
356+
steps:
357+
- uses: actions/checkout@v3
358+
with:
359+
fetch-depth: 0
360+
361+
- name: Set up JDK 11
362+
uses: actions/setup-java@v3
363+
with:
364+
distribution: 'temurin'
365+
java-version: '11'
366+
cache: 'maven'
367+
368+
- name: Build SQLancer
369+
run: mvn -B package -DskipTests=true
370+
371+
- name: Run Tests
372+
run: SPARK_AVAILABLE=true mvn -Dtest=TestSparkTLP test
373+
336374
hsqldb:
337375
name: DBMS Tests (HSQLB)
338376
runs-on: ubuntu-latest

pom.xml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@
329329
</dependency>
330330
<dependency>
331331
<groupId>org.slf4j</groupId>
332-
<artifactId> slf4j-simple</artifactId>
332+
<artifactId>slf4j-simple</artifactId>
333333
<version>2.0.6</version>
334334
</dependency>
335335
<dependency>
@@ -381,7 +381,7 @@
381381
<dependency>
382382
<groupId>org.apache.hive</groupId>
383383
<artifactId>hive-jdbc</artifactId>
384-
<version>4.0.1</version>
384+
<version>3.1.2</version>
385385
</dependency>
386386
<dependency>
387387
<groupId>org.apache.hive</groupId>
@@ -393,6 +393,11 @@
393393
<artifactId>hive-cli</artifactId>
394394
<version>4.0.1</version>
395395
</dependency>
396+
<dependency>
397+
<groupId>org.apache.hadoop</groupId>
398+
<artifactId>hadoop-common</artifactId>
399+
<version>3.2.4</version>
400+
</dependency>
396401
</dependencies>
397402
<reporting>
398403
<plugins>

src/sqlancer/Main.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import sqlancer.tidb.TiDBProvider;
4949
import sqlancer.yugabyte.ycql.YCQLProvider;
5050
import sqlancer.yugabyte.ysql.YSQLProvider;
51+
import sqlancer.spark.SparkProvider;
5152

5253
public final class Main {
5354

@@ -756,6 +757,7 @@ private static void checkForIssue799(List<DatabaseProvider<?, ?, ?>> providers)
756757
providers.add(new DuckDBProvider());
757758
providers.add(new H2Provider());
758759
providers.add(new HiveProvider());
760+
providers.add(new SparkProvider());
759761
providers.add(new HSQLDBProvider());
760762
providers.add(new MariaDBProvider());
761763
providers.add(new MaterializeProvider());
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package sqlancer.spark;
2+
3+
import java.util.ArrayList;
4+
import java.util.List;
5+
6+
import sqlancer.common.query.ExpectedErrors;
7+
8+
public final class SparkErrors {
9+
10+
private SparkErrors() {
11+
}
12+
13+
public static List<String> getExpressionErrors() {
14+
ArrayList<String> errors = new ArrayList<>();
15+
16+
errors.add("cannot resolve");
17+
errors.add("AnalysisException");
18+
errors.add("data type mismatch");
19+
errors.add("undefined function");
20+
errors.add("mismatched input");
21+
errors.add("due to data type mismatch");
22+
23+
// --- Invalid Literals
24+
errors.add("The value of the typed literal");
25+
26+
errors.add("DATATYPE_MISMATCH");
27+
errors.add("cannot be cast to");
28+
29+
errors.add("Overflow");
30+
errors.add("Divide by zero"); // Common if spark.sql.ansi.enabled is true
31+
errors.add("division by zero");
32+
33+
// --- Group By / Aggregation errors ---
34+
errors.add("grouping expressions");
35+
errors.add("expression is neither present in the group by");
36+
errors.add("is not a valid grouping expression");
37+
errors.add("is not contained in either an aggregate function or the GROUP BY clause");
38+
errors.add("PARSE_SYNTAX_ERROR");
39+
errors.add("Syntax error");
40+
41+
return errors;
42+
}
43+
44+
public static void addExpressionErrors(ExpectedErrors errors) {
45+
errors.addAll(getExpressionErrors());
46+
}
47+
48+
public static List<String> getInsertErrors() {
49+
ArrayList<String> errors = new ArrayList<>();
50+
51+
errors.add("not enough data columns");
52+
errors.add("cannot write to");
53+
errors.add("incompatible types");
54+
errors.add("too many data columns");
55+
errors.add("cannot be cast to");
56+
errors.add("Error running query");
57+
errors.add("The value of the typed literal");
58+
errors.add("Cannot safely cast"); // Found in logs: Decimal -> Date
59+
errors.add("AnalysisException"); // Spark throws this for almost all insert failures
60+
61+
return errors;
62+
}
63+
64+
public static void addInsertErrors(ExpectedErrors errors) {
65+
errors.addAll(getInsertErrors());
66+
}
67+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package sqlancer.spark;
2+
3+
import sqlancer.SQLGlobalState;
4+
5+
public class SparkGlobalState extends SQLGlobalState<SparkOptions, SparkSchema> {
6+
7+
@Override
8+
protected SparkSchema readSchema() throws Exception {
9+
return SparkSchema.fromConnection(getConnection(), getDatabaseName());
10+
}
11+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package sqlancer.spark;
2+
3+
import java.sql.SQLException;
4+
import java.util.Arrays;
5+
import java.util.List;
6+
7+
import com.beust.jcommander.Parameter;
8+
import com.beust.jcommander.Parameters;
9+
10+
import sqlancer.DBMSSpecificOptions;
11+
import sqlancer.OracleFactory;
12+
import sqlancer.common.oracle.TLPWhereOracle;
13+
import sqlancer.common.oracle.TestOracle;
14+
import sqlancer.common.query.ExpectedErrors;
15+
import sqlancer.spark.gen.SparkExpressionGenerator;
16+
17+
@Parameters(separators = "=", commandDescription = "Spark SQL (default port: " + SparkOptions.DEFAULT_PORT
18+
+ ", default host: " + SparkOptions.DEFAULT_HOST + ")")
19+
public class SparkOptions implements DBMSSpecificOptions<SparkOptions.SparkOracleFactory> {
20+
public static final String DEFAULT_HOST = "localhost";
21+
public static final int DEFAULT_PORT = 10000;
22+
23+
@Parameter(names = "--oracle")
24+
public List<SparkOracleFactory> oracle = Arrays.asList(SparkOracleFactory.TLPWhere);
25+
26+
public enum SparkOracleFactory implements OracleFactory<SparkGlobalState> {
27+
TLPWhere {
28+
@Override
29+
public TestOracle<SparkGlobalState> create(SparkGlobalState globalState) throws SQLException {
30+
SparkExpressionGenerator gen = new SparkExpressionGenerator(globalState);
31+
ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(SparkErrors.getExpressionErrors())
32+
.build();
33+
34+
return new TLPWhereOracle<>(globalState, gen, expectedErrors);
35+
}
36+
};
37+
}
38+
39+
@Override
40+
public List<SparkOracleFactory> getTestOracleFactory() {
41+
return oracle;
42+
}
43+
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package sqlancer.spark;
2+
3+
import java.sql.Connection;
4+
import java.sql.DriverManager;
5+
import java.sql.SQLException;
6+
import java.sql.Statement;
7+
8+
import com.google.auto.service.AutoService;
9+
10+
import sqlancer.AbstractAction;
11+
import sqlancer.DatabaseProvider;
12+
import sqlancer.IgnoreMeException;
13+
import sqlancer.MainOptions;
14+
import sqlancer.Randomly;
15+
import sqlancer.SQLConnection;
16+
import sqlancer.SQLProviderAdapter;
17+
import sqlancer.StatementExecutor;
18+
import sqlancer.common.query.SQLQueryAdapter;
19+
import sqlancer.common.query.SQLQueryProvider;
20+
import sqlancer.spark.gen.SparkInsertGenerator;
21+
import sqlancer.spark.gen.SparkTableGenerator;
22+
23+
@AutoService(DatabaseProvider.class)
24+
public class SparkProvider extends SQLProviderAdapter<SparkGlobalState, SparkOptions> {
25+
26+
public SparkProvider() {
27+
super(SparkGlobalState.class, SparkOptions.class);
28+
}
29+
30+
public enum Action implements AbstractAction<SparkGlobalState> {
31+
INSERT(SparkInsertGenerator::getQuery); // You will need to create this class
32+
33+
private final SQLQueryProvider<SparkGlobalState> sqlQueryProvider;
34+
35+
Action(SQLQueryProvider<SparkGlobalState> sqlQueryProvider) {
36+
this.sqlQueryProvider = sqlQueryProvider;
37+
}
38+
39+
@Override
40+
public SQLQueryAdapter getQuery(SparkGlobalState state) throws Exception {
41+
return sqlQueryProvider.getQuery(state);
42+
}
43+
}
44+
45+
private static int mapActions(SparkGlobalState globalState, Action a) {
46+
Randomly r = globalState.getRandomly();
47+
switch (a) {
48+
case INSERT:
49+
return r.getInteger(0, globalState.getOptions().getMaxNumberInserts());
50+
default:
51+
throw new AssertionError(a);
52+
}
53+
}
54+
55+
@Override
56+
public void generateDatabase(SparkGlobalState globalState) throws Exception {
57+
for (int i = 0; i < Randomly.fromOptions(1, 2); i++) {
58+
boolean success;
59+
do {
60+
String tableName = globalState.getSchema().getFreeTableName();
61+
SQLQueryAdapter qt = SparkTableGenerator.generate(globalState, tableName);
62+
success = globalState.executeStatement(qt);
63+
} while (!success);
64+
}
65+
66+
if (globalState.getSchema().getDatabaseTables().isEmpty()) {
67+
throw new IgnoreMeException();
68+
}
69+
70+
StatementExecutor<SparkGlobalState, Action> se = new StatementExecutor<>(globalState, Action.values(),
71+
SparkProvider::mapActions, (q) -> {
72+
if (globalState.getSchema().getDatabaseTables().isEmpty()) {
73+
throw new IgnoreMeException();
74+
}
75+
});
76+
se.executeStatements();
77+
}
78+
79+
@Override
80+
public SQLConnection createDatabase(SparkGlobalState globalState) throws SQLException {
81+
String username = globalState.getOptions().getUserName();
82+
String password = globalState.getOptions().getPassword();
83+
String host = globalState.getOptions().getHost();
84+
int port = globalState.getOptions().getPort();
85+
86+
if (host == null) {
87+
host = SparkOptions.DEFAULT_HOST;
88+
}
89+
if (port == MainOptions.NO_SET_PORT) {
90+
port = SparkOptions.DEFAULT_PORT;
91+
}
92+
93+
String databaseName = globalState.getDatabaseName();
94+
95+
// Spark uses the Hive driver for JDBC usually
96+
String url = String.format("jdbc:hive2://%s:%d/%s", host, port, "default");
97+
98+
// Connect to default to create the fuzzing DB
99+
Connection con = DriverManager.getConnection(url, username, password);
100+
try (Statement s = con.createStatement()) {
101+
s.execute("DROP DATABASE IF EXISTS " + databaseName + " CASCADE");
102+
}
103+
try (Statement s = con.createStatement()) {
104+
s.execute("CREATE DATABASE " + databaseName);
105+
}
106+
con.close();
107+
108+
// Connect to the specific fuzzing DB
109+
con = DriverManager.getConnection(String.format("jdbc:hive2://%s:%d/%s", host, port, databaseName), username,
110+
password);
111+
try (Statement s = con.createStatement()) {
112+
// This allows casting things like BOOLEAN to DATE/TIMESTAMP, which the generator loves to do.
113+
s.execute("SET spark.sql.ansi.enabled=false");
114+
}
115+
return new SQLConnection(con);
116+
}
117+
118+
@Override
119+
public String getDBMSName() {
120+
return "spark";
121+
}
122+
}

0 commit comments

Comments
 (0)