diff --git a/grpc_client/lib/grpc/client/adapters/gun.ex b/grpc_client/lib/grpc/client/adapters/gun.ex index 6ca8331b..8dde7778 100644 --- a/grpc_client/lib/grpc/client/adapters/gun.ex +++ b/grpc_client/lib/grpc/client/adapters/gun.ex @@ -206,9 +206,15 @@ defmodule GRPC.Client.Adapters.Gun do defp handle_nofin_response(adapter_payload, payload, stream, headers, opts) do # Regular response: fetch body and trailers with {:ok, body, trailers} <- recv_body(adapter_payload, payload, opts), - {:ok, response} <- parse_response(stream, headers, body, trailers) do + {:ok, response, embedded_trailers} <- parse_response(stream, headers, body, trailers) do if opts[:return_headers] do - {:ok, response, %{headers: headers, trailers: trailers}} + all_trailers = Map.merge(trailers, embedded_trailers) + + { + :ok, + response, + %{headers: headers, trailers: all_trailers} + } else {:ok, response} end @@ -410,15 +416,26 @@ defmodule GRPC.Client.Adapters.Gun do end end - defp read_stream(%{buffer: buffer, need_more: false, response_mod: res_mod, codec: codec} = s) do + defp read_stream( + %{buffer: buffer, need_more: false, response_mod: res_mod, codec: codec, opts: opts} = + stream + ) do case GRPC.Message.get_message(buffer) do + {{:trailers, trailers}, rest} -> + new_stream = + stream + |> update_stream_with_trailers(trailers, opts[:return_headers]) + |> Map.put(:buffer, rest) + + {{:ok, trailers}, new_stream} + {{_, message}, rest} -> reply = codec.decode(message, res_mod) - new_s = Map.put(s, :buffer, rest) - {{:ok, reply}, new_s} + new_stream = Map.put(stream, :buffer, rest) + {{:ok, reply}, new_stream} _ -> - read_stream(Map.put(s, :need_more, true)) + read_stream(Map.put(stream, :need_more, true)) end end @@ -431,8 +448,17 @@ defmodule GRPC.Client.Adapters.Gun do with :ok <- parse_trailers(trailers), compressor <- get_compressor(headers, accepted_compressors), body <- get_body(codec, body), - {:ok, msg} <- GRPC.Message.from_data(%{compressor: compressor}, body) do - {:ok, codec.decode(msg, res_mod)} + {:ok, msg, remaining} <- GRPC.Message.from_data(%{compressor: compressor}, body) do + {:ok, codec.decode(msg, res_mod), check_for_trailers(remaining, compressor)} + end + end + + defp check_for_trailers(<<>>, _compressor), do: %{} + + defp check_for_trailers(body, compressor) do + case GRPC.Message.from_data(%{compressor: compressor}, body) do + {:trailers, trailers, <<>>} -> trailers + _ -> %{} end end diff --git a/grpc_core/lib/grpc/message.ex b/grpc_core/lib/grpc/message.ex index 03ac2e33..12b30b4f 100644 --- a/grpc_core/lib/grpc/message.ex +++ b/grpc_core/lib/grpc/message.ex @@ -11,9 +11,11 @@ defmodule GRPC.Message do Message -> *{binary octet} """ + import Bitwise alias GRPC.RPCError - @max_message_length Bitwise.bsl(1, 32 - 1) + @max_message_length bsl(1, 32 - 1) + @trailers_flag 0b1000_0000 @doc """ Transforms Protobuf data into a gRPC body binary. @@ -46,12 +48,13 @@ defmodule GRPC.Message do iolist = opts[:iolist] codec = opts[:codec] max_length = opts[:max_message_length] || @max_message_length + additional_flags = opts[:message_flag] || 0 - {compress_flag, message} = + {flag, message} = if compressor do - {1, compressor.compress(message)} + {1 ||| additional_flags, compressor.compress(message)} else - {0, message} + {0 ||| additional_flags, message} end length = IO.iodata_length(message) @@ -59,7 +62,7 @@ defmodule GRPC.Message do if length > max_length do {:error, "Encoded message is too large (#{length} bytes)"} else - result = [compress_flag, <>, message] + result = [flag, <>, message] result = if function_exported?(codec, :pack_for_channel, 1), @@ -78,12 +81,14 @@ defmodule GRPC.Message do ## Examples iex> GRPC.Message.from_data(<<0, 0, 0, 0, 8, 1, 2, 3, 4, 5, 6, 7, 8>>) - <<1, 2, 3, 4, 5, 6, 7, 8>> + {<<1, 2, 3, 4, 5, 6, 7, 8>>, <<>>} """ - @spec from_data(binary) :: binary + @spec from_data(binary) :: {message :: binary, rest :: binary} def from_data(data) do - <<_flag::unsigned-integer-size(8), _length::bytes-size(4), message::binary>> = data - message + <<_flag::unsigned-integer-size(8), length::big-32, message::bytes-size(length), rest::binary>> = + data + + {message, rest} end @doc """ @@ -92,13 +97,16 @@ defmodule GRPC.Message do ## Examples iex> GRPC.Message.from_data(%{compressor: nil}, <<0, 0, 0, 0, 8, 1, 2, 3, 4, 5, 6, 7, 8>>) - {:ok, <<1, 2, 3, 4, 5, 6, 7, 8>>} + {:ok, <<1, 2, 3, 4, 5, 6, 7, 8>>, <<>>} """ - @spec from_data(map, binary) :: {:ok, binary} | {:error, GRPC.RPCError.t()} + @spec from_data(map, binary) :: + {:ok, message :: binary, rest :: binary} + | {:trailers, map, rest :: binary} + | {:error, GRPC.RPCError.t()} def from_data(%{compressor: nil}, data) do case data do - <<0, _length::bytes-size(4), message::binary>> -> - {:ok, message} + <<0, length::big-32, message::bytes-size(length), rest::binary>> -> + {:ok, message, rest} <<1, _length::bytes-size(4), _::binary>> -> {:error, @@ -107,6 +115,9 @@ defmodule GRPC.Message do message: "Compressed flag is set, but not specified in headers." )} + <<@trailers_flag, length::big-32, message::bytes-size(length), rest::binary>> -> + {:trailers, parse_trailers(message), rest} + _ -> {:error, RPCError.exception(status: :invalid_argument, message: "Message is malformed.")} end @@ -114,17 +125,29 @@ defmodule GRPC.Message do def from_data(%{compressor: compressor}, data) do case data do - <<1, _length::bytes-size(4), message::binary>> -> - {:ok, compressor.decompress(message)} + <<1, length::big-32, message::bytes-size(length), rest::binary>> -> + {:ok, compressor.decompress(message), rest} - <<0, _length::bytes-size(4), message::binary>> -> - {:ok, message} + <<0, length::big-32, message::bytes-size(length), rest::binary>> -> + {:ok, message, rest} + + <<@trailers_flag, length::big-32, message::bytes-size(length), rest::binary>> -> + {:trailers, parse_trailers(message), rest} _ -> {:error, RPCError.exception(status: :invalid_argument, message: "Message is malformed.")} end end + defp parse_trailers(data) do + data + |> String.split("\r\n") + |> Enum.reduce(%{}, fn line, acc -> + [k, v] = String.split(line, ":", parts: 2) + Map.put(acc, k, String.trim(v)) + end) + end + def from_frame(bin), do: from_frame(bin, []) def from_frame(<<>>, acc), do: Enum.reverse(acc) @@ -166,7 +189,10 @@ defmodule GRPC.Message do <> ) do - {{flag, message}, rest} + case flag do + @trailers_flag -> {{:trailers, message}, rest} + _ -> {{flag, message}, rest} + end end def get_message(_) do @@ -175,6 +201,10 @@ defmodule GRPC.Message do def get_message(data, nil = _compressor) do case data do + <<@trailers_flag::8, length::unsigned-integer-size(32), message::bytes-size(length), + rest::binary>> -> + {{:trailers, message}, rest} + <> -> {{flag, message}, rest} @@ -192,6 +222,10 @@ defmodule GRPC.Message do <<0::8, length::unsigned-integer-32, message::bytes-size(length), rest::binary>> -> {{0, message}, rest} + <<@trailers_flag::8, length::unsigned-integer-32, message::bytes-size(length), + rest::binary>> -> + {{:trailers, message}, rest} + _other -> data end diff --git a/grpc_core/test/grpc/message_test.exs b/grpc_core/test/grpc/message_test.exs index 04b29f32..8f37df0e 100644 --- a/grpc_core/test/grpc/message_test.exs +++ b/grpc_core/test/grpc/message_test.exs @@ -7,13 +7,17 @@ defmodule GRPC.MessageTest do message = String.duplicate("foo", 100) # 10th byte is the operating system ID - assert {:ok, - data = - <<1, 0, 0, 0, 27, 31, 139, 8, 0, 0, 0, 0, 0, 0, _, 75, 203, 207, 79, 27, 69, 196, - 33, 0, 41, 249, 122, 62, 44, 1, 0, 0>>, - 32} = GRPC.Message.to_data(message, %{compressor: GRPC.Compressor.Gzip}) - - assert {:ok, message} == GRPC.Message.from_data(%{compressor: GRPC.Compressor.Gzip}, data) + assert { + :ok, + data = + <<1, 0, 0, 0, 27, 31, 139, 8, 0, 0, 0, 0, 0, 0, _, 75, 203, 207, 79, 27, 69, 196, + 33, 0, 41, 249, 122, 62, 44, 1, 0, 0>>, + 32 + } = + GRPC.Message.to_data(message, %{compressor: GRPC.Compressor.Gzip}) + + assert {:ok, message, <<>>} == + GRPC.Message.from_data(%{compressor: GRPC.Compressor.Gzip}, data) end test "iodata can be passed to and returned from `to_data/2`" do @@ -25,13 +29,13 @@ defmodule GRPC.MessageTest do assert is_list(data) binary = IO.iodata_to_binary(data) - assert {:ok, IO.iodata_to_binary(message)} == + assert {:ok, IO.iodata_to_binary(message), <<>>} == GRPC.Message.from_data(%{compressor: GRPC.Compressor.Gzip}, binary) end test "to_data/2 invokes codec.pack_for_channel on the gRPC body if codec implements it" do message = "web-text" assert {:ok, base64_payload, _} = GRPC.Message.to_data(message, %{codec: GRPC.Codec.WebText}) - assert message == GRPC.Message.from_data(Base.decode64!(base64_payload)) + assert {message, ""} == GRPC.Message.from_data(Base.decode64!(base64_payload)) end end diff --git a/grpc_server/lib/grpc/server.ex b/grpc_server/lib/grpc/server.ex index 1e48c926..eb0eeb89 100644 --- a/grpc_server/lib/grpc/server.ex +++ b/grpc_server/lib/grpc/server.ex @@ -300,7 +300,7 @@ defmodule GRPC.Server do end case GRPC.Message.from_data(stream, body) do - {:ok, message} -> + {:ok, message, <<>>} -> request = codec.decode(message, req_mod) call_with_interceptors(res_stream, func_name, stream, request) diff --git a/grpc_server/lib/grpc/server/adapters/cowboy/handler.ex b/grpc_server/lib/grpc/server/adapters/cowboy/handler.ex index bc596772..1e8b8028 100644 --- a/grpc_server/lib/grpc/server/adapters/cowboy/handler.ex +++ b/grpc_server/lib/grpc/server/adapters/cowboy/handler.ex @@ -12,6 +12,7 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do @adapter GRPC.Server.Adapters.Cowboy @default_trailers HTTP2.server_trailers() + @trailers_flag 0b1000_0000 @type init_state :: { endpoint :: atom(), @@ -103,6 +104,7 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do handling_timer: timer_ref, pending_reader: nil, access_mode: access_mode, + codec: codec, exception_log_filter: exception_log_filter } } @@ -481,7 +483,34 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do def info({:stream_trailers, trailers}, req, state) do metadata = Map.get(state, :resp_trailers, %{}) metadata = GRPC.Transport.HTTP2.encode_metadata(metadata) - send_stream_trailers(req, Map.merge(metadata, trailers)) + all_trailers = Map.merge(metadata, trailers) + + req = check_sent_resp(req) + + if state.access_mode === :grpcweb do + # grpc_web requires trailers be sent as the last + # message block rather than in the HTTP trailers + # as javascript runtimes do not propagate trailers + # + # trailers are instead denoted with the "trailer flag" + # which has the MSB set to 1. + {:ok, data, _length} = + all_trailers + |> Enum.map_join("\r\n", fn {k, v} -> "#{k}: #{v}" end) + |> GRPC.Message.to_data(message_flag: @trailers_flag) + + packed = + if function_exported?(state.codec, :pack_for_channel, 1) do + state.codec.pack_for_channel(data) + else + data + end + + :cowboy_req.stream_body(packed, :nofin, req) + end + + :cowboy_req.stream_trailers(all_trailers, req) + {:ok, req, state} end @@ -616,11 +645,6 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do end end - defp send_stream_trailers(req, trailers) do - req = check_sent_resp(req) - :cowboy_req.stream_trailers(trailers, req) - end - defp check_sent_resp(%{has_sent_resp: _} = req) do req end