Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions spannerlib/grpc-server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,38 +127,39 @@ func (s *spannerLibServer) ExecuteStreaming(request *pb.ExecuteRequest, stream g
if err != nil {
return err
}
defer func() { _ = api.CloseRows(context.Background(), request.Connection.Pool.Id, request.Connection.Id, id) }()
rows := &pb.Rows{Connection: request.Connection, Id: id}
metadata, err := api.Metadata(queryContext, request.Connection.Pool.Id, request.Connection.Id, id)
return s.streamRows(queryContext, rows, stream)
}

func (s *spannerLibServer) streamRows(queryContext context.Context, rows *pb.Rows, stream grpc.ServerStreamingServer[pb.RowData]) error {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving this logic into a separate function makes it easier to re-use. It will be re-used in the bi-directional gRPC stream implementation in a follow-up PR.

defer func() { _ = api.CloseRows(context.Background(), rows.Connection.Pool.Id, rows.Connection.Id, rows.Id) }()
metadata, err := api.Metadata(queryContext, rows.Connection.Pool.Id, rows.Connection.Id, rows.Id)
if err != nil {
return err
}

first := true
for {
if queryContext.Err() != nil {
return queryContext.Err()
}
if row, err := api.Next(queryContext, request.Connection.Pool.Id, request.Connection.Id, id); err != nil {
if row, err := api.Next(queryContext, rows.Connection.Pool.Id, rows.Connection.Id, rows.Id); err != nil {
return err
} else {
if row == nil {
stats, err := api.ResultSetStats(queryContext, request.Connection.Pool.Id, request.Connection.Id, id)
stats, err := api.ResultSetStats(queryContext, rows.Connection.Pool.Id, rows.Connection.Id, rows.Id)
if err != nil {
return err
}
nextMetadata, err := api.NextResultSet(queryContext, request.Connection.Pool.Id, request.Connection.Id, id)
if err != nil {
return err
}
res := &pb.RowData{Rows: rows, Stats: stats, HasMoreResults: nextMetadata != nil}
nextMetadata, nextResultSetErr := api.NextResultSet(queryContext, rows.Connection.Pool.Id, rows.Connection.Id, rows.Id)
res := &pb.RowData{Rows: rows, Stats: stats, HasMoreResults: nextMetadata != nil || nextResultSetErr != nil}
if first {
res.Metadata = metadata
first = false
}
if err := stream.Send(res); err != nil {
return err
}
if nextResultSetErr != nil {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ensures that calling NextResultSet when using the streaming gRPC API returns an error if the second (or third, ...) SQL statement in a multi-statement SQL string fails.

return nextResultSetErr
}
if res.HasMoreResults {
metadata = nextMetadata
first = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Google.Api.Gax;
using Google.Cloud.Spanner.V1;
using Google.Cloud.SpannerLib.V1;
using Google.Protobuf.WellKnownTypes;
Expand All @@ -27,8 +28,14 @@

namespace Google.Cloud.SpannerLib.Grpc;

public class GrpcLibSpanner : ISpannerLib
public sealed class GrpcLibSpanner : ISpannerLib
{

/// <summary>
/// Creates a GrpcChannel that uses a Unix domain socket with the given file name.
/// </summary>
/// <param name="fileName">The file to use for communication</param>
/// <returns>A GrpcChannel over a Unix domain socket with the given file name</returns>
public static GrpcChannel ForUnixSocket(string fileName)
{
var endpoint = new UnixDomainSocketEndPoint(fileName);
Expand All @@ -50,6 +57,12 @@ public static GrpcChannel ForUnixSocket(string fileName)
});
}

/// <summary>
/// Creates a GrpcChannel that connects to the given IP address or host name. The GrpcChannel uses a TCP socket for
/// communication. The communication does not use encryption.
/// </summary>
/// <param name="address">The IP address or host name that the channel should connect to</param>
/// <returns>A GrpcChannel using a TCP socket for communication</returns>
public static GrpcChannel ForTcpSocket(string address)
{
return GrpcChannel.ForAddress($"http://{address}", new GrpcChannelOptions
Expand All @@ -63,22 +76,20 @@ public static GrpcChannel ForTcpSocket(string address)
}

private readonly Server _server;
private readonly V1.SpannerLib.SpannerLibClient _client;
private readonly GrpcChannel _channel;
Comment on lines -66 to -67
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced this with an implementation that always uses a (very simple) channel pool.

private readonly V1.SpannerLib.SpannerLibClient[] _clients;
private readonly GrpcChannel[] _channels;
private readonly bool _useStreamingRows;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this option, as there is no good reason why we would not use the streaming implementation.

private bool _disposed;

private V1.SpannerLib.SpannerLibClient Client => _clients[Random.Shared.Next(_clients.Length)];

public GrpcLibSpanner(bool useStreamingRows = true, Server.AddressType addressType = Server.AddressType.UnixDomainSocket)
public GrpcLibSpanner(
int numChannels = 4,
Server.AddressType addressType = Server.AddressType.UnixDomainSocket)
{
GaxPreconditions.CheckArgument(numChannels > 0, nameof(numChannels), "numChannels must be > 0");
_server = new Server();
var file = _server.Start(addressType: addressType);
_channel = addressType == Server.AddressType.Tcp ? ForTcpSocket(file) : ForUnixSocket(file);
_client = new V1.SpannerLib.SpannerLibClient(_channel);
_useStreamingRows = useStreamingRows;

var numChannels = 1;
_channels = new GrpcChannel[numChannels];
_clients = new V1.SpannerLib.SpannerLibClient[numChannels];
for (var i = 0; i < numChannels; i++)
Expand All @@ -88,21 +99,22 @@ public GrpcLibSpanner(bool useStreamingRows = true, Server.AddressType addressTy
}
}

~GrpcLibSpanner() => Dispose(false);

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
private void Dispose(bool disposing)
{
if (_disposed)
{
return;
}
try
{
_channel.Dispose();
foreach (var channel in _channels)
{
channel.Dispose();
Expand All @@ -129,35 +141,35 @@ T TranslateException<T>(Func<T> f)

public Pool CreatePool(string connectionString)
{
return FromProto(TranslateException(() => _client.CreatePool(new CreatePoolRequest
return FromProto(TranslateException(() => Client.CreatePool(new CreatePoolRequest
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This picks a channel from the 'pool' instead of always just using one fixed channel.

{
ConnectionString = connectionString,
})));
}

public void ClosePool(Pool pool)
{
TranslateException(() => _client.ClosePool(ToProto(pool)));
TranslateException(() => Client.ClosePool(ToProto(pool)));
}

public Connection CreateConnection(Pool pool)
{
return FromProto(pool, TranslateException(() => _client.CreateConnection(new CreateConnectionRequest
return FromProto(pool, TranslateException(() => Client.CreateConnection(new CreateConnectionRequest
{
Pool = ToProto(pool),
})));
}

public void CloseConnection(Connection connection)
{
TranslateException(() => _client.CloseConnection(ToProto(connection)));
TranslateException(() => Client.CloseConnection(ToProto(connection)));
}

public async Task CloseConnectionAsync(Connection connection, CancellationToken cancellationToken = default)
{
try
{
await _client.CloseConnectionAsync(ToProto(connection), cancellationToken: cancellationToken).ConfigureAwait(false);
await Client.CloseConnectionAsync(ToProto(connection), cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (RpcException exception)
{
Expand All @@ -167,7 +179,7 @@ public async Task CloseConnectionAsync(Connection connection, CancellationToken

public CommitResponse? WriteMutations(Connection connection, BatchWriteRequest.Types.MutationGroup mutations)
{
var response = TranslateException(() => _client.WriteMutations(new WriteMutationsRequest
var response = TranslateException(() => Client.WriteMutations(new WriteMutationsRequest
{
Connection = ToProto(connection),
Mutations = mutations,
Expand All @@ -180,7 +192,7 @@ public async Task CloseConnectionAsync(Connection connection, CancellationToken
{
try
{
var response = await _client.WriteMutationsAsync(new WriteMutationsRequest
var response = await Client.WriteMutationsAsync(new WriteMutationsRequest
{
Connection = ToProto(connection),
Mutations = mutations,
Expand All @@ -195,15 +207,7 @@ public async Task CloseConnectionAsync(Connection connection, CancellationToken

public Rows Execute(Connection connection, ExecuteSqlRequest statement)
{
if (_useStreamingRows)
{
return ExecuteStreaming(connection, statement);
}
return FromProto(connection, TranslateException(() => _client.Execute(new ExecuteRequest
{
Connection = ToProto(connection),
ExecuteSqlRequest = statement,
})));
return ExecuteStreaming(connection, statement);
}

private StreamingRows ExecuteStreaming(Connection connection, ExecuteSqlRequest statement)
Expand All @@ -221,16 +225,7 @@ public async Task<Rows> ExecuteAsync(Connection connection, ExecuteSqlRequest st
{
try
{
if (_useStreamingRows)
{
return await ExecuteStreamingAsync(connection, statement, cancellationToken).ConfigureAwait(false);
}
var rows = await _client.ExecuteAsync(new ExecuteRequest
{
Connection = ToProto(connection),
ExecuteSqlRequest = statement,
}, cancellationToken: cancellationToken).ConfigureAwait(false);
return FromProto(connection, rows);
return await ExecuteStreamingAsync(connection, statement, cancellationToken).ConfigureAwait(false);
}
catch (RpcException exception)
{
Expand All @@ -251,7 +246,7 @@ private async Task<StreamingRows> ExecuteStreamingAsync(Connection connection, E

public long[] ExecuteBatch(Connection connection, ExecuteBatchDmlRequest statements)
{
var response = TranslateException(() => _client.ExecuteBatch(new ExecuteBatchRequest
var response = TranslateException(() => Client.ExecuteBatch(new ExecuteBatchRequest
{
Connection = ToProto(connection),
ExecuteBatchDmlRequest = statements,
Expand Down Expand Up @@ -279,7 +274,7 @@ public async Task<long[]> ExecuteBatchAsync(Connection connection, ExecuteBatchD
{
try
{
var stats = await _client.ExecuteBatchAsync(new ExecuteBatchRequest
var stats = await Client.ExecuteBatchAsync(new ExecuteBatchRequest
{
Connection = ToProto(connection),
ExecuteBatchDmlRequest = statements,
Expand All @@ -299,14 +294,14 @@ public async Task<long[]> ExecuteBatchAsync(Connection connection, ExecuteBatchD

public ResultSetMetadata? Metadata(Rows rows)
{
return TranslateException(() => _client.Metadata(ToProto(rows)));
return TranslateException(() => Client.Metadata(ToProto(rows)));
}

public async Task<ResultSetMetadata?> MetadataAsync(Rows rows, CancellationToken cancellationToken = default)
{
try
{
return await _client.MetadataAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false);
return await Client.MetadataAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (RpcException exception)
{
Expand All @@ -316,14 +311,14 @@ public async Task<long[]> ExecuteBatchAsync(Connection connection, ExecuteBatchD

public ResultSetMetadata? NextResultSet(Rows rows)
{
return TranslateException(() => _client.NextResultSet(ToProto(rows)));
return TranslateException(() => Client.NextResultSet(ToProto(rows)));
}

public async Task<ResultSetMetadata?> NextResultSetAsync(Rows rows, CancellationToken cancellationToken = default)
{
try
{
return await _client.NextResultSetAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false);
return await Client.NextResultSetAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (RpcException exception)
{
Expand All @@ -333,12 +328,12 @@ public async Task<long[]> ExecuteBatchAsync(Connection connection, ExecuteBatchD

public ResultSetStats? Stats(Rows rows)
{
return TranslateException(() => _client.ResultSetStats(ToProto(rows)));
return TranslateException(() => Client.ResultSetStats(ToProto(rows)));
}

public ListValue? Next(Rows rows, int numRows, ISpannerLib.RowEncoding encoding)
{
var row = TranslateException(() =>_client.Next(new NextRequest
var row = TranslateException(() =>Client.Next(new NextRequest
{
Rows = ToProto(rows),
NumRows = numRows,
Expand All @@ -351,7 +346,7 @@ public async Task<long[]> ExecuteBatchAsync(Connection connection, ExecuteBatchD
{
try
{
return await _client.NextAsync(new NextRequest
return await Client.NextAsync(new NextRequest
{
Rows = ToProto(rows),
NumRows = numRows,
Expand All @@ -366,14 +361,14 @@ public async Task<long[]> ExecuteBatchAsync(Connection connection, ExecuteBatchD

public void CloseRows(Rows rows)
{
TranslateException(() => _client.CloseRows(ToProto(rows)));
TranslateException(() => Client.CloseRows(ToProto(rows)));
}

public async Task CloseRowsAsync(Rows rows, CancellationToken cancellationToken = default)
{
try
{
await _client.CloseRowsAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false);
await Client.CloseRowsAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (RpcException exception)
{
Expand All @@ -383,7 +378,7 @@ public async Task CloseRowsAsync(Rows rows, CancellationToken cancellationToken

public void BeginTransaction(Connection connection, TransactionOptions transactionOptions)
{
TranslateException(() => _client.BeginTransaction(new BeginTransactionRequest
TranslateException(() => Client.BeginTransaction(new BeginTransactionRequest
{
Connection = ToProto(connection),
TransactionOptions = transactionOptions,
Expand All @@ -392,15 +387,15 @@ public void BeginTransaction(Connection connection, TransactionOptions transacti

public CommitResponse? Commit(Connection connection)
{
var response = TranslateException(() => _client.Commit(ToProto(connection)));
var response = TranslateException(() => Client.Commit(ToProto(connection)));
return response.CommitTimestamp == null ? null : response;
}

public async Task<CommitResponse?> CommitAsync(Connection connection, CancellationToken cancellationToken = default)
{
try
{
var response = await _client.CommitAsync(ToProto(connection), cancellationToken: cancellationToken).ConfigureAwait(false);
var response = await Client.CommitAsync(ToProto(connection), cancellationToken: cancellationToken).ConfigureAwait(false);
return response.CommitTimestamp == null ? null : response;
}
catch (RpcException exception)
Expand All @@ -411,30 +406,28 @@ public void BeginTransaction(Connection connection, TransactionOptions transacti

public void Rollback(Connection connection)
{
TranslateException(() => _client.Rollback(ToProto(connection)));
TranslateException(() => Client.Rollback(ToProto(connection)));
}

public async Task RollbackAsync(Connection connection, CancellationToken cancellationToken = default)
{
try
{
await _client.RollbackAsync(ToProto(connection), cancellationToken: cancellationToken).ConfigureAwait(false);
await Client.RollbackAsync(ToProto(connection), cancellationToken: cancellationToken).ConfigureAwait(false);
}
catch (RpcException exception)
{
throw SpannerException.ToSpannerException(exception);
}
}

Pool FromProto(V1.Pool pool) => new(this, pool.Id);

V1.Pool ToProto(Pool pool) => new() { Id = pool.Id };
private Pool FromProto(V1.Pool pool) => new(this, pool.Id);

Connection FromProto(Pool pool, V1.Connection proto) => new(pool, proto.Id);
private V1.Pool ToProto(Pool pool) => new() { Id = pool.Id };

V1.Connection ToProto(Connection connection) => new() { Id = connection.Id, Pool = ToProto(connection.Pool), };
private Connection FromProto(Pool pool, V1.Connection proto) => new(pool, proto.Id);

Rows FromProto(Connection connection, V1.Rows proto) => new(connection, proto.Id);
private V1.Connection ToProto(Connection connection) => new() { Id = connection.Id, Pool = ToProto(connection.Pool), };

V1.Rows ToProto(Rows rows) => new() { Id = rows.Id, Connection = ToProto(rows.SpannerConnection), };
private V1.Rows ToProto(Rows rows) => new() { Id = rows.Id, Connection = ToProto(rows.SpannerConnection), };
}
Loading
Loading