Skip to content

Commit

Permalink
Merge pull request #44 from edward3h/sqlarray_to_container
Browse files Browse the repository at this point in the history
Convert from SQL Array to container result
  • Loading branch information
edward3h authored Sep 17, 2024
2 parents 7e55999 + b178ae7 commit 6c01eea
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,4 @@ public boolean singleResult() {
}
return true;
}

public String fromList() {
if (signature.returnType() instanceof ContainerType containerType) {
var container = containerType.type();
var template = container.fromListTemplate();
if (template.contains("%s")) { // hacky
return template.formatted(containerType.containedType());
} else {
return template;
}
}
return """
l.empty() ? null: l.get(0)
""";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public static SqlTypeMapping get(ColumnMetaData columnMetaData) {
public KiwiType kiwiType() {
if (jdbcType == JDBCType.ARRAY) {
assert componentType != null;
return new SqlArrayType(componentType.kiwiType(), componentType.jdbcType, componentDbType);
return new SqlArrayType(componentType.kiwiType(), componentType, componentDbType);
}
if (CoreTypes.primitiveToBoxed.containsKey(baseType)) {
return new PrimitiveKiwiType(baseType().getSimpleName(), isNullable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public boolean validateReturn(List<ColumnMetaData> columnMetaData, KiwiType retu
return reportError(
"Missing component type for column %s type %s".formatted(cmd.name(), columnType));
}
return componentType.type().isSimple() && validateCompatible(columnType, componentType.type())
return validateCompatible(columnType, componentType.type())
|| reportError("Incompatible component type %s for column %s type %s"
.formatted(componentType, cmd.name(), columnType));
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

import com.karuslabs.utilitary.Logger;
import com.palantir.javapoet.*;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import javax.lang.model.element.Modifier;
import org.ethelred.kiwiproc.processor.*;
import org.ethelred.kiwiproc.processor.types.ContainerType;
import org.ethelred.kiwiproc.processor.types.KiwiType;

public class InstanceGenerator {

Expand All @@ -16,6 +19,7 @@ public class InstanceGenerator {
private final CoreTypes coreTypes;
private final Set<String> parameterNames = new HashSet<>();
private final Map<String, String> patchedNames = new HashMap<>();
private int patchedNameCount = 0;

public InstanceGenerator(Logger logger, KiwiTypeConverter kiwiTypeConverter, CoreTypes coreTypes) {
this.logger = logger;
Expand Down Expand Up @@ -50,6 +54,7 @@ public JavaFile generate(DAOClassInfo classInfo) {
private MethodSpec buildMethod(DAOMethodInfo methodInfo) {
parameterNames.clear();
patchedNames.clear();
patchedNameCount = 0;
var methodSpecBuilder = MethodSpec.overriding(methodInfo.methodElement());
methodSpecBuilder.addStatement("var connection = context.getConnection()");
methodSpecBuilder.beginControlFlow(
Expand Down Expand Up @@ -83,7 +88,8 @@ private CodeBlock queryMethodBody(DAOMethodInfo methodInfo) {
methodInfo.parameterMapping().forEach(parameterInfo -> {
var name = "param" + parameterInfo.index();
var conversion = lookupConversion(parameterInfo::element, parameterInfo.mapper());
buildConversion(builder, conversion, name, parameterInfo.javaAccessor(), true);
buildConversion(
builder, conversion, parameterInfo.mapper().target(), name, parameterInfo.javaAccessor(), true);
var nullableSource = parameterInfo.mapper().source().isNullable();
// if (nullableSource) {
// builder.beginControlFlow("if ($L == null)", name)
Expand All @@ -97,11 +103,10 @@ private CodeBlock queryMethodBody(DAOMethodInfo methodInfo) {
}
parameterNames.add(parameterInfo.javaAccessor());
});
var listVariable = patchName("l");
TypeName componentClass = kiwiTypeConverter.fromKiwiType(methodInfo.resultComponentType());
builder.addStatement("var rs = statement.executeQuery()")
.addStatement(
"List<$T> l = new $T<>()",
kiwiTypeConverter.fromKiwiType(methodInfo.resultComponentType()),
ArrayList.class)
.addStatement("List<$T> $L = new $T<>()", componentClass, listVariable, ArrayList.class)
.beginControlFlow("$L (rs.next())", methodInfo.singleResult() ? "if" : "while");
var singleColumn = methodInfo.singleColumn();
var multipleColumns = methodInfo.multipleColumns();
Expand All @@ -110,16 +115,15 @@ private CodeBlock queryMethodBody(DAOMethodInfo methodInfo) {
var conversion = lookupConversion(methodInfo::methodElement, mapping);
builder.addStatement(
"var rawValue = rs.get$L($S)", singleColumn.sqlTypeMapping().accessorSuffix(), singleColumn.name());
buildConversion(builder, conversion, "value", "rawValue", true);
buildConversion(builder, conversion, mapping.target(), "value", "rawValue", true);
} else if (!multipleColumns.isEmpty()) {
multipleColumns.forEach(daoResultColumn -> {
var conversion = lookupConversion(methodInfo::methodElement, daoResultColumn.asTypeMapping());
String rawName = daoResultColumn.name() + "Raw";
builder.addStatement(
"$T $L = rs.get$L($S)",
ClassName.get(
daoResultColumn.targetType().packageName(),
daoResultColumn.targetType().className()),
kiwiTypeConverter.fromKiwiType(
daoResultColumn.sqlTypeMapping().kiwiType()),
rawName,
daoResultColumn.sqlTypeMapping().accessorSuffix(),
daoResultColumn.name());
Expand All @@ -129,7 +133,8 @@ private CodeBlock queryMethodBody(DAOMethodInfo methodInfo) {
.endControlFlow();
}
var varName = patchName(daoResultColumn.name());
buildConversion(builder, conversion, varName, rawName, true);
buildConversion(
builder, conversion, daoResultColumn.asTypeMapping().target(), varName, rawName, true);
});
var params = multipleColumns.stream()
.map(p -> CodeBlock.of("$L", patchedNames.get(p.name())))
Expand All @@ -141,19 +146,31 @@ private CodeBlock queryMethodBody(DAOMethodInfo methodInfo) {
$L
);
""",
kiwiTypeConverter.fromKiwiType(methodInfo.resultComponentType()),
componentClass,
params);
} else {
throw new IllegalStateException("Expected singleColumn or multipleColumns");
}
builder.addStatement("l.add(value)")
.endControlFlow() // end while
.addStatement("return $L", methodInfo.fromList());
builder.addStatement("$L.add(value)", listVariable).endControlFlow(); // end while
if (methodInfo.signature().returnType() instanceof ContainerType containerType) {
builder.add("return ")
.addNamed(
containerType.type().fromListTemplate(),
Map.of("componentClass", componentClass, "listVariable", listVariable))
.addStatement("");
} else {
builder.addStatement("return $1L.isEmpty() ? null : $1L.get(0)", listVariable);
}
return builder.build();
}

private void buildConversion(
CodeBlock.Builder builder, Conversion conversion, String assignee, String accessor, boolean withVar) {
CodeBlock.Builder builder,
Conversion conversion,
KiwiType targetType,
String assignee,
String accessor,
boolean withVar) {
var insertVar = withVar ? "var " : "";
if (conversion instanceof AssignmentConversion) {
/* e.g.
Expand All @@ -168,20 +185,68 @@ private void buildConversion(
"$L$L = $L", insertVar, assignee, sfc.conversionFormat().formatted(accessor));
} else if (conversion instanceof ToSqlArrayConversion sac) {
/* e.g.
Object[] elementObjects = listParam.toArray();
var param1 = connection.createArrayOf("_int4", elementObjects);
Object[] elementObjects = listParam.stream()
.map(x -> (int) x)
.toArray();
var param1 = connection.createArrayOf("int4", elementObjects);
*/
Conversion elementConversion = sac.elementConversion();
String elementObjects = patchName("elementObjects");
builder.addStatement(
"Object[] $L = $L",
elementObjects,
String.format(sac.ct().type().toObjectArrayTemplate(), accessor));
String lambdaValue = patchName("value");
builder.add("Object[] $L = ", elementObjects)
.addNamed(sac.ct().type().toStreamTemplate(), Map.of("containerVariable", accessor))
.indent()
.add("\n.map($L -> {\n", lambdaValue)
.indent();
buildConversion(builder, elementConversion, sac.sat().containedType(), "tmp", lambdaValue, true);
builder.addStatement("return tmp")
.unindent()
.add("})\n.toArray();\n")
.unindent();
builder.addStatement(
"$L$L = connection.createArrayOf($S, $L)",
insertVar,
assignee,
sac.sat().dbType(),
sac.sat().componentDbType(),
elementObjects);
} else if (conversion instanceof FromSqlArrayConversion sac) {
/* e.g.
ResultSet arrayRS = rawValue.getResultSet();
List<String> arrayList = new ArrayList<>();
while (arrayRs.next()) {
var rawItemValue = arrayRs.getString(2);
var itemValue = rawItemValue;
arrayList.add(itemValue);
}
var value = List.copyOf(arrayList);
*/
var arrayRS = patchName("arrayRS");
var arrayList = patchName("arrayList");
var rawItemValue = patchName("rawItemValue");
var itemValue = patchName("itemValue");
TypeName componentClass = kiwiTypeConverter.fromKiwiType(sac.ct().containedType());
builder.addStatement("$T $L = $L.getResultSet()", ResultSet.class, arrayRS, accessor)
.addStatement("List<$T> $L = new $T<>()", componentClass, arrayList, ArrayList.class)
.beginControlFlow("while ($L.next())", arrayRS)
// Array.getResultSet() returns 2 columns: 1 is the index, 2 is the value
.addStatement(
"var $L = $L.get$L(2)",
rawItemValue,
arrayRS,
sac.sat().componentType().accessorSuffix());
buildConversion(builder, sac.elementConversion(), sac.ct().containedType(), itemValue, rawItemValue, true);
builder.addStatement("$L.add($L)", arrayList, itemValue)
.endControlFlow()
.add("$L$L = ", insertVar, assignee)
.addNamed(
sac.ct().type().fromListTemplate(),
Map.of("componentClass", componentClass, "listVariable", arrayList))
.addStatement("");
} else if (conversion instanceof NullableSourceConversion nsc) {
builder.addStatement("$T $L = null", kiwiTypeConverter.fromKiwiType(targetType), assignee)
.beginControlFlow("if ($L != null)", accessor);
buildConversion(builder, nsc.conversion(), targetType, assignee, accessor, false);
builder.endControlFlow();
} else {
logger.error(null, "Unsupported Conversion %s".formatted(conversion)); // TODO add Element
}
Expand All @@ -191,7 +256,7 @@ private String patchName(String name) {
return patchedNames.computeIfAbsent(name, k -> {
var newName = k;
while (parameterNames.contains(newName)) {
newName = "_" + newName;
newName = k + (++patchedNameCount);
}
return newName;
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package org.ethelred.kiwiproc.processor.types;

import java.sql.JDBCType;
import org.ethelred.kiwiproc.processor.SqlTypeMapping;

public record SqlArrayType(KiwiType containedType, JDBCType componentType, String dbType) implements KiwiType {
public record SqlArrayType(KiwiType containedType, SqlTypeMapping componentType, String componentDbType)
implements KiwiType {
@Override
public String packageName() {
return "";
return "java.sql";
}

@Override
public String className() {
return "ARRAY";
return "Array";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,28 @@ public enum ValidContainerType {
ARRAY(
Array.class,
"""
l.toArray(new %s[l.size()])
l.toArray(new $componentClass:T[$listVariable:L.size()])
""",
"""
java.util.Arrays.copyOf(%s, %<s.length, Object[].class)
java.util.stream.Stream.of($containerVariable:L)
"""),
ITERABLE(
Iterable.class,
"""
List.copyOf($listVariable:L)""",
"""
java.util.stream.StreamSupport.stream($containerVariable:L.spliterator(), false)
"""),
ITERABLE(Iterable.class),
COLLECTION(Collection.class),
LIST(List.class),
SET(Set.class, """
new java.util.LinkedHashSet<>(l)
new java.util.LinkedHashSet<>($listVariable:L)
"""),
OPTIONAL(
Optional.class,
"""
l.isEmpty() ? Optional.empty() : Optional.of(l.get(0))
""",
"""
%s.stream().toArray()""");
$listVariable:L.isEmpty() ? Optional.empty() : Optional.of($listVariable:L.get(0))
""");

private final Class<?> javaType;

Expand All @@ -36,20 +40,20 @@ public String fromListTemplate() {
}

private final String fromListTemplate;
private final String toObjectArrayTemplate;
private final String toStreamTemplate;

ValidContainerType(Class<?> javaType, String fromListTemplate, String toObjectArrayTemplate) {
ValidContainerType(Class<?> javaType, String fromListTemplate, String toStreamTemplate) {
this.javaType = javaType;
this.fromListTemplate = fromListTemplate;
this.toObjectArrayTemplate = toObjectArrayTemplate;
this.toStreamTemplate = toStreamTemplate;
}

ValidContainerType(Class<?> javaType) {
this(javaType, "List.copyOf(l)", "%s.toArray()");
this(javaType, "List.copyOf($listVariable:L)", "$containerVariable:L.stream()");
}

ValidContainerType(Class<?> javaType, String fromListTemplate) {
this(javaType, fromListTemplate, "%s.toArray()");
this(javaType, fromListTemplate, "$containerVariable:L.stream()");
}

public boolean isMultiValued() {
Expand All @@ -65,7 +69,7 @@ public String toString() {
return javaType().getName();
}

public String toObjectArrayTemplate() {
return toObjectArrayTemplate;
public String toStreamTemplate() {
return toStreamTemplate;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ public static Stream<Arguments> testConversions() {
arguments(ofClass(String.class), ofClass(Integer.class, true), true, true, "Integer.valueOf(value)"),
arguments(
new ContainerType(ValidContainerType.LIST, ofClass(Integer.class, true)),
new SqlArrayType(ofClass(int.class), JDBCType.INTEGER, "ignored"),
new SqlArrayType(
ofClass(int.class), new SqlTypeMapping(JDBCType.INTEGER, int.class, "Int"), "ignored"),
true,
false,
"fail"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,20 @@ ValidContainerType.LIST, recordType("TestRecord", "test1", ofClass(int.class))),
false,
"Missing component type for column test2 type String/non-null",
col(false, JDBCType.INTEGER),
col(false, JDBCType.VARCHAR)));
col(false, JDBCType.VARCHAR)),
testCase(
new ContainerType(
ValidContainerType.LIST,
recordType(
"TestRecord",
"test1",
ofClass(String.class),
"test2",
new ContainerType(ValidContainerType.LIST, ofClass(String.class)))),
true,
null,
col(false, JDBCType.VARCHAR),
col(false, JDBCType.ARRAY, new ArrayComponent(JDBCType.VARCHAR, "ignored"))));
}

private static KiwiType recordType(String className, String componentName, KiwiType componentType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package org.ethelred.kiwiproc.test;

import java.util.List;

public record OwnerPets(String owner_first_name, List<String> pet_names) {}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ record PetTypeWithCount(

@SqlQuery(
"""
SELECT t.id, t.name, count(*) FROM types t JOIN pets p ON t.id = p.type_id GROUP BY 1,2""")
SELECT t.id, t.name, count(*)
FROM types t JOIN pets p ON t.id = p.type_id GROUP BY 1,2""")
List<PetTypeWithCount> getPetTypesWithCountList();

default Map<PetType, Long> getPetTypesWithCount() {
Expand All @@ -41,4 +42,11 @@ default Map<PetType, Long> getPetTypesWithCount() {
@SqlQuery("""
SELECT id, first_name, last_name FROM owners WHERE id = ANY(:ids)""")
List<Owner> findOwnersByIds(List<Integer> ids);

@SqlQuery(
"""
SELECT o.first_name AS owner_first_name, array_agg(p.name) as pet_names
FROM owners o JOIN pets p ON o.id = p.owner_id
GROUP BY 1""")
List<OwnerPets> findOwnersAndPets();
}
Loading

0 comments on commit 6c01eea

Please sign in to comment.