Skip to content

Commit

Permalink
Add CancelFlightInfoRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
kou committed Jun 24, 2023
1 parent f82ea8d commit 2ef69be
Show file tree
Hide file tree
Showing 20 changed files with 450 additions and 280 deletions.
4 changes: 2 additions & 2 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,8 @@ Status FlightClient::DoAction(const FlightCallOptions& options, const Action& ac
}

arrow::Result<CancelFlightInfoResult> FlightClient::CancelFlightInfo(
const FlightCallOptions& options, const FlightInfo& info) {
ARROW_ASSIGN_OR_RAISE(auto body, info.SerializeToString());
const FlightCallOptions& options, const CancelFlightInfoRequest& request) {
ARROW_ASSIGN_OR_RAISE(auto body, request.SerializeToString());
Action action{ActionType::kCancelFlightInfo.type, Buffer::FromString(body)};
ARROW_ASSIGN_OR_RAISE(auto stream, DoAction(options, action));
ARROW_ASSIGN_OR_RAISE(auto result, stream->Next());
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// CancelFlightInfoResult
///
/// \param[in] options Per-RPC options
/// \param[in] info The FlightInfo to be cancelled
/// \param[in] request The CancelFlightInfoRequest
/// \return Arrow result with a CancelFlightInfoResult
arrow::Result<CancelFlightInfoResult> CancelFlightInfo(const FlightCallOptions& options,
const FlightInfo& info);
arrow::Result<CancelFlightInfoResult> CancelFlightInfo(const FlightInfo& info) {
return CancelFlightInfo({}, info);
const CancelFlightInfoRequest& request);
arrow::Result<CancelFlightInfoResult> CancelFlightInfo(const CancelFlightInfoRequest& request) {
return CancelFlightInfo({}, request);
}

/// \brief Perform the CloseFlightInfo action
Expand Down
17 changes: 10 additions & 7 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,10 @@ class ExpirationTimeServer : public FlightServerBase {
std::unique_ptr<ResultStream>* result_stream) override {
std::vector<Result> results;
if (action.type == ActionType::kCancelFlightInfo.type) {
ARROW_ASSIGN_OR_RAISE(auto info,
FlightInfo::Deserialize(std::string_view(*action.body)));
ARROW_ASSIGN_OR_RAISE(auto request,
CancelFlightInfoRequest::Deserialize(std::string_view(*action.body)));
auto cancel_status = CancelStatus::kUnspecified;
for (const auto& endpoint : info->endpoints()) {
for (const auto& endpoint : request.info->endpoints()) {
auto index_result = ExtractIndexFromTicket(endpoint.ticket.ticket);
if (index_result.ok()) {
auto index = *index_result;
Expand Down Expand Up @@ -710,11 +710,13 @@ class ExpirationTimeCancelFlightInfoScenario : public Scenario {
Status RunClient(std::unique_ptr<FlightClient> client) override {
ARROW_ASSIGN_OR_RAISE(auto info,
client->GetFlightInfo(FlightDescriptor::Command("expiration")));
ARROW_ASSIGN_OR_RAISE(auto cancel_result, client->CancelFlightInfo(*info));
CancelFlightInfoRequest request{std::move(info)};
ARROW_ASSIGN_OR_RAISE(auto cancel_result, client->CancelFlightInfo(request));
if (cancel_result.status != CancelStatus::kCancelled) {
return Status::Invalid("CancelFlightInfo must return CANCEL_STATUS_CANCELLED: ",
cancel_result.ToString());
}
info = std::move(request.info);
for (const auto& endpoint : info->endpoints()) {
auto reader = client->DoGet(endpoint.ticket);
if (reader.ok()) {
Expand Down Expand Up @@ -1326,10 +1328,11 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
}

arrow::Result<CancelFlightInfoResult> CancelFlightInfo(
const ServerCallContext& context, const FlightInfo& info) override {
ARROW_RETURN_NOT_OK(AssertEq<size_t>(1, info.endpoints().size(),
const ServerCallContext& context, const CancelFlightInfoRequest& request) override {
const auto& info = request.info;
ARROW_RETURN_NOT_OK(AssertEq<size_t>(1, info->endpoints().size(),
"Expected 1 endpoint for CancelFlightInfo"));
const FlightEndpoint& endpoint = info.endpoints()[0];
const auto& endpoint = info->endpoints()[0];
ARROW_ASSIGN_OR_RAISE(auto ticket,
sql::StatementQueryTicket::Deserialize(endpoint.ticket.ticket));
ARROW_RETURN_NOT_OK(AssertEq<std::string>("PLAN HANDLE", ticket.statement_handle,
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/flight/serialization_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info) {
return Status::OK();
}

// CancelFlightInfoRequest

Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request) {
FlightInfo::Data data;
RETURN_NOT_OK(FromProto(pb_request.info(), &data));
request->info = std::make_unique<FlightInfo>(std::move(data));
return Status::OK();
}

Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* basic_auth) {
basic_auth->password = pb_basic_auth.password();
basic_auth->username = pb_basic_auth.username();
Expand Down Expand Up @@ -287,6 +297,12 @@ Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info) {
return Status::OK();
}

Status ToProto(const CancelFlightInfoRequest& request,
pb::CancelFlightInfoRequest* pb_request) {
RETURN_NOT_OK(ToProto(*request.info, pb_request->mutable_info()));
return Status::OK();
}

Status ToProto(const SchemaResult& result, pb::SchemaResult* pb_result) {
pb_result->set_schema(result.serialized_schema());
return Status::OK();
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/flight/serialization_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ Status FromProto(const pb::FlightEndpoint& pb_endpoint, FlightEndpoint* endpoint
Status FromProto(const pb::RenewFlightEndpointRequest& pb_request,
RenewFlightEndpointRequest* request);
Status FromProto(const pb::FlightInfo& pb_info, FlightInfo::Data* info);
Status FromProto(const pb::CancelFlightInfoRequest& pb_request,
CancelFlightInfoRequest* request);
Status FromProto(const pb::SchemaResult& pb_result, std::string* result);
Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* info);

Expand All @@ -68,6 +70,8 @@ Status ToProto(const FlightEndpoint& endpoint, pb::FlightEndpoint* pb_endpoint);
Status ToProto(const RenewFlightEndpointRequest& request,
pb::RenewFlightEndpointRequest* pb_request);
Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info);
Status ToProto(const CancelFlightInfoRequest& request,
pb::CancelFlightInfoRequest* pb_request);
Status ToProto(const ActionType& type, pb::ActionType* pb_type);
Status ToProto(const Action& action, pb::Action* pb_action);
Status ToProto(const Result& result, pb::Result* pb_result);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/flight/sql/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
/// \param[in] info The FlightInfo to cancel.
/// \return Arrow result with a canceled result.
::arrow::Result<CancelFlightInfoResult> CancelFlightInfo(
const FlightCallOptions& options, const FlightInfo& info) {
return impl_->CancelFlightInfo(options, info);
const FlightCallOptions& options, const CancelFlightInfoRequest& request) {
return impl_->CancelFlightInfo(options, request);
}

/// \brief Explicitly cancel a query.
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/arrow/flight/sql/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,8 @@ Status FlightSqlServerBase::DoAction(const ServerCallContext& context,
std::vector<Result> results;
if (action.type == ActionType::kCancelFlightInfo.type) {
std::string_view body(*action.body);
ARROW_ASSIGN_OR_RAISE(auto info, FlightInfo::Deserialize(body));
ARROW_ASSIGN_OR_RAISE(auto result, CancelFlightInfo(context, *info));
ARROW_ASSIGN_OR_RAISE(auto request, CancelFlightInfoRequest::Deserialize(body));
ARROW_ASSIGN_OR_RAISE(auto result, CancelFlightInfo(context, request));
ARROW_ASSIGN_OR_RAISE(auto packed_result, PackActionResult(std::move(result)));

results.push_back(std::move(packed_result));
Expand Down Expand Up @@ -1085,13 +1085,15 @@ arrow::Result<ActionBeginTransactionResult> FlightSqlServerBase::BeginTransactio
}

arrow::Result<CancelFlightInfoResult> FlightSqlServerBase::CancelFlightInfo(
const ServerCallContext& context, const FlightInfo& info) {
const ServerCallContext& context, const CancelFlightInfoRequest& request) {
return Status::NotImplemented("CancelFlightInfo not implemented");
}

arrow::Result<CancelResult> FlightSqlServerBase::CancelQuery(
const ServerCallContext& context, const ActionCancelQueryRequest& request) {
ARROW_ASSIGN_OR_RAISE(auto result, CancelFlightInfo(context, *request.info));
CancelFlightInfoRequest cancel_flight_info_request;
cancel_flight_info_request.info = std::make_unique<FlightInfo>(*request.info);
ARROW_ASSIGN_OR_RAISE(auto result, CancelFlightInfo(context, cancel_flight_info_request));
return static_cast<CancelResult>(result.status);
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/flight/sql/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,10 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase {

/// \brief Attempt to explicitly cancel a FlightInfo.
/// \param[in] context The call context.
/// \param[in] info The FlightInfo to cancel.
/// \param[in] request The CancelFlightInfoRequest.
/// \return The cancellation result.
virtual arrow::Result<CancelFlightInfoResult> CancelFlightInfo(
const ServerCallContext& context, const FlightInfo& info);
const ServerCallContext& context, const CancelFlightInfoRequest& request);

/// \brief Attempt to explicitly cancel a query.
///
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/flight/sql/server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,8 @@ TEST_F(TestFlightSqlServer, TestCommandGetSqlInfoNoInfo) {
TEST_F(TestFlightSqlServer, CancelFlightInfo) {
// Not supported
ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetSqlInfo({}, {}));
ASSERT_RAISES(NotImplemented, sql_client->CancelFlightInfo({}, *flight_info));
CancelFlightInfoRequest request{std::move(flight_info)};
ASSERT_RAISES(NotImplemented, sql_client->CancelFlightInfo({}, request));
}

TEST_F(TestFlightSqlServer, CancelQuery) {
Expand Down
42 changes: 40 additions & 2 deletions cpp/src/arrow/flight/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,44 @@ bool FlightInfo::Equals(const FlightInfo& other) const {
data_.ordered == other.data_.ordered;
}

std::string CancelFlightInfoRequest::ToString() const {
std::stringstream ss;
ss << "<CancelFlightInfoRequest info=" << info->ToString() << ">";
return ss.str();
}

bool CancelFlightInfoRequest::Equals(const CancelFlightInfoRequest& other) const {
return info == other.info;
}

arrow::Result<std::string> CancelFlightInfoRequest::SerializeToString() const {
pb::CancelFlightInfoRequest pb_request;
RETURN_NOT_OK(internal::ToProto(*this, &pb_request));

std::string out;
if (!pb_request.SerializeToString(&out)) {
return Status::IOError("Serialized CancelFlightInfoRequest exceeded 2 GiB limit");
}
return out;
}

arrow::Result<CancelFlightInfoRequest> CancelFlightInfoRequest::Deserialize(
std::string_view serialized) {
pb::CancelFlightInfoRequest pb_request;
if (serialized.size() > static_cast<size_t>(std::numeric_limits<int>::max())) {
return Status::Invalid(
"Serialized CancelFlightInfoRequest size should not exceed 2 GiB");
}
google::protobuf::io::ArrayInputStream input(serialized.data(),
static_cast<int>(serialized.size()));
if (!pb_request.ParseFromZeroCopyStream(&input)) {
return Status::Invalid("Not a valid CancelFlightInfoRequest");
}
CancelFlightInfoRequest out;
RETURN_NOT_OK(internal::FromProto(pb_request, &out));
return out;
}

Location::Location() { uri_ = std::make_shared<arrow::internal::Uri>(); }

Status FlightListing::Next(std::unique_ptr<FlightInfo>* info) {
Expand Down Expand Up @@ -556,8 +594,8 @@ std::string ActionType::ToString() const {
const ActionType ActionType::kCancelFlightInfo =
ActionType{"CancelFlightInfo",
"Explicitly cancel a running FlightInfo.\n"
"Request Message: FlightInfo to be canceled\n"
"Response Message: ActionCancelFlightInfoResult"};
"Request Message: CancelFlightInfoRequest\n"
"Response Message: CancelFlightInfoResult"};
const ActionType ActionType::kCloseFlightInfo =
ActionType{"CloseFlightInfo",
"Close the given FlightInfo explicitly.\n"
Expand Down
24 changes: 24 additions & 0 deletions cpp/src/arrow/flight/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,30 @@ class ARROW_FLIGHT_EXPORT FlightInfo {
mutable bool reconstructed_schema_;
};

/// \brief The request of the CancelFlightInfoRequest action.
struct ARROW_FLIGHT_EXPORT CancelFlightInfoRequest {
std::unique_ptr<FlightInfo> info;

std::string ToString() const;
bool Equals(const CancelFlightInfoRequest& other) const;

friend bool operator==(const CancelFlightInfoRequest& left,
const CancelFlightInfoRequest& right) {
return left.Equals(right);
}
friend bool operator!=(const CancelFlightInfoRequest& left,
const CancelFlightInfoRequest& right) {
return !(left == right);
}

/// \brief Serialize this message to its wire-format representation.
arrow::Result<std::string> SerializeToString() const;

/// \brief Deserialize this message from its wire-format representation.
static arrow::Result<CancelFlightInfoRequest> Deserialize(
std::string_view serialized);
};

/// \brief An iterator to FlightInfo instances returned by ListFlights.
class ARROW_FLIGHT_EXPORT FlightListing {
public:
Expand Down
9 changes: 9 additions & 0 deletions format/Flight.proto
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ message Action {
bytes body = 2;
}

/*
* The request of the CancelFlightInfo action.
*
* The request should be stored in Action.body.
*/
message CancelFlightInfoRequest {
FlightInfo info = 1;
}

/*
* The request of the RenewFlightEndpoint action.
*
Expand Down
6 changes: 3 additions & 3 deletions go/arrow/flight/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type Client interface {
// in order to use the Handshake endpoints of the service.
Authenticate(context.Context, ...grpc.CallOption) error
AuthenticateBasicToken(ctx context.Context, username string, password string, opts ...grpc.CallOption) (context.Context, error)
CancelFlightInfo(ctx context.Context, info *FlightInfo, opts ...grpc.CallOption) (CancelFlightInfoResult, error)
CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (CancelFlightInfoResult, error)
Close() error
CloseFlightInfo(ctx context.Context, info *FlightInfo, opts ...grpc.CallOption) error
RenewFlightEndpoint(ctx context.Context, request *RenewFlightEndpointRequest, opts ...grpc.CallOption) (*FlightEndpoint, error)
Expand Down Expand Up @@ -365,10 +365,10 @@ func ReadUntilEOF(stream FlightService_DoActionClient) error {
}
}

func (c *client) CancelFlightInfo(ctx context.Context, info *FlightInfo, opts ...grpc.CallOption) (result CancelFlightInfoResult, err error) {
func (c *client) CancelFlightInfo(ctx context.Context, request *CancelFlightInfoRequest, opts ...grpc.CallOption) (result CancelFlightInfoResult, err error) {
var action flight.Action
action.Type = CancelFlightInfoActionType
action.Body, err = proto.Marshal(info)
action.Body, err = proto.Marshal(request)
if err != nil {
return
}
Expand Down
4 changes: 2 additions & 2 deletions go/arrow/flight/flightsql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ func (c *Client) CancelQuery(ctx context.Context, info *flight.FlightInfo, opts
return
}

func (c *Client) CancelFlightInfo(ctx context.Context, info *flight.FlightInfo, opts ...grpc.CallOption) (flight.CancelFlightInfoResult, error) {
return c.Client.CancelFlightInfo(ctx, info, opts...)
func (c *Client) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (flight.CancelFlightInfoResult, error) {
return c.Client.CancelFlightInfo(ctx, request, opts...)
}

func (c *Client) CloseFlightInfo(ctx context.Context, info *flight.FlightInfo, opts ...grpc.CallOption) error {
Expand Down
9 changes: 5 additions & 4 deletions go/arrow/flight/flightsql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ func (m *FlightServiceClientMock) AuthenticateBasicToken(_ context.Context, user
return args.Get(0).(context.Context), args.Error(1)
}

func (m *FlightServiceClientMock) CancelFlightInfo(ctx context.Context, info *flight.FlightInfo, opts ...grpc.CallOption) (flight.CancelFlightInfoResult, error) {
args := m.Called(info, opts)
func (m *FlightServiceClientMock) CancelFlightInfo(ctx context.Context, request *flight.CancelFlightInfoRequest, opts ...grpc.CallOption) (flight.CancelFlightInfoResult, error) {
args := m.Called(request, opts)
return args.Get(0).(flight.CancelFlightInfoResult), args.Error(1)
}

Expand Down Expand Up @@ -624,11 +624,12 @@ func (s *FlightSqlClientSuite) TestCancelFlightInfo() {
info, err := s.sqlClient.Execute(context.Background(), query, s.callOpts...)
s.NoError(err)
s.Equal(&emptyFlightInfo, info)
request := flight.CancelFlightInfoRequest{Info: info}
mockedCancelResult := flight.CancelFlightInfoResult{
Status: flight.CancelStatusCancelled,
}
s.mockClient.On("CancelFlightInfo", info, s.callOpts).Return(mockedCancelResult, nil)
cancelResult, err := s.sqlClient.CancelFlightInfo(context.TODO(), info, s.callOpts...)
s.mockClient.On("CancelFlightInfo", &request, s.callOpts).Return(mockedCancelResult, nil)
cancelResult, err := s.sqlClient.CancelFlightInfo(context.TODO(), &request, s.callOpts...)
s.NoError(err)
s.Equal(mockedCancelResult, cancelResult)
}
Expand Down
Loading

0 comments on commit 2ef69be

Please sign in to comment.