Skip to content

Commit

Permalink
Implemented spring-projectsGH-921 keyspace customization feature
Browse files Browse the repository at this point in the history
  • Loading branch information
mipo256 committed Jun 22, 2023
1 parent e950084 commit 23aa3c9
Show file tree
Hide file tree
Showing 15 changed files with 309 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,22 @@ public <T> List<T> select(Query query, Class<T> entityClass) throws DataAccessEx
Assert.notNull(query, "Query must not be null");
Assert.notNull(entityClass, "Entity type must not be null");

return doSelect(query, entityClass, getTableName(entityClass), entityClass);
return doSelect(query, entityClass, getTableName(entityClass), entityClass, getEntityOperations().getCustomKeyspaceName(entityClass));
}

<T> List<T> doSelect(Query query, Class<?> entityClass, CqlIdentifier tableName, Class<T> returnType) {
return this.doSelect(query, entityClass, tableName, returnType, null);
}

<T> List<T> doSelect(Query query, Class<?> entityClass, CqlIdentifier tableName, Class<T> returnType, @Nullable CqlIdentifier keyspace) {
CassandraPersistentEntity<?> entity = getRequiredPersistentEntity(entityClass);
EntityProjection<T, ?> projection = entityOperations.introspectProjection(returnType, entityClass);
Columns columns = getStatementFactory().computeColumnsForProjection(projection, query.getColumns(), entity,
returnType);
returnType);

Query queryToUse = query.columns(columns);

StatementBuilder<Select> select = getStatementFactory().select(queryToUse, entity, tableName);
StatementBuilder<Select> select = getStatementFactory().select(queryToUse, entity, tableName, keyspace);
Function<Row, T> mapper = getMapper(projection, tableName);

return doQuery(select.build(), (row, rowNum) -> mapper.apply(row));
Expand Down Expand Up @@ -463,7 +466,7 @@ public <T> Stream<T> stream(Query query, Class<T> entityClass) throws DataAccess
<T> Stream<T> doStream(Query query, Class<?> entityClass, CqlIdentifier tableName, Class<T> returnType) {

StatementBuilder<Select> select = getStatementFactory().select(query, getRequiredPersistentEntity(entityClass),
tableName);
tableName, getEntityOperations().getCustomKeyspaceName(entityClass));
EntityProjection<T, ?> projection = entityOperations.introspectProjection(returnType, entityClass);

Function<Row, T> mapper = getMapper(projection, tableName);
Expand Down Expand Up @@ -500,16 +503,20 @@ public boolean delete(Query query, Class<?> entityClass) throws DataAccessExcept
Assert.notNull(query, "Query must not be null");
Assert.notNull(entityClass, "Entity type must not be null");

WriteResult result = doDelete(query, entityClass, getTableName(entityClass));
WriteResult result = doDelete(query, entityClass, getTableName(entityClass), getEntityOperations().getCustomKeyspaceName(entityClass));

return result != null && result.wasApplied();
}

@Nullable
WriteResult doDelete(Query query, Class<?> entityClass, CqlIdentifier tableName) {
return this.doDelete(query, entityClass, tableName, null);
}

StatementBuilder<Delete> delete = getStatementFactory().delete(query, getRequiredPersistentEntity(entityClass),
tableName);
@Nullable
WriteResult doDelete(Query query, Class<?> entityClass, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) {

StatementBuilder<Delete> delete = getStatementFactory().delete(query, getRequiredPersistentEntity(entityClass), tableName, keyspace);
SimpleStatement statement = delete.build();

maybeEmitEvent(() -> new BeforeDeleteEvent<>(statement, entityClass, tableName));
Expand All @@ -521,6 +528,7 @@ WriteResult doDelete(Query query, Class<?> entityClass, CqlIdentifier tableName)
return writeResult;
}


// -------------------------------------------------------------------------
// Methods dealing with entities
// -------------------------------------------------------------------------
Expand All @@ -539,13 +547,17 @@ public long count(Query query, Class<?> entityClass) throws DataAccessException
Assert.notNull(query, "Query must not be null");
Assert.notNull(entityClass, "Entity type must not be null");

return doCount(query, entityClass, getTableName(entityClass));
return doCount(query, entityClass, getTableName(entityClass), getEntityOperations().getCustomKeyspaceName(entityClass));
}

long doCount(Query query, Class<?> entityClass, CqlIdentifier tableName) {
return this.doCount(query, entityClass, tableName, null);
}

long doCount(Query query, Class<?> entityClass, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) {

StatementBuilder<Select> countStatement = getStatementFactory().count(query,
getRequiredPersistentEntity(entityClass), tableName);
getRequiredPersistentEntity(entityClass), tableName, keyspace);

return doQueryForObject(countStatement.build(), Long.class);
}
Expand All @@ -557,7 +569,7 @@ public boolean exists(Object id, Class<?> entityClass) {
Assert.notNull(entityClass, "Entity type must not be null");

CassandraPersistentEntity<?> entity = getRequiredPersistentEntity(entityClass);
StatementBuilder<Select> select = getStatementFactory().selectOneById(id, entity, entity.getTableName());
StatementBuilder<Select> select = getStatementFactory().selectOneById(id, entity, entity.getTableName(), entity.getCustomKeyspace());

return doQueryForResultSet(select.build()).one() != null;
}
Expand Down Expand Up @@ -587,7 +599,7 @@ public <T> T selectOneById(Object id, Class<T> entityClass) {

CassandraPersistentEntity<?> entity = getRequiredPersistentEntity(entityClass);
CqlIdentifier tableName = entity.getTableName();
StatementBuilder<Select> select = getStatementFactory().selectOneById(id, entity, tableName);
StatementBuilder<Select> select = getStatementFactory().selectOneById(id, entity, tableName, entity.getCustomKeyspace());
Function<Row, T> mapper = getMapper(EntityProjection.nonProjecting(entityClass), tableName);
List<T> result = doQuery(select.build(), (row, rowNum) -> mapper.apply(row));

Expand All @@ -605,18 +617,21 @@ public <T> EntityWriteResult<T> insert(T entity, InsertOptions options) {
Assert.notNull(entity, "Entity must not be null");
Assert.notNull(options, "InsertOptions must not be null");

return doInsert(entity, options, getTableName(entity.getClass()));
return doInsert(entity, options, getTableName(entity.getClass()), getEntityOperations().getCustomKeyspaceName(entity.getClass()));
}

<T> EntityWriteResult<T> doInsert(T entity, WriteOptions options, CqlIdentifier tableName) {
return this.doInsert(entity, options, tableName, null);
}
<T> EntityWriteResult<T> doInsert(T entity, WriteOptions options, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) {

AdaptibleEntity<T> source = getEntityOperations().forEntity(maybeCallBeforeConvert(entity, tableName),
getConverter().getConversionService());

T entityToUse = source.isVersionedEntity() ? source.initializeVersionProperty() : source.getBean();

StatementBuilder<RegularInsert> builder = getStatementFactory().insert(entityToUse, options,
source.getPersistentEntity(), tableName);
source.getPersistentEntity(), tableName, keyspace);

if (source.isVersionedEntity()) {

Expand Down Expand Up @@ -709,7 +724,7 @@ public WriteResult delete(Object entity, QueryOptions options) {
CassandraPersistentEntity<?> persistentEntity = getRequiredPersistentEntity(entity.getClass());
CqlIdentifier tableName = persistentEntity.getTableName();

StatementBuilder<Delete> builder = getStatementFactory().delete(entity, options, getConverter(), tableName);
StatementBuilder<Delete> builder = getStatementFactory().delete(entity, options, getConverter(), tableName, persistentEntity.getCustomKeyspace());

return source.isVersionedEntity()
? doDeleteVersioned(source.appendVersionCondition(builder).build(), entity, source, tableName)
Expand Down Expand Up @@ -743,7 +758,7 @@ public boolean deleteById(Object id, Class<?> entityClass) {
CassandraPersistentEntity<?> entity = getRequiredPersistentEntity(entityClass);
CqlIdentifier tableName = entity.getTableName();

StatementBuilder<Delete> delete = getStatementFactory().deleteById(id, entity, tableName);
StatementBuilder<Delete> delete = getStatementFactory().deleteById(id, entity, tableName, getEntityOperations().getCustomKeyspaceName(entityClass));
SimpleStatement statement = delete.build();

maybeEmitEvent(() -> new BeforeDeleteEvent<>(statement, entityClass, tableName));
Expand All @@ -761,7 +776,7 @@ public void truncate(Class<?> entityClass) {
Assert.notNull(entityClass, "Entity type must not be null");

CqlIdentifier tableName = getTableName(entityClass);
Truncate truncate = QueryBuilder.truncate(tableName);
Truncate truncate = QueryBuilder.truncate(getEntityOperations().getCustomKeyspaceName(entityClass), tableName);
SimpleStatement statement = truncate.build();

maybeEmitEvent(() -> new BeforeDeleteEvent<>(statement, entityClass, tableName));
Expand Down Expand Up @@ -924,7 +939,6 @@ public String getCql() {
return getCqlOperations().execute(new GetConfiguredPageSize());
}

@SuppressWarnings("unchecked")
private <T> Function<Row, T> getMapper(EntityProjection<T, ?> projection, CqlIdentifier tableName) {

Class<T> targetType = projection.getMappedType().getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ CqlIdentifier getTableName(Class<?> entityClass) {
return getRequiredPersistentEntity(entityClass).getTableName();
}

/**
* Returns custom keyspace defined (if any) where the table for entity {@code entityClass} should be persisted.
* If the keyspace is not overridden in {@link org.springframework.data.cassandra.core.mapping.Table} annotation,
* then {@code null} is returned, signaling that default keyspace of {@link com.datastax.oss.driver.api.core.CqlSession}
* should be used
*
* @param entityClass entity class, must not be {@literal null}.
* @return custom keyspace defined (if any)
*/
@Nullable
CqlIdentifier getCustomKeyspaceName(Class<?> entityClass) {
return getRequiredPersistentEntity(entityClass).getCustomKeyspace();
}

/**
* Introspect the given {@link Class result type} in the context of the {@link Class entity type} whether the returned
* type is a projection and what property paths are participating in the projection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ static class ExecutableInsertSupport<T> implements ExecutableInsert<T> {
@Nullable private final CqlIdentifier tableName;

public ExecutableInsertSupport(CassandraTemplate template, Class<T> domainType, InsertOptions insertOptions,
CqlIdentifier tableName) {
@Nullable CqlIdentifier tableName) {
this.template = template;
this.domainType = domainType;
this.insertOptions = insertOptions;
Expand Down
Loading

0 comments on commit 23aa3c9

Please sign in to comment.