Skip to content

Commit

Permalink
fix: stop propagating request to backend if not valid
Browse files Browse the repository at this point in the history
  • Loading branch information
ytvnr authored and phiz71 committed Mar 28, 2022
1 parent b5752cc commit 877f812
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

<json-schema-validator.version>2.2.14</json-schema-validator.version>
<swagger-parser.version>2.0.22</swagger-parser.version>
<mockito.version>3.5.13</mockito.version>
<mockito.version>4.4.0</mockito.version>

<json-schema-generator-maven-plugin.version>1.3.0</json-schema-generator-maven-plugin.version>
<json-schema-generator-maven-plugin.outputDirectory>${project.build.directory}/schemas</json-schema-generator-maven-plugin.outputDirectory>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import io.gravitee.gateway.api.ExecutionContext;
import io.gravitee.gateway.api.Request;
import io.gravitee.gateway.api.Response;
import io.gravitee.gateway.api.http.stream.TransformableRequestStreamBuilder;
import io.gravitee.gateway.api.http.stream.TransformableResponseStreamBuilder;
import io.gravitee.gateway.api.buffer.Buffer;
import io.gravitee.gateway.api.stream.BufferedReadWriteStream;
import io.gravitee.gateway.api.stream.ReadWriteStream;
import io.gravitee.gateway.api.stream.SimpleReadWriteStream;
import io.gravitee.policy.api.PolicyChain;
import io.gravitee.policy.api.PolicyResult;
import io.gravitee.policy.api.annotations.OnRequestContent;
Expand All @@ -53,7 +54,7 @@ public class JsonValidationPolicy {
/**
* The associated configuration to this JsonMetadata Policy
*/
private JsonValidationPolicyConfiguration configuration;
private final JsonValidationPolicyConfiguration configuration;

private static final JsonValidator validator = JsonSchemaFactory.byDefault().getValidator();

Expand All @@ -75,10 +76,18 @@ public ReadWriteStream onRequestContent(
) {
if (configuration.getScope() == null || configuration.getScope() == PolicyScope.REQUEST_CONTENT) {
logger.debug("Execute json schema validation policy on request content{}", request.id());
return TransformableRequestStreamBuilder
.on(request)
.chain(policyChain)
.transform(buffer -> {

return new BufferedReadWriteStream() {
final Buffer buffer = Buffer.buffer();

@Override
public SimpleReadWriteStream<Buffer> write(Buffer content) {
buffer.appendBuffer(content);
return this;
}

@Override
public void end() {
try {
JsonNode schema = JsonLoader.fromString(configuration.getSchema());
JsonNode content = JsonLoader.fromString(buffer.toString());
Expand All @@ -87,14 +96,18 @@ public ReadWriteStream onRequestContent(
if (!report.isSuccess()) {
request.metrics().setMessage(report.toString());
sendErrorResponse(JSON_INVALID_PAYLOAD_KEY, executionContext, policyChain, HttpStatusCode.BAD_REQUEST_400);
} else {
if (buffer.length() > 0) {
super.write(buffer);
}
super.end();
}
} catch (Exception ex) {
request.metrics().setMessage(ex.getMessage());
sendErrorResponse(JSON_INVALID_FORMAT_KEY, executionContext, policyChain, HttpStatusCode.BAD_REQUEST_400);
}
return buffer;
})
.build();
}
};
}
return null;
}
Expand All @@ -107,10 +120,18 @@ public ReadWriteStream onResponseContent(
PolicyChain policyChain
) {
if (configuration.getScope() == PolicyScope.RESPONSE_CONTENT) {
return TransformableResponseStreamBuilder
.on(response)
.chain(policyChain)
.transform(buffer -> {
logger.debug("Execute json schema validation policy on request content{}", request.id());
return new BufferedReadWriteStream() {
final Buffer buffer = Buffer.buffer();

@Override
public SimpleReadWriteStream<Buffer> write(Buffer content) {
buffer.appendBuffer(content);
return this;
}

@Override
public void end() {
try {
JsonNode schema = JsonLoader.fromString(configuration.getSchema());
JsonNode content = JsonLoader.fromString(buffer.toString());
Expand All @@ -125,7 +146,12 @@ public ReadWriteStream onResponseContent(
policyChain,
HttpStatusCode.INTERNAL_SERVER_ERROR_500
);
} else {
// In Straight Respond Mode, send the response to the user without replacement.
writeBufferAndEnd();
}
} else {
writeBufferAndEnd();
}
} catch (Exception ex) {
request.metrics().setMessage(ex.toString());
Expand All @@ -138,9 +164,15 @@ public ReadWriteStream onResponseContent(
);
}
}
return buffer;
})
.build();
}

private void writeBufferAndEnd() {
if (buffer.length() > 0) {
super.write(buffer);
}
super.end();
}
};
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
import io.gravitee.gateway.api.Response;
import io.gravitee.gateway.api.buffer.Buffer;
import io.gravitee.gateway.api.buffer.BufferFactory;
import io.gravitee.gateway.api.stream.BufferedReadWriteStream;
import io.gravitee.gateway.api.stream.ReadWriteStream;
import io.gravitee.gateway.api.stream.SimpleReadWriteStream;
import io.gravitee.policy.api.PolicyChain;
import io.gravitee.policy.api.PolicyResult;
import io.gravitee.policy.jsonvalidation.configuration.JsonValidationPolicyConfiguration;
import io.gravitee.policy.jsonvalidation.configuration.PolicyScope;
import io.gravitee.reporter.api.http.Metrics;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -83,8 +86,6 @@ public void beforeAll() {
metrics = Metrics.on(System.currentTimeMillis()).build();
HttpHeaders headers = spy(new HttpHeaders());

when(mockRequest.headers()).thenReturn(headers);
when(mockResponse.headers()).thenReturn(headers);
when(configuration.getErrorMessage()).thenReturn("{\"msg\":\"error\"}");
when(configuration.getSchema()).thenReturn(jsonschema);
when(mockRequest.metrics()).thenReturn(metrics);
Expand All @@ -100,8 +101,13 @@ public void shouldAcceptValidPayload() {
JsonValidationPolicy policy = new JsonValidationPolicy(configuration);
Buffer buffer = factory.buffer("{\"name\":\"foo\"}");
ReadWriteStream readWriteStream = policy.onRequestContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isTrue();
})
.doesNotThrowAnyException();
}
Expand All @@ -112,9 +118,14 @@ public void shouldValidateRejectInvalidPayload() {

Buffer buffer = factory.buffer("{\"name\":1}");
ReadWriteStream readWriteStream = policy.onRequestContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

policyAssertions(JsonValidationPolicy.JSON_INVALID_PAYLOAD_KEY);
}

Expand All @@ -124,9 +135,14 @@ public void shouldValidateUncheckedRejectInvalidPayload() {

Buffer buffer = factory.buffer("{\"name\":1}");
ReadWriteStream readWriteStream = policy.onRequestContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

policyAssertions(JsonValidationPolicy.JSON_INVALID_PAYLOAD_KEY);
}

Expand All @@ -136,9 +152,14 @@ public void shouldMalformedPayloadBeRejected() {

Buffer buffer = factory.buffer("{\"name\":");
ReadWriteStream readWriteStream = policy.onRequestContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

policyAssertions(JsonValidationPolicy.JSON_INVALID_FORMAT_KEY);
}

Expand All @@ -149,9 +170,14 @@ public void shouldMalformedJsonSchemaBeRejected() {

Buffer buffer = factory.buffer("{\"name\":\"foo\"}");
ReadWriteStream readWriteStream = policy.onRequestContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

policyAssertions(JsonValidationPolicy.JSON_INVALID_FORMAT_KEY);
}

Expand All @@ -167,8 +193,12 @@ public void shouldAcceptValidResponsePayload() {
mockExecutionContext,
mockPolicychain
);
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isTrue();
})
.doesNotThrowAnyException();
}
Expand All @@ -179,9 +209,13 @@ public void shouldValidateResponseInvalidPayload() {

Buffer buffer = factory.buffer("{\"name\":1}");
ReadWriteStream readWriteStream = policy.onResponseContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

policyAssertions(JsonValidationPolicy.JSON_INVALID_RESPONSE_PAYLOAD_KEY, HttpStatusCode.INTERNAL_SERVER_ERROR_500);
}

Expand All @@ -192,9 +226,14 @@ public void shouldValidateResponseInvalidSchema() {

Buffer buffer = factory.buffer("{\"name\":\"foo\"}");
ReadWriteStream readWriteStream = policy.onResponseContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

policyAssertions(JsonValidationPolicy.JSON_INVALID_RESPONSE_FORMAT_KEY, HttpStatusCode.INTERNAL_SERVER_ERROR_500);
}

Expand All @@ -205,9 +244,14 @@ public void shouldValidateResponseInvalidPayloadStraightRespondMode() {

Buffer buffer = factory.buffer("{\"name\":1}");
ReadWriteStream readWriteStream = policy.onResponseContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isTrue();

policyAssertions();
}

Expand All @@ -219,9 +263,14 @@ public void shouldValidateResponseInvalidSchemaStraightRespondMode() {

Buffer buffer = factory.buffer("{\"name\":\"foo\"}");
ReadWriteStream readWriteStream = policy.onResponseContent(mockRequest, mockResponse, mockExecutionContext, mockPolicychain);

final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = spyEndHandler(readWriteStream);

readWriteStream.write(buffer);
readWriteStream.end();

assertThat(hasCalledEndOnReadWriteStreamParentClass).isFalse();

policyAssertions();
}

Expand All @@ -244,4 +293,19 @@ private void policyAssertions(String key, int statusCode) {
assertThat(value.statusCode()).isEqualTo(statusCode);
assertThat(value.key()).isEqualTo(key);
}

/**
* Replace the endHandler of the resulting ReadWriteStream of the policy execution.
* This endHandler will set an {@link AtomicBoolean} to {@code true} if its called.
* It will allow us to verify if super.end() has been called on {@link BufferedReadWriteStream#end()}
* @param readWriteStream: the {@link ReadWriteStream} to modify
* @return an AtomicBoolean set to {@code true} if {@link SimpleReadWriteStream#end()}, else {@code false}
*/
private AtomicBoolean spyEndHandler(ReadWriteStream readWriteStream) {
final AtomicBoolean hasCalledEndOnReadWriteStreamParentClass = new AtomicBoolean(false);
readWriteStream.endHandler(__ -> {
hasCalledEndOnReadWriteStreamParentClass.set(true);
});
return hasCalledEndOnReadWriteStreamParentClass;
}
}

0 comments on commit 877f812

Please sign in to comment.