Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing keyspace qualification for particular entities #1400

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.springframework.dao.DataAccessResourceFailureException;
import org.springframework.util.CollectionUtils;

import com.datastax.oss.driver.api.core.metadata.Node;

Expand All @@ -32,14 +37,26 @@ public class CassandraConnectionFailureException extends DataAccessResourceFailu

private static final long serialVersionUID = 6299912054261646552L;

private final Map<Node, Throwable> messagesByHost = new HashMap<>();
private final Map<Node, List<Throwable>> messagesByHost = new HashMap<>();

public CassandraConnectionFailureException(Map<Node, Throwable> map, String msg, Throwable cause) {
super(msg, cause);
map.forEach((node, throwable) -> messagesByHost.put(node, Collections.singletonList(throwable)));
}

public CassandraConnectionFailureException(String msg, Map<Node, List<Throwable>> map, Throwable cause) {
super(msg, cause);
this.messagesByHost.putAll(map);
}

@Deprecated(forRemoval = true)
public Map<Node, Throwable> getMessagesByHost() {
HashMap<Node, Throwable> singleMessageByHost = new HashMap<>();
this.messagesByHost.forEach((node, throwables) -> singleMessageByHost.put(node, CollectionUtils.firstElement(throwables)));
return singleMessageByHost;
}

public Map<Node, List<Throwable>> getAllMessagesByHost() {
return Collections.unmodifiableMap(messagesByHost);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,22 @@ public <T> List<T> select(Query query, Class<T> entityClass) throws DataAccessEx
Assert.notNull(query, "Query must not be null");
Assert.notNull(entityClass, "Entity type must not be null");

return doSelect(query, entityClass, getTableName(entityClass), entityClass);
return doSelect(query, entityClass, getTableName(entityClass), entityClass, getEntityOperations().getCustomKeyspaceName(entityClass));
}

<T> List<T> doSelect(Query query, Class<?> entityClass, CqlIdentifier tableName, Class<T> returnType) {
return this.doSelect(query, entityClass, tableName, returnType, null);
}

<T> List<T> doSelect(Query query, Class<?> entityClass, CqlIdentifier tableName, Class<T> returnType, @Nullable CqlIdentifier keyspace) {
CassandraPersistentEntity<?> entity = getRequiredPersistentEntity(entityClass);
EntityProjection<T, ?> projection = entityOperations.introspectProjection(returnType, entityClass);
Columns columns = getStatementFactory().computeColumnsForProjection(projection, query.getColumns(), entity,
returnType);
returnType);

Query queryToUse = query.columns(columns);

StatementBuilder<Select> select = getStatementFactory().select(queryToUse, entity, tableName);
StatementBuilder<Select> select = getStatementFactory().select(queryToUse, entity, tableName, keyspace);
Function<Row, T> mapper = getMapper(projection, tableName);

return doQuery(select.build(), (row, rowNum) -> mapper.apply(row));
Expand Down Expand Up @@ -463,7 +466,7 @@ public <T> Stream<T> stream(Query query, Class<T> entityClass) throws DataAccess
<T> Stream<T> doStream(Query query, Class<?> entityClass, CqlIdentifier tableName, Class<T> returnType) {

StatementBuilder<Select> select = getStatementFactory().select(query, getRequiredPersistentEntity(entityClass),
tableName);
tableName, getEntityOperations().getCustomKeyspaceName(entityClass));
EntityProjection<T, ?> projection = entityOperations.introspectProjection(returnType, entityClass);

Function<Row, T> mapper = getMapper(projection, tableName);
Expand Down Expand Up @@ -500,16 +503,20 @@ public boolean delete(Query query, Class<?> entityClass) throws DataAccessExcept
Assert.notNull(query, "Query must not be null");
Assert.notNull(entityClass, "Entity type must not be null");

WriteResult result = doDelete(query, entityClass, getTableName(entityClass));
WriteResult result = doDelete(query, entityClass, getTableName(entityClass), getEntityOperations().getCustomKeyspaceName(entityClass));

return result != null && result.wasApplied();
}

@Nullable
WriteResult doDelete(Query query, Class<?> entityClass, CqlIdentifier tableName) {
return this.doDelete(query, entityClass, tableName, null);
}

@Nullable
WriteResult doDelete(Query query, Class<?> entityClass, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) {

StatementBuilder<Delete> delete = getStatementFactory().delete(query, getRequiredPersistentEntity(entityClass),
tableName);
StatementBuilder<Delete> delete = getStatementFactory().delete(query, getRequiredPersistentEntity(entityClass), tableName, keyspace);
SimpleStatement statement = delete.build();

maybeEmitEvent(() -> new BeforeDeleteEvent<>(statement, entityClass, tableName));
Expand All @@ -521,6 +528,7 @@ WriteResult doDelete(Query query, Class<?> entityClass, CqlIdentifier tableName)
return writeResult;
}


// -------------------------------------------------------------------------
// Methods dealing with entities
// -------------------------------------------------------------------------
Expand All @@ -539,13 +547,17 @@ public long count(Query query, Class<?> entityClass) throws DataAccessException
Assert.notNull(query, "Query must not be null");
Assert.notNull(entityClass, "Entity type must not be null");

return doCount(query, entityClass, getTableName(entityClass));
return doCount(query, entityClass, getTableName(entityClass), getEntityOperations().getCustomKeyspaceName(entityClass));
}

long doCount(Query query, Class<?> entityClass, CqlIdentifier tableName) {
return this.doCount(query, entityClass, tableName, null);
}

long doCount(Query query, Class<?> entityClass, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) {

StatementBuilder<Select> countStatement = getStatementFactory().count(query,
getRequiredPersistentEntity(entityClass), tableName);
getRequiredPersistentEntity(entityClass), tableName, keyspace);

return doQueryForObject(countStatement.build(), Long.class);
}
Expand All @@ -557,7 +569,7 @@ public boolean exists(Object id, Class<?> entityClass) {
Assert.notNull(entityClass, "Entity type must not be null");

CassandraPersistentEntity<?> entity = getRequiredPersistentEntity(entityClass);
StatementBuilder<Select> select = getStatementFactory().selectOneById(id, entity, entity.getTableName());
StatementBuilder<Select> select = getStatementFactory().selectOneById(id, entity, entity.getTableName(), entity.getCustomKeyspace());

return doQueryForResultSet(select.build()).one() != null;
}
Expand All @@ -574,7 +586,7 @@ public boolean exists(Query query, Class<?> entityClass) throws DataAccessExcept
boolean doExists(Query query, Class<?> entityClass, CqlIdentifier tableName) {

StatementBuilder<Select> select = getStatementFactory().select(query.limit(1),
getRequiredPersistentEntity(entityClass), tableName);
getRequiredPersistentEntity(entityClass), tableName, getEntityOperations().getCustomKeyspaceName(entityClass));

return doQueryForResultSet(select.build()).one() != null;
}
Expand All @@ -587,7 +599,7 @@ public <T> T selectOneById(Object id, Class<T> entityClass) {

CassandraPersistentEntity<?> entity = getRequiredPersistentEntity(entityClass);
CqlIdentifier tableName = entity.getTableName();
StatementBuilder<Select> select = getStatementFactory().selectOneById(id, entity, tableName);
StatementBuilder<Select> select = getStatementFactory().selectOneById(id, entity, tableName, entity.getCustomKeyspace());
Function<Row, T> mapper = getMapper(EntityProjection.nonProjecting(entityClass), tableName);
List<T> result = doQuery(select.build(), (row, rowNum) -> mapper.apply(row));

Expand All @@ -605,18 +617,21 @@ public <T> EntityWriteResult<T> insert(T entity, InsertOptions options) {
Assert.notNull(entity, "Entity must not be null");
Assert.notNull(options, "InsertOptions must not be null");

return doInsert(entity, options, getTableName(entity.getClass()));
return doInsert(entity, options, getTableName(entity.getClass()), getEntityOperations().getCustomKeyspaceName(entity.getClass()));
}

<T> EntityWriteResult<T> doInsert(T entity, WriteOptions options, CqlIdentifier tableName) {
return this.doInsert(entity, options, tableName, null);
}
<T> EntityWriteResult<T> doInsert(T entity, WriteOptions options, CqlIdentifier tableName, @Nullable CqlIdentifier keyspace) {

AdaptibleEntity<T> source = getEntityOperations().forEntity(maybeCallBeforeConvert(entity, tableName),
getConverter().getConversionService());

T entityToUse = source.isVersionedEntity() ? source.initializeVersionProperty() : source.getBean();

StatementBuilder<RegularInsert> builder = getStatementFactory().insert(entityToUse, options,
source.getPersistentEntity(), tableName);
source.getPersistentEntity(), tableName, keyspace);

if (source.isVersionedEntity()) {

Expand Down Expand Up @@ -709,7 +724,7 @@ public WriteResult delete(Object entity, QueryOptions options) {
CassandraPersistentEntity<?> persistentEntity = getRequiredPersistentEntity(entity.getClass());
CqlIdentifier tableName = persistentEntity.getTableName();

StatementBuilder<Delete> builder = getStatementFactory().delete(entity, options, getConverter(), tableName);
StatementBuilder<Delete> builder = getStatementFactory().delete(entity, options, getConverter(), tableName, persistentEntity.getCustomKeyspace());

return source.isVersionedEntity()
? doDeleteVersioned(source.appendVersionCondition(builder).build(), entity, source, tableName)
Expand Down Expand Up @@ -743,7 +758,7 @@ public boolean deleteById(Object id, Class<?> entityClass) {
CassandraPersistentEntity<?> entity = getRequiredPersistentEntity(entityClass);
CqlIdentifier tableName = entity.getTableName();

StatementBuilder<Delete> delete = getStatementFactory().deleteById(id, entity, tableName);
StatementBuilder<Delete> delete = getStatementFactory().deleteById(id, entity, tableName, getEntityOperations().getCustomKeyspaceName(entityClass));
SimpleStatement statement = delete.build();

maybeEmitEvent(() -> new BeforeDeleteEvent<>(statement, entityClass, tableName));
Expand All @@ -761,7 +776,7 @@ public void truncate(Class<?> entityClass) {
Assert.notNull(entityClass, "Entity type must not be null");

CqlIdentifier tableName = getTableName(entityClass);
Truncate truncate = QueryBuilder.truncate(tableName);
Truncate truncate = QueryBuilder.truncate(getEntityOperations().getCustomKeyspaceName(entityClass), tableName);
SimpleStatement statement = truncate.build();

maybeEmitEvent(() -> new BeforeDeleteEvent<>(statement, entityClass, tableName));
Expand Down Expand Up @@ -900,12 +915,10 @@ private int getEffectivePageSize(Statement<?> statement) {
return statement.getPageSize();
}

if (getCqlOperations() instanceof CassandraAccessor) {

CassandraAccessor accessor = (CassandraAccessor) getCqlOperations();
if (getCqlOperations() instanceof CassandraAccessor cassandraAccessor) {

if (accessor.getFetchSize() != -1) {
return accessor.getFetchSize();
if (cassandraAccessor.getPageSize() != -1) {
return cassandraAccessor.getPageSize();
}
}

Expand All @@ -924,7 +937,6 @@ public String getCql() {
return getCqlOperations().execute(new GetConfiguredPageSize());
}

@SuppressWarnings("unchecked")
private <T> Function<Row, T> getMapper(EntityProjection<T, ?> projection, CqlIdentifier tableName) {

Class<T> targetType = projection.getMappedType().getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ CqlIdentifier getTableName(Class<?> entityClass) {
return getRequiredPersistentEntity(entityClass).getTableName();
}

/**
* Returns custom keyspace defined (if any) where the table for entity {@code entityClass} should be persisted.
* If the keyspace is not overridden in {@link org.springframework.data.cassandra.core.mapping.Table} annotation,
* then {@code null} is returned, signaling that default keyspace of {@link com.datastax.oss.driver.api.core.CqlSession}
* should be used
*
* @param entityClass entity class, must not be {@literal null}.
* @return custom keyspace defined (if any)
*/
@Nullable
CqlIdentifier getCustomKeyspaceName(Class<?> entityClass) {
return getRequiredPersistentEntity(entityClass).getCustomKeyspace();
}

/**
* Introspect the given {@link Class result type} in the context of the {@link Class entity type} whether the returned
* type is a projection and what property paths are participating in the projection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ static class ExecutableInsertSupport<T> implements ExecutableInsert<T> {
@Nullable private final CqlIdentifier tableName;

public ExecutableInsertSupport(CassandraTemplate template, Class<T> domainType, InsertOptions insertOptions,
CqlIdentifier tableName) {
@Nullable CqlIdentifier tableName) {
this.template = template;
this.domainType = domainType;
this.insertOptions = insertOptions;
Expand Down
Loading