diff --git a/Activout.RestClient.Newtonsoft.Json.Test/DomainExceptions/DefaultDomainExceptionErrorObjectTests.cs b/Activout.RestClient.Newtonsoft.Json.Test/DomainExceptions/DefaultDomainExceptionErrorObjectTests.cs index 17a3263..86c0382 100644 --- a/Activout.RestClient.Newtonsoft.Json.Test/DomainExceptions/DefaultDomainExceptionErrorObjectTests.cs +++ b/Activout.RestClient.Newtonsoft.Json.Test/DomainExceptions/DefaultDomainExceptionErrorObjectTests.cs @@ -66,9 +66,9 @@ public async Task TestMapApiErrorObject() } [Theory] - [InlineData(HttpStatusCode.BadRequest, MyDomainErrorEnum.ClientError)] - [InlineData(HttpStatusCode.BadGateway, MyDomainErrorEnum.ServerError)] - public async Task TestDeserializerException(HttpStatusCode httpStatusCode, MyDomainErrorEnum error) + [InlineData(HttpStatusCode.BadRequest)] + [InlineData(HttpStatusCode.BadGateway)] + public async Task TestDeserializerException(HttpStatusCode httpStatusCode) { // Arrange _mockHttp @@ -76,18 +76,18 @@ public async Task TestDeserializerException(HttpStatusCode httpStatusCode, MyDom .Respond(_ => HtmlHttpResponseMessage(httpStatusCode)); // Act - var exception = await Assert.ThrowsAsync(() => + var exception = await Assert.ThrowsAsync(() => _defaultMapperApiClient.Api()); // Assert - Assert.Equal(error, exception.Error); + Assert.Equal(httpStatusCode, exception.StatusCode); Assert.IsType(exception.InnerException); } [Theory] - [InlineData(HttpStatusCode.BadRequest, MyDomainErrorEnum.ClientError)] - [InlineData(HttpStatusCode.BadGateway, MyDomainErrorEnum.ServerError)] - public async Task TestNoDeserializerFound(HttpStatusCode httpStatusCode, MyDomainErrorEnum error) + [InlineData(HttpStatusCode.BadRequest)] + [InlineData(HttpStatusCode.BadGateway)] + public async Task TestNoDeserializerFound(HttpStatusCode httpStatusCode) { // Arrange _mockHttp @@ -95,12 +95,11 @@ public async Task TestNoDeserializerFound(HttpStatusCode httpStatusCode, MyDomai .Respond(_ => FoobarHttpResponseMessage(httpStatusCode)); // Act - var exception = await Assert.ThrowsAsync(() => + var exception = await Assert.ThrowsAsync(() => _defaultMapperApiClient.Api()); // Assert - Assert.Equal(error, exception.Error); - Assert.IsType(exception.InnerException); + Assert.Equal(httpStatusCode, exception.StatusCode); } private static HttpResponseMessage JsonHttpResponseMessage(HttpStatusCode httpStatusCode, MyApiError myApiError) diff --git a/Activout.RestClient.Newtonsoft.Json.Test/DomainExceptions/DomainExceptionErrorObjectTests.cs b/Activout.RestClient.Newtonsoft.Json.Test/DomainExceptions/DomainExceptionErrorObjectTests.cs index cdf6c66..d8d2ac5 100644 --- a/Activout.RestClient.Newtonsoft.Json.Test/DomainExceptions/DomainExceptionErrorObjectTests.cs +++ b/Activout.RestClient.Newtonsoft.Json.Test/DomainExceptions/DomainExceptionErrorObjectTests.cs @@ -102,7 +102,7 @@ public async Task TestMapApiErrorObject() } [Fact] - public async Task TestNoDeserializerFound() + public async Task TestErrorResponseNotCompatibleWithHtml() { // Arrange _mockHttp @@ -110,11 +110,11 @@ public async Task TestNoDeserializerFound() .Respond(_ => HtmlHttpResponseMessage(HttpStatusCode.BadRequest)); // Act - var exception = await Assert.ThrowsAsync(() => + var exception = await Assert.ThrowsAsync(() => _myApiClient.Api()); // Assert - Assert.Equal(MyDomainErrorEnum.Unknown, exception.Error.ErrorEnum); + Assert.Equal(HttpStatusCode.BadRequest, exception.StatusCode); Assert.IsType(exception.InnerException); } diff --git a/Activout.RestClient/Implementation/RequestHandler.cs b/Activout.RestClient/Implementation/RequestHandler.cs index d50d806..2b5ae82 100644 --- a/Activout.RestClient/Implementation/RequestHandler.cs +++ b/Activout.RestClient/Implementation/RequestHandler.cs @@ -1,4 +1,3 @@ -#nullable disable using System; using System.Collections; using System.Collections.Generic; @@ -15,562 +14,586 @@ using Activout.RestClient.Serialization; using Microsoft.Extensions.Logging; -namespace Activout.RestClient.Implementation +namespace Activout.RestClient.Implementation; + +internal class RequestHandler { - internal class RequestHandler + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 + private const string DefaultHttpContentType = "application/octet-stream"; + + // https://tools.ietf.org/html/rfc7578#section-4.4 + private static readonly MediaType DefaultPartContentType = new MediaType("text/plain"); + + private readonly Type _actualReturnType; + private readonly int _bodyArgumentIndex = -1; + private readonly MediaType _contentType; + private readonly RestClientContext _context; + private readonly ITaskConverter? _converter; + private readonly Type _errorResponseType; + private readonly HttpMethod _httpMethod = HttpMethod.Get; + private readonly ParameterInfo[] _parameters; + private readonly Type _returnType; + private readonly ISerializer _serializer; + private readonly string _template; + private readonly IParamConverter[] _paramConverters; + private readonly IDomainExceptionMapper? _domainExceptionMapper; + private readonly List> _requestHeaders = new List>(); + + private bool IsDebugLoggingEnabled => _context.Logger.IsEnabled(LogLevel.Debug); + + public RequestHandler(MethodInfo method, RestClientContext context) { - // https://www.w3.org/Protocols/rfc2616/rfc2616-sec7.html#sec7.2.1 - private const string DefaultHttpContentType = "application/octet-stream"; - - // https://tools.ietf.org/html/rfc7578#section-4.4 - private static readonly MediaType DefaultPartContentType = new MediaType("text/plain"); - - private readonly Type _actualReturnType; - private readonly int _bodyArgumentIndex = -1; - private readonly MediaType _contentType; - private readonly RestClientContext _context; - private readonly ITaskConverter _converter; - private readonly Type _errorResponseType; - private readonly HttpMethod _httpMethod = HttpMethod.Get; - private readonly ParameterInfo[] _parameters; - private readonly Type _returnType; - private readonly ISerializer _serializer; - private readonly string _template; - private readonly IParamConverter[] _paramConverters; - private readonly IDomainExceptionMapper _domainExceptionMapper; - private readonly List> _requestHeaders = new List>(); - - private bool DebugLoggingEnabled => _context.Logger.IsEnabled(LogLevel.Debug); - - public RequestHandler(MethodInfo method, RestClientContext context) - { - _returnType = method.ReturnType; - _actualReturnType = GetActualReturnType(); - _parameters = method.GetParameters(); - _paramConverters = GetParamConverters(context.ParamConverterManager); - _converter = CreateConverter(context); - _template = context.BaseTemplate ?? ""; - _serializer = context.DefaultSerializer; - _contentType = context.DefaultContentType; - _errorResponseType = context.ErrorResponseType; - _requestHeaders.AddRange(context.DefaultHeaders); - - var templateBuilder = new StringBuilder(context.BaseTemplate ?? ""); - foreach (var attribute in method.GetCustomAttributes(true)) - switch (attribute) - { - case ContentTypeAttribute contentTypeAttribute: - _contentType = MediaType.ValueOf(contentTypeAttribute.ContentType); - break; - - case ErrorResponseAttribute errorResponseAttribute: - _errorResponseType = errorResponseAttribute.Type; - break; - - case HeaderAttribute headerAttribute: - _requestHeaders.AddOrReplaceHeader(headerAttribute.Name, headerAttribute.Value, - headerAttribute.Replace); - break; - - case HttpMethodAttribute httpMethodAttribute: - templateBuilder.Append(httpMethodAttribute.Template); - _httpMethod = GetHttpMethod(httpMethodAttribute); - break; - - case PathAttribute pathAttribute: - templateBuilder.Append(pathAttribute.Template); - break; - } - - if (IsHttpMethodWithBody()) + _returnType = method.ReturnType; + _actualReturnType = GetActualReturnType(); + _parameters = method.GetParameters(); + _paramConverters = GetParamConverters(context.ParamConverterManager); + _converter = CreateConverter(context); + _template = context.BaseTemplate; + _serializer = context.DefaultSerializer; + _contentType = context.DefaultContentType; + _errorResponseType = context.ErrorResponseType ?? typeof(string); + _requestHeaders.AddRange(context.DefaultHeaders); + + var templateBuilder = new StringBuilder(context.BaseTemplate); + foreach (var attribute in method.GetCustomAttributes(true)) + switch (attribute) { - _bodyArgumentIndex = _parameters.Length - 1; - - if (_parameters.Length > 0 && - _parameters[_bodyArgumentIndex].ParameterType == typeof(CancellationToken)) - { - _bodyArgumentIndex--; - } - - if (_bodyArgumentIndex < 0) - { - throw new InvalidOperationException("No body argument found for method: " + method.Name); - } + case ContentTypeAttribute contentTypeAttribute: + _contentType = new MediaType(contentTypeAttribute.ContentType); + break; + + case ErrorResponseAttribute errorResponseAttribute: + _errorResponseType = errorResponseAttribute.Type; + break; + + case HeaderAttribute headerAttribute: + _requestHeaders.AddOrReplaceHeader(headerAttribute.Name, headerAttribute.Value, + headerAttribute.Replace); + break; + + case HttpMethodAttribute httpMethodAttribute: + templateBuilder.Append(httpMethodAttribute.Template); + _httpMethod = GetHttpMethod(httpMethodAttribute); + break; + + case PathAttribute pathAttribute: + templateBuilder.Append(pathAttribute.Template); + break; } - _serializer = context.SerializationManager.GetSerializer(_contentType); + if (IsHttpMethodWithBody()) + { + _bodyArgumentIndex = _parameters.Length - 1; - if (context.UseDomainException) + if (_parameters.Length > 0 && + _parameters[_bodyArgumentIndex].ParameterType == typeof(CancellationToken)) { - _domainExceptionMapper = context.DomainExceptionMapperFactory.CreateDomainExceptionMapper( - method, - _errorResponseType, - context.DomainExceptionType); + _bodyArgumentIndex--; } - _template = templateBuilder.ToString(); - _context = context; + if (_bodyArgumentIndex < 0) + { + throw new InvalidOperationException("No body argument found for method: " + method.Name); + } } - private bool IsHttpMethodWithBody() + _serializer = context.SerializationManager.GetSerializer(_contentType) ?? + throw new InvalidOperationException("No serializer found for content type: " + _contentType); + + if (context.DomainExceptionType != null) { - return _httpMethod == HttpMethod.Post || _httpMethod == HttpMethod.Put || _httpMethod == HttpMethod.Patch; + _domainExceptionMapper = context.DomainExceptionMapperFactory.CreateDomainExceptionMapper( + method, + _errorResponseType, + context.DomainExceptionType); } - private IParamConverter[] GetParamConverters(IParamConverterManager paramConverterManager) - { - var paramConverters = new IParamConverter[_parameters.Length]; - for (var i = 0; i < _parameters.Length; i++) - { - paramConverters[i] = paramConverterManager.GetConverter(_parameters[i].ParameterType, _parameters[i]); - } + _template = templateBuilder.ToString(); + _context = context; + } - return paramConverters; - } + private bool IsHttpMethodWithBody() + { + return _httpMethod == HttpMethod.Post || _httpMethod == HttpMethod.Put || _httpMethod == HttpMethod.Patch; + } - private static HttpMethod GetHttpMethod(HttpMethodAttribute attribute) + private IParamConverter[] GetParamConverters(IParamConverterManager paramConverterManager) + { + var paramConverters = new IParamConverter[_parameters.Length]; + for (var i = 0; i < _parameters.Length; i++) { - return attribute.HttpMethod; + paramConverters[i] = paramConverterManager.GetConverter(_parameters[i].ParameterType, _parameters[i]) + ?? throw new InvalidOperationException( + "No parameter converter found for type: " + _parameters[i].ParameterType); } - private ITaskConverter CreateConverter(RestClientContext context) + return paramConverters; + } + + private static HttpMethod GetHttpMethod(HttpMethodAttribute attribute) + { + return attribute.HttpMethod; + } + + private ITaskConverter? CreateConverter(RestClientContext context) + { + if (_actualReturnType == typeof(void)) { - return context.TaskConverterFactory.CreateTaskConverter(_actualReturnType); + return null; } - private bool IsVoidTask() + return context.TaskConverterFactory.CreateTaskConverter(_actualReturnType) ?? + throw new InvalidOperationException("Failed to create task converter for return type: " + + _actualReturnType); + } + + private bool IsVoidTask() + { + return _returnType == typeof(Task); + } + + private bool IsGenericTask() + { + return _returnType.BaseType == typeof(Task) && _returnType.IsGenericType; + } + + private Type GetActualReturnType() + { + if (IsVoidTask()) + return typeof(void); + if (IsGenericTask()) + return _returnType.GenericTypeArguments[0]; + return _returnType; + } + + private string ExpandTemplate(Dictionary routeParams) + { + var expanded = _template; + foreach (var entry in routeParams) { - return _returnType == typeof(Task); + expanded = expanded.Replace("{" + entry.Key + "}", entry.Value.ToString()); } - private bool IsGenericTask() + return expanded; + } + + // Based on PrepareRequestMessage at https://github.com/dotnet/corefx/blob/master/src/System.Net.Http/src/System/Net/Http/HttpClient.cs + private void PrepareRequestMessage(HttpRequestMessage request) + { + var baseUri = _context.BaseUri; + Uri? requestUri = null; + if (request.RequestUri == null && baseUri == null) throw new InvalidOperationException(); + if (request.RequestUri == null) { - return _returnType.BaseType == typeof(Task) && _returnType.IsGenericType; + requestUri = baseUri; } - - private Type GetActualReturnType() + else { - if (IsVoidTask()) - return typeof(void); - if (IsGenericTask()) - return _returnType.GenericTypeArguments[0]; - return _returnType; + // If the request Uri is an absolute Uri, just use it. Otherwise try to combine it with the base Uri. + if (!request.RequestUri.IsAbsoluteUri) + { + if (baseUri == null) + throw new InvalidOperationException(); + requestUri = new Uri(baseUri, request.RequestUri); + } } - private string ExpandTemplate(Dictionary routeParams) - { - var expanded = _template; - foreach (var entry in routeParams) - expanded = expanded.Replace("{" + entry.Key + "}", entry.Value.ToString()); + // We modified the original request Uri. Assign the new Uri to the request message. + if (requestUri != null) request.RequestUri = requestUri; + } + + public object? Send(object?[]? args) + { + var headers = new List>(); + headers.AddRange(_requestHeaders); + + var routeParams = new Dictionary(); + var queryParams = new List(); + var formParams = new List>(); + var partParams = new List>(); + var cancellationToken = GetParams(args, routeParams, queryParams, formParams, headers, partParams); - return expanded; + var requestUriString = ExpandTemplate(routeParams); + if (queryParams.Count != 0) + { + requestUriString = requestUriString + "?" + string.Join("&", queryParams); } - // Based on PrepareRequestMessage at https://github.com/dotnet/corefx/blob/master/src/System.Net.Http/src/System/Net/Http/HttpClient.cs - private void PrepareRequestMessage(HttpRequestMessage request) + var requestUri = new Uri(requestUriString, UriKind.RelativeOrAbsolute); + + var request = new HttpRequestMessage(_httpMethod, requestUri); + + SetHeaders(request, headers); + + if (IsHttpMethodWithBody()) { - var baseUri = _context.BaseUri; - Uri requestUri = null; - if (request.RequestUri == null && baseUri == null) throw new InvalidOperationException(); - if (request.RequestUri == null) + if (partParams.Count != 0) { - requestUri = baseUri; + request.Content = CreateMultipartFormDataContent(partParams); } - else + else if (formParams.Count != 0) { - // If the request Uri is an absolute Uri, just use it. Otherwise try to combine it with the base Uri. - if (!request.RequestUri.IsAbsoluteUri) - { - if (baseUri == null) - throw new InvalidOperationException(); - requestUri = new Uri(baseUri, request.RequestUri); - } + request.Content = new FormUrlEncodedContent(formParams); + } + else if (args != null) + { + request.Content = GetHttpContent(_serializer, args[_bodyArgumentIndex], _contentType); } - - // We modified the original request Uri. Assign the new Uri to the request message. - if (requestUri != null) request.RequestUri = requestUri; } - public object Send(object[] args) - { - var headers = new List>(); - headers.AddRange(_requestHeaders); + var task = SendRequestAndHandleResponse(request, cancellationToken); - var routeParams = new Dictionary(); - var queryParams = new List(); - var formParams = new List>(); - var partParams = new List>(); - var cancellationToken = GetParams(args, routeParams, queryParams, formParams, headers, partParams); + if (IsVoidTask()) + return task; + if (_returnType.BaseType == typeof(Task) && _returnType.IsGenericType && _converter != null) + return _converter.ConvertReturnType(task); + return task.Result; + } - var requestUriString = ExpandTemplate(routeParams); - if (queryParams.Any()) + private static MultipartFormDataContent CreateMultipartFormDataContent( + IEnumerable> partParams) + { + var content = new MultipartFormDataContent(); + foreach (var part in partParams) + { + if (!string.IsNullOrEmpty(part.FileName)) { - requestUriString = requestUriString + "?" + string.Join("&", queryParams); + content.Add(part.Content, part.Name, part.FileName); } - - var requestUri = new Uri(requestUriString, UriKind.RelativeOrAbsolute); - - var request = new HttpRequestMessage(_httpMethod, requestUri); - - SetHeaders(request, headers); - - if (IsHttpMethodWithBody()) + else if (!string.IsNullOrEmpty(part.Name)) { - if (partParams.Count != 0) - { - request.Content = CreateMultipartFormDataContent(partParams); - } - else if (formParams.Count != 0) - { - request.Content = new FormUrlEncodedContent(formParams); - } - else - { - request.Content = GetHttpContent(_serializer, args[_bodyArgumentIndex], _contentType); - } + content.Add(part.Content, part.Name); + } + else + { + content.Add(part.Content); } + } - var task = SendAsync(request, cancellationToken); + return content; + } - if (IsVoidTask()) - return task; - if (_returnType.BaseType == typeof(Task) && _returnType.IsGenericType) - return _converter.ConvertReturnType(task); - return task.Result; - } + private void SetHeaders(HttpRequestMessage request, List> headers) + { + headers.ForEach(p => request.Headers.Add(p.Key, p.Value.ToString())); + } - private static MultipartFormDataContent CreateMultipartFormDataContent( - IEnumerable> partParams) - { - var content = new MultipartFormDataContent(); - foreach (var part in partParams) - { - if (!string.IsNullOrEmpty(part.FileName)) - { - content.Add(part.Content, part.Name, part.FileName); - } - else if (!string.IsNullOrEmpty(part.Name)) - { - content.Add(part.Content, part.Name); - } - else - { - content.Add(part.Content); - } - } + private string? ConvertValueToString(object? value, ParameterInfo parameterInfo) + { + if (value == null) + return null; - return content; - } + var converter = _context.ParamConverterManager.GetConverter(value.GetType(), parameterInfo); + return converter?.ToString(value) ?? value.ToString(); + } - private void SetHeaders(HttpRequestMessage request, List> headers) - { - headers.ForEach(p => request.Headers.Add(p.Key, p.Value.ToString())); - } + private CancellationToken GetParams( + object?[]? args, + Dictionary pathParams, + List queryParams, + List> formParams, + List> headers, + List> parts) + { + var cancellationToken = CancellationToken.None; - private string ConvertValueToString(object value, ParameterInfo parameterInfo) + if (_parameters.Length > 0 && args == null || _parameters.Length != args?.Length) { - if (value == null) - return null; - - var converter = _context.ParamConverterManager.GetConverter(value.GetType(), parameterInfo); - return converter?.ToString(value) ?? value.ToString(); + throw new InvalidOperationException( + $"Argument count mismatch. Expected: {_parameters.Length}, Actual: {args?.Length ?? 0}"); } - private CancellationToken GetParams( - object[] args, - Dictionary pathParams, - List queryParams, - List> formParams, - List> headers, - List> parts) + for (var i = 0; i < _parameters.Length; i++) { - var cancellationToken = CancellationToken.None; + var rawValue = args[i]; + if (rawValue is CancellationToken ct) + { + cancellationToken = ct; + continue; + } - for (var i = 0; i < _parameters.Length; i++) + if (rawValue == null) { - var rawValue = args[i]; - if (rawValue is CancellationToken ct) - { - cancellationToken = ct; - continue; - } + continue; + } - var parameterAttributes = _parameters[i].GetCustomAttributes(false); - var parameterName = _parameters[i].Name; - var stringValue = _paramConverters[i].ToString(rawValue); - var handled = false; + var parameterAttributes = _parameters[i].GetCustomAttributes(false); + var parameterName = _parameters[i].Name ?? throw new InvalidOperationException( + "Parameter name not found for parameter at index: " + i); + var handled = false; - foreach (var attribute in parameterAttributes) + foreach (var attribute in parameterAttributes) + { + if (attribute is PartParamAttribute partAttribute) { - if (attribute is PartParamAttribute partAttribute) + if (_parameters[i].ParameterType.IsArray) { - if (_parameters[i].ParameterType.IsArray) - { - var items = (object[])rawValue; - parts.AddRange(items.SelectMany(item => - GetPartNameAndHttpContent(partAttribute, parameterName, item))); - } - else + if (rawValue is IEnumerable items and not string) { - parts.AddRange(GetPartNameAndHttpContent(partAttribute, parameterName, rawValue)); + foreach (var item in items) + { + parts.AddRange(GetPartNameAndHttpContent(partAttribute, parameterName, item)); + } } - - handled = true; } - else if (attribute is PathParamAttribute pathParamAttribute) + else { - pathParams[pathParamAttribute.Name ?? parameterName] = Uri.EscapeDataString(stringValue); - handled = true; + parts.AddRange(GetPartNameAndHttpContent(partAttribute, parameterName, rawValue)); } - else if (attribute is QueryParamAttribute queryParamAttribute) + + handled = true; + } + else if (attribute is PathParamAttribute pathParamAttribute) + { + var stringValue = _paramConverters[i].ToString(rawValue); + pathParams[pathParamAttribute.Name ?? parameterName] = Uri.EscapeDataString(stringValue); + handled = true; + } + else if (attribute is QueryParamAttribute queryParamAttribute) + { + if (rawValue is IDictionary dictionary) { - if (rawValue is IDictionary dictionary) + foreach (DictionaryEntry entry in dictionary) { - foreach (DictionaryEntry entry in dictionary) + var key = entry.Key.ToString(); + var value = ConvertValueToString(entry.Value, _parameters[i]); + if (key != null && value != null) { - var key = entry.Key?.ToString(); - var value = ConvertValueToString(entry.Value, _parameters[i]); - if (key != null && value != null) - { - queryParams.Add(Uri.EscapeDataString(key) + "=" + Uri.EscapeDataString(value)); - } + queryParams.Add(Uri.EscapeDataString(key) + "=" + Uri.EscapeDataString(value)); } } - else if (rawValue != null) - { - queryParams.Add(Uri.EscapeDataString(queryParamAttribute.Name ?? parameterName) + "=" + - Uri.EscapeDataString(stringValue)); - } - - handled = true; } - else if (attribute is FormParamAttribute formParamAttribute) + else if (rawValue != null) { - if (rawValue is IDictionary dictionary) + var stringValue = _paramConverters[i].ToString(rawValue); + queryParams.Add(Uri.EscapeDataString(queryParamAttribute.Name ?? parameterName) + "=" + + Uri.EscapeDataString(stringValue)); + } + + handled = true; + } + else if (attribute is FormParamAttribute formParamAttribute) + { + if (rawValue is IDictionary dictionary) + { + foreach (DictionaryEntry entry in dictionary) { - foreach (DictionaryEntry entry in dictionary) + var key = entry.Key.ToString(); + var value = ConvertValueToString(entry.Value, _parameters[i]); + if (key != null && value != null) { - var key = entry.Key?.ToString(); - var value = ConvertValueToString(entry.Value, _parameters[i]); - if (key != null && value != null) - { - formParams.Add(new KeyValuePair(key, value)); - } + formParams.Add(new KeyValuePair(key, value)); } } - else if (rawValue != null) - { - formParams.Add(new KeyValuePair(formParamAttribute.Name ?? parameterName, - stringValue)); - } - - handled = true; } - else if (attribute is HeaderParamAttribute headerParamAttribute) + else if (rawValue != null) { - if (rawValue is IDictionary dictionary) + var stringValue = _paramConverters[i].ToString(rawValue); + formParams.Add(new KeyValuePair(formParamAttribute.Name ?? parameterName, + stringValue)); + } + + handled = true; + } + else if (attribute is HeaderParamAttribute headerParamAttribute) + { + if (rawValue is IDictionary dictionary) + { + foreach (DictionaryEntry entry in dictionary) { - foreach (DictionaryEntry entry in dictionary) + var key = entry.Key.ToString(); + var value = ConvertValueToString(entry.Value, _parameters[i]); + if (key != null && value != null) { - var key = entry.Key?.ToString(); - var value = ConvertValueToString(entry.Value, _parameters[i]); - if (key != null && value != null) - { - headers.AddOrReplaceHeader(key, value, headerParamAttribute.Replace); - } + headers.AddOrReplaceHeader(key, value, headerParamAttribute.Replace); } } - else if (rawValue != null) - { - headers.AddOrReplaceHeader(headerParamAttribute.Name ?? parameterName, stringValue, - headerParamAttribute.Replace); - } - - handled = true; } - } + else if (rawValue != null) + { + var stringValue = _paramConverters[i].ToString(rawValue); + headers.AddOrReplaceHeader(headerParamAttribute.Name ?? parameterName, stringValue, + headerParamAttribute.Replace); + } - if (!handled) - { - pathParams[parameterName] = Uri.EscapeDataString(stringValue); + handled = true; } } - return cancellationToken; + if (!handled) + { + var stringValue = _paramConverters[i].ToString(rawValue); + pathParams[parameterName] = Uri.EscapeDataString(stringValue); + } } - private IEnumerable> GetPartNameAndHttpContent(PartParamAttribute partAttribute, - string parameterName, - object rawValue) - { - string fileName = null; - string partName = null; + return cancellationToken; + } - if (rawValue is Part part) - { - rawValue = part.InternalContent; - partName = part.Name; - fileName = part.FileName; - } + private IEnumerable> GetPartNameAndHttpContent(PartParamAttribute partAttribute, + string parameterName, + object? rawValue) + { + string? fileName = null; + string? partName = null; - if (rawValue is { }) + if (rawValue is Part part) + { + rawValue = part.InternalContent; + partName = part.Name; + fileName = part.FileName; + } + + if (rawValue is { }) + { + yield return new Part { - yield return new Part - { - Content = GetPartHttpContent(partAttribute, rawValue), - Name = partName ?? partAttribute.Name ?? parameterName, - FileName = fileName ?? partAttribute.FileName - }; - } + Content = GetPartHttpContent(partAttribute, rawValue), + Name = partName ?? partAttribute.Name ?? parameterName, + FileName = fileName ?? partAttribute.FileName + }; } + } - private HttpContent GetPartHttpContent(PartParamAttribute partAttribute, object value) + private HttpContent GetPartHttpContent(PartParamAttribute partAttribute, object value) + { + // TODO: prepare part serializer in advance + + var contentType = partAttribute.ContentType ?? DefaultPartContentType; + var serializer = _context.SerializationManager.GetSerializer(contentType) ?? + throw new InvalidOperationException("No serializer for part content type: " + contentType); + return GetHttpContent(serializer, value, contentType); + } + + private static HttpContent GetHttpContent(ISerializer serializer, object? value, MediaType contentType) + { + if (value is HttpContent httpContent) { - // TODO: prepare part serializer in advance + return httpContent; + } - var contentType = partAttribute.ContentType ?? DefaultPartContentType; - var serializer = _context.SerializationManager.GetSerializer(contentType); - return GetHttpContent(serializer, value, contentType); + if (serializer == null) + { + throw new InvalidOperationException("No serializer for: " + contentType); } - private static HttpContent GetHttpContent(ISerializer serializer, object value, MediaType contentType) + return serializer.Serialize(value, Encoding.UTF8, contentType); + } + + + private async Task SendRequestAndHandleResponse(HttpRequestMessage request, CancellationToken cancellationToken) + { + var response = await SendRequest(request, cancellationToken); + return await HandleResponse(request, response); + } + + private async Task SendRequest(HttpRequestMessage request, CancellationToken cancellationToken) + { + PrepareRequestMessage(request); + + if (IsDebugLoggingEnabled) { - if (value is HttpContent httpContent) - { - return httpContent; - } + _context.Logger.LogDebug("{Request}", request); - if (serializer == null) + if (request.Content != null) { - throw new InvalidOperationException("No serializer for: " + contentType); + await request.Content.LoadIntoBufferAsync(); + _context.Logger.LogDebug("{RequestContent}", + (await request.Content.ReadAsStringAsync(cancellationToken)).SafeSubstring(0, 1000)); } - - return serializer.Serialize(value, Encoding.UTF8, contentType); } + HttpResponseMessage response; + using (_context.RequestLogger.TimeOperation(request)) + { + response = await _context.HttpClient.SendAsync(request, cancellationToken); + } - private async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + if (IsDebugLoggingEnabled) { - PrepareRequestMessage(request); + _context.Logger.LogDebug("{Response}", response); - if (DebugLoggingEnabled) - { - _context.Logger.LogDebug("{Request}", request); + await response.Content.LoadIntoBufferAsync(); + _context.Logger.LogDebug("{ResponseContent}", + (await response.Content.ReadAsStringAsync(cancellationToken)).SafeSubstring(0, 1000)); + } - if (request.Content != null) - { - await request.Content.LoadIntoBufferAsync(); - _context.Logger.LogDebug("{RequestContent}", - (await request.Content.ReadAsStringAsync()).SafeSubstring(0, 1000)); - } - } + return response; + } + + private async Task HandleResponse(HttpRequestMessage request, HttpResponseMessage response) + { + if (_actualReturnType == typeof(HttpResponseMessage)) + { + return response; + } - HttpResponseMessage response; - using (_context.RequestLogger.TimeOperation(request)) + var shouldDisposeResponse = true; + try + { + if (_actualReturnType == typeof(HttpStatusCode)) { - response = await _context.HttpClient.SendAsync(request, cancellationToken); + return response.StatusCode; } - if (DebugLoggingEnabled) - { - _context.Logger.LogDebug("{Response}", response); + object? data; + var type = response.IsSuccessStatusCode ? _actualReturnType : _errorResponseType; - if (response.Content != null) - { - await response.Content.LoadIntoBufferAsync(); - _context.Logger.LogDebug("{ResponseContent}", - (await response.Content.ReadAsStringAsync()).SafeSubstring(0, 1000)); - } + if (type == typeof(void)) + { + data = null; } - - if (_actualReturnType == typeof(HttpStatusCode)) + else if (type.IsInstanceOfType(response.Content)) // HttpContent or a subclass like MultipartFormDataContent { - return response.StatusCode; + shouldDisposeResponse = false; + data = response.Content; } - - if (_actualReturnType == typeof(HttpResponseMessage)) + else { - return response; + data = await Deserialize(request, response, type); } - var data = await GetResponseData(request, response); - if (response.IsSuccessStatusCode) { return data; } - if (_context.UseDomainException) + if (_context.UseDomainException && _domainExceptionMapper != null) { throw await _domainExceptionMapper.CreateExceptionAsync(response, data); } throw new RestClientException(request.RequestUri, response.StatusCode, data); } - - private async Task GetResponseData(HttpRequestMessage request, HttpResponseMessage response) + finally { - var type = response.IsSuccessStatusCode ? _actualReturnType : _errorResponseType; - - if (type == typeof(void) || response.Content == null) - { - return null; - } - - // HttpContent or a subclass like MultipartFormDataContent - if (type.IsInstanceOfType(response.Content)) - { - return response.Content; - } - - var contentTypeMediaType = response.Content.Headers?.ContentType?.MediaType ?? DefaultHttpContentType; - var deserializer = _context.SerializationManager.GetDeserializer(new MediaType(contentTypeMediaType)); - if (deserializer == null) - { - throw await CreateNoDeserializerFoundException(request, response, contentTypeMediaType); - } - - try + if (shouldDisposeResponse) { - return await deserializer.Deserialize(response.Content, type); - } - catch (Exception e) - { - if (e is RestClientException) - { - throw; - } - - throw await CreateDeserializationException(request, response, e); + response.Dispose(); } } + } - private async Task CreateDeserializationException(HttpRequestMessage request, - HttpResponseMessage response, Exception e) - { - var errorResponse = response.Content == null ? null : await response.Content.ReadAsStringAsync(); - - if (response.IsSuccessStatusCode || !_context.UseDomainException) - { - return new RestClientException(request.RequestUri, response.StatusCode, errorResponse, e); - } + private async Task Deserialize(HttpRequestMessage request, HttpResponseMessage response, Type type) + { + var contentTypeMediaType = response.Content.Headers.ContentType?.MediaType ?? DefaultHttpContentType; + var deserializer = _context.SerializationManager.GetDeserializer(new MediaType(contentTypeMediaType)) ?? + throw new RestClientException(request.RequestUri, response.StatusCode, + "No deserializer found for " + contentTypeMediaType); - return await _domainExceptionMapper.CreateExceptionAsync(response, errorResponse, e); + try + { + return await deserializer.Deserialize(response.Content, type); } - - private async Task CreateNoDeserializerFoundException(HttpRequestMessage request, - HttpResponseMessage response, - string contentTypeMediaType) + catch (Exception e) { - var exception = (Exception)new RestClientException(request.RequestUri, response.StatusCode, - "No deserializer found for " + contentTypeMediaType); - - if (response.IsSuccessStatusCode || !_context.UseDomainException) + if (e is RestClientException) { - return exception; + throw; } - return await _domainExceptionMapper.CreateExceptionAsync(response, null, exception); + var errorResponse = await response.Content.ReadAsStringAsync(); + throw new RestClientException(request.RequestUri, response.StatusCode, errorResponse, e); } } } \ No newline at end of file diff --git a/Activout.RestClient/Implementation/RestClientBuilder.cs b/Activout.RestClient/Implementation/RestClientBuilder.cs index f9273a7..292a18a 100644 --- a/Activout.RestClient/Implementation/RestClientBuilder.cs +++ b/Activout.RestClient/Implementation/RestClientBuilder.cs @@ -56,6 +56,10 @@ public IRestClientBuilder ContentType(MediaType contentType) public IRestClientBuilder Header(string name, object value, bool isReplace = true) { + if (string.Equals("content-type", name, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException("Use ContentType method to set default content type."); + } _defaultHeaders.AddOrReplaceHeader(name, value, isReplace); return this; } @@ -159,7 +163,7 @@ public T Build() where T : class HttpClient: _httpClient ?? new HttpClient(), TaskConverterFactory: _taskConverterFactory ?? TaskConverter3Factory.Instance, ErrorResponseType: _errorResponseType, - DefaultContentType: _defaultContentType, + DefaultContentType: _defaultContentType ?? throw new InvalidOperationException("DefaultContentType is not set."), ParamConverterManager: _paramConverterManager ?? ParamConverterManager.Instance, DomainExceptionType: _domainExceptionType, DomainExceptionMapperFactory: _domainExceptionMapperFactory ?? diff --git a/Activout.RestClient/Implementation/RestClientContext.cs b/Activout.RestClient/Implementation/RestClientContext.cs index 41d6fb0..743a00f 100644 --- a/Activout.RestClient/Implementation/RestClientContext.cs +++ b/Activout.RestClient/Implementation/RestClientContext.cs @@ -17,7 +17,7 @@ internal record RestClientContext( HttpClient HttpClient, ITaskConverterFactory TaskConverterFactory, Type? ErrorResponseType, - MediaType? DefaultContentType, + MediaType DefaultContentType, IParamConverterManager ParamConverterManager, Type? DomainExceptionType, IDomainExceptionMapperFactory DomainExceptionMapperFactory,