Skip to content

Commit

Permalink
chore: add explicit validation for incoming JSONs
Browse files Browse the repository at this point in the history
This commit adds validation for incoming JSONs from the Chat Backend. If any required fields are missing or have the wrong type, a bad request exception will be thrown on the SDK side.
  • Loading branch information
ttypic committed Sep 4, 2024
1 parent 727b8d2 commit 5c3bcaf
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 40 deletions.
2 changes: 2 additions & 0 deletions chat-android/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ dependencies {
implementation(libs.gson)

testImplementation(libs.junit)
testImplementation(libs.mockk)
testImplementation(libs.coroutine.test)
androidTestImplementation(libs.androidx.test.core)
androidTestImplementation(libs.androidx.test.runner)
androidTestImplementation(libs.androidx.junit)
Expand Down
117 changes: 95 additions & 22 deletions chat-android/src/main/java/com/ably/chat/ChatApi.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.ably.chat

import com.google.gson.JsonElement
import com.google.gson.JsonObject
import com.google.gson.JsonPrimitive
import io.ably.lib.http.HttpCore
import io.ably.lib.http.HttpUtils
import io.ably.lib.types.AblyException
Expand All @@ -17,27 +18,37 @@ private const val PROTOCOL_VERSION_PARAM_NAME = "v"
private val apiProtocolParam = Param(PROTOCOL_VERSION_PARAM_NAME, API_PROTOCOL_VERSION.toString())

// TODO make this class internal
class ChatApi(private val realtimeClient: RealtimeClient) {
class ChatApi(private val realtimeClient: RealtimeClient, private val clientId: String) {

/**
* Get messages from the Chat Backend
*
* @return paginated result with messages
*/
suspend fun getMessages(roomId: String, params: QueryOptions): PaginatedResult<Message> {
return makeAuthorizedPaginatedRequest(
url = "/chat/v1/rooms/$roomId/messages",
method = "GET",
params = params.toParams(),
) {
Message(
timeserial = it.asJsonObject.get("timeserial").asString,
clientId = it.asJsonObject.get("clientId").asString,
roomId = it.asJsonObject.get("roomId").asString,
text = it.asJsonObject.get("text").asString,
createdAt = it.asJsonObject.get("createdAt").asLong,
metadata = it.asJsonObject.get("metadata")?.asJsonObject?.toMap() ?: mapOf(),
headers = it.asJsonObject.get("headers")?.asJsonObject?.toMap() ?: mapOf(),
timeserial = it.requireString("timeserial"),
clientId = it.requireString("clientId"),
roomId = it.requireString("roomId"),
text = it.requireString("text"),
createdAt = it.requireLong("createdAt"),
metadata = it.asJsonObject.get("metadata")?.toMap() ?: mapOf(),
headers = it.asJsonObject.get("headers")?.toMap() ?: mapOf(),
)
}
}

suspend fun sendMessage(roomId: String, params: SendMessageParams): CreateMessageResponse {
/**
* Send message to the Chat Backend
*
* @return sent message instance
*/
suspend fun sendMessage(roomId: String, params: SendMessageParams): Message {
val body = JsonObject().apply {
addProperty("text", params.text)
params.headers?.let {
Expand All @@ -53,18 +64,26 @@ class ChatApi(private val realtimeClient: RealtimeClient) {
"POST",
body,
)?.let {
CreateMessageResponse(
timeserial = it.asJsonObject.get("timeserial").asString,
createdAt = it.asJsonObject.get("createdAt").asLong,
Message(
timeserial = it.requireString("timeserial"),
clientId = clientId,
roomId = roomId,
text = params.text,
createdAt = it.requireLong("createdAt"),
metadata = params.metadata ?: mapOf(),
headers = params.headers ?: mapOf(),
)
} ?: throw AblyException.fromErrorInfo(ErrorInfo("Send message endpoint returned empty value", HttpStatusCodes.InternalServerError))
}

/**
* return occupancy for specified room
*/
suspend fun getOccupancy(roomId: String): OccupancyEvent {
return this.makeAuthorizedRequest("/chat/v1/rooms/$roomId/occupancy", "GET")?.let {
OccupancyEvent(
connections = it.asJsonObject.get("connections").asInt,
presenceMembers = it.asJsonObject.get("presenceMembers").asInt,
connections = it.requireInt("connections"),
presenceMembers = it.requireInt("presenceMembers"),
)
} ?: throw AblyException.fromErrorInfo(ErrorInfo("Occupancy endpoint returned empty value", HttpStatusCodes.InternalServerError))
}
Expand All @@ -81,8 +100,7 @@ class ChatApi(private val realtimeClient: RealtimeClient) {
arrayOf(apiProtocolParam),
requestBody,
arrayOf(),
object :
AsyncHttpPaginatedResponse.Callback {
object : AsyncHttpPaginatedResponse.Callback {
override fun onResponse(response: AsyncHttpPaginatedResponse?) {
continuation.resume(response?.items()?.firstOrNull())
}
Expand All @@ -106,8 +124,7 @@ class ChatApi(private val realtimeClient: RealtimeClient) {
(params + apiProtocolParam).toTypedArray(),
null,
arrayOf(),
object :
AsyncHttpPaginatedResponse.Callback {
object : AsyncHttpPaginatedResponse.Callback {
override fun onResponse(response: AsyncHttpPaginatedResponse?) {
continuation.resume(response.toPaginatedResult(transform))
}
Expand All @@ -120,17 +137,15 @@ class ChatApi(private val realtimeClient: RealtimeClient) {
}
}

data class CreateMessageResponse(val timeserial: String, val createdAt: Long)

private fun JsonElement?.toRequestBody(useBinaryProtocol: Boolean = false): HttpCore.RequestBody =
HttpUtils.requestBodyFromGson(this, useBinaryProtocol)

private fun Map<String, String>.toJson() = JsonObject().apply {
forEach { (key, value) -> addProperty(key, value) }
}

private fun JsonObject.toMap() = buildMap<String, String> {
entrySet().filter { (_, value) -> value.isJsonPrimitive }.forEach { (key, value) -> put(key, value.asString) }
private fun JsonElement.toMap() = buildMap<String, String> {
requireJsonObject().entrySet().filter { (_, value) -> value.isJsonPrimitive }.forEach { (key, value) -> put(key, value.asString) }
}

private fun QueryOptions.toParams() = buildList {
Expand All @@ -147,3 +162,61 @@ private fun QueryOptions.toParams() = buildList {
),
)
}

private fun JsonElement.requireJsonObject(): JsonObject {
if (!isJsonObject) {
throw AblyException.fromErrorInfo(
ErrorInfo("Response value expected to be JsonObject, got primitive instead", HttpStatusCodes.InternalServerError),
)
}
return asJsonObject
}

private fun JsonElement.requireString(memberName: String): String {
val memberElement = requireField(memberName)
if (!memberElement.isJsonPrimitive) {
throw AblyException.fromErrorInfo(
ErrorInfo("Value for \"$memberName\" field expected to be JsonPrimitive, got object instead", HttpStatusCodes.InternalServerError),
)
}
return memberElement.asString
}

private fun JsonElement.requireLong(memberName: String): Long {
val memberElement = requireJsonPrimitive(memberName)
try {
return memberElement.asLong
} catch (formatException: NumberFormatException) {
throw AblyException.fromErrorInfo(
formatException,
ErrorInfo("Required numeric field \"$memberName\" is not a valid long", HttpStatusCodes.InternalServerError),
)
}
}

private fun JsonElement.requireInt(memberName: String): Int {
val memberElement = requireJsonPrimitive(memberName)
try {
return memberElement.asInt
} catch (formatException: NumberFormatException) {
throw AblyException.fromErrorInfo(
formatException,
ErrorInfo("Required numeric field \"$memberName\" is not a valid int", HttpStatusCodes.InternalServerError),
)
}
}

private fun JsonElement.requireJsonPrimitive(memberName: String): JsonPrimitive {
val memberElement = requireField(memberName)
if (!memberElement.isJsonPrimitive) {
throw AblyException.fromErrorInfo(
ErrorInfo("Value for \"$memberName\" field expected to be JsonPrimitive, got object instead", HttpStatusCodes.InternalServerError),
)
}
return memberElement.asJsonPrimitive
}

private fun JsonElement.requireField(memberName: String): JsonElement = requireJsonObject().get(memberName)
?: throw AblyException.fromErrorInfo(
ErrorInfo("Required field \"$memberName\" is missing", HttpStatusCodes.InternalServerError),
)
3 changes: 1 addition & 2 deletions chat-android/src/main/java/com/ably/chat/PaginatedResult.kt
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ private class AsyncPaginatedResultWrapper<T>(
val asyncPaginatedResult: AsyncHttpPaginatedResponse,
val transform: (JsonElement) -> T,
) : PaginatedResult<T> {
override val items: List<T>
get() = asyncPaginatedResult.items()?.map(transform) ?: emptyList()
override val items: List<T> = asyncPaginatedResult.items()?.map(transform) ?: emptyList()

override suspend fun next(): PaginatedResult<T> = suspendCoroutine { continuation ->
asyncPaginatedResult.next(object : AsyncHttpPaginatedResponse.Callback {
Expand Down
172 changes: 172 additions & 0 deletions chat-android/src/test/java/com/ably/chat/ChatApiTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package com.ably.chat

import com.google.gson.JsonElement
import com.google.gson.JsonObject
import io.ably.lib.types.AblyException
import io.ably.lib.types.AsyncHttpPaginatedResponse
import io.mockk.every
import io.mockk.mockk
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertEquals
import org.junit.Assert.assertThrows
import org.junit.Assert.assertTrue
import org.junit.Test

class ChatApiTest {

private val realtime = mockk<RealtimeClient>(relaxed = true)
private val chatApi = ChatApi(realtime, "clientId")

@Test
fun `getMessages should ignore unknown fields for Chat Backend`() = runTest {
every {
realtime.requestAsync("GET", "/chat/v1/rooms/roomId/messages", any(), any(), any(), any())
} answers {
val callback = lastArg<AsyncHttpPaginatedResponse.Callback>()
callback.onResponse(
buildAsyncHttpPaginatedResponse(
listOf(
JsonObject().apply {
addProperty("foo", "bar")
addProperty("timeserial", "timeserial")
addProperty("roomId", "roomId")
addProperty("clientId", "clientId")
addProperty("text", "hello")
addProperty("createdAt", 1_000_000)
},
),
),
)
}

val messages = chatApi.getMessages("roomId", QueryOptions())

assertEquals(
listOf(
Message(
timeserial = "timeserial",
roomId = "roomId",
clientId = "clientId",
text = "hello",
createdAt = 1_000_000L,
metadata = mapOf(),
headers = mapOf(),
),
),
messages.items,
)
}

@Test
fun `getMessages should throws AblyException if some required fields are missing`() = runTest {
every {
realtime.requestAsync("GET", "/chat/v1/rooms/roomId/messages", any(), any(), any(), any())
} answers {
val callback = lastArg<AsyncHttpPaginatedResponse.Callback>()
callback.onResponse(
buildAsyncHttpPaginatedResponse(
listOf(
JsonObject().apply {
addProperty("foo", "bar")
},
),
),
)
}

val exception = assertThrows(AblyException::class.java) {
runBlocking { chatApi.getMessages("roomId", QueryOptions()) }
}

assertTrue(exception.message!!.matches(""".*Required field "\w+" is missing""".toRegex()))
}

@Test
fun `sendMessage should ignore unknown fields for Chat Backend`() = runTest {
every {
realtime.requestAsync("POST", "/chat/v1/rooms/roomId/messages", any(), any(), any(), any())
} answers {
val callback = lastArg<AsyncHttpPaginatedResponse.Callback>()
callback.onResponse(
buildAsyncHttpPaginatedResponse(
listOf(
JsonObject().apply {
addProperty("foo", "bar")
addProperty("timeserial", "timeserial")
addProperty("createdAt", 1_000_000)
},
),
),
)
}

val message = chatApi.sendMessage("roomId", SendMessageParams(text = "hello"))

assertEquals(
Message(
timeserial = "timeserial",
roomId = "roomId",
clientId = "clientId",
text = "hello",
createdAt = 1_000_000L,
headers = mapOf(),
metadata = mapOf(),
),
message,
)
}

@Test
fun `sendMessage should throw exception if 'timeserial' field is not presented`() = runTest {
every {
realtime.requestAsync("POST", "/chat/v1/rooms/roomId/messages", any(), any(), any(), any())
} answers {
val callback = lastArg<AsyncHttpPaginatedResponse.Callback>()
callback.onResponse(
buildAsyncHttpPaginatedResponse(
listOf(
JsonObject().apply {
addProperty("foo", "bar")
addProperty("createdAt", 1_000_000)
},
),
),
)
}

assertThrows(AblyException::class.java) {
runBlocking { chatApi.sendMessage("roomId", SendMessageParams(text = "hello")) }
}
}

@Test
fun `getOccupancy should throw exception if 'connections' field is not presented`() = runTest {
every {
realtime.requestAsync("GET", "/chat/v1/rooms/roomId/occupancy", any(), any(), any(), any())
} answers {
val callback = lastArg<AsyncHttpPaginatedResponse.Callback>()
callback.onResponse(
buildAsyncHttpPaginatedResponse(
listOf(
JsonObject().apply {
addProperty("presenceMembers", 1_000)
},
),
),
)
}

assertThrows(AblyException::class.java) {
runBlocking { chatApi.getOccupancy("roomId") }
}
}
}

private fun buildAsyncHttpPaginatedResponse(items: List<JsonElement>): AsyncHttpPaginatedResponse {
val response = mockk<AsyncHttpPaginatedResponse>()
every {
response.items()
} returns items.toTypedArray()
return response
}
16 changes: 0 additions & 16 deletions chat-android/src/test/java/com/ably/chat/ExampleUnitTest.kt

This file was deleted.

Loading

0 comments on commit 5c3bcaf

Please sign in to comment.