Skip to content

Commit

Permalink
Remove needless checks
Browse files Browse the repository at this point in the history
  • Loading branch information
kou committed Jun 23, 2023
1 parent b9abe3a commit af69c9b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 335 deletions.
93 changes: 10 additions & 83 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -617,26 +617,20 @@ class ExpirationTimeDoGetScenario : public Scenario {
ARROW_ASSIGN_OR_RAISE(
auto info, client->GetFlightInfo(FlightDescriptor::Command("expiration_time")));
std::vector<std::shared_ptr<arrow::Table>> tables;
// First read from all endpoints
for (const auto& endpoint : info->endpoints()) {
ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
tables.push_back(table);
}
// Re-reads only from endpoints that have expiration time
for (const auto& endpoint : info->endpoints()) {
if (endpoint.expiration_time.has_value()) {
ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
tables.push_back(table);
if (tables.size() == 0) {
if (endpoint.expiration_time.has_value()) {
return Status::Invalid("endpoints[0] must not have expiration time");
}
} else {
auto reader = client->DoGet(endpoint.ticket);
if (reader.ok()) {
return Status::Invalid(
"Data that doesn't have expiration time "
"shouldn't be readable multiple times");
if (!endpoint.expiration_time.has_value()) {
return Status::Invalid("endpoints[", tables.size(),
"] must have expiration time");
}
}
ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
tables.push_back(table);
}
ARROW_ASSIGN_OR_RAISE(auto table, ConcatenateTables(tables));

Expand All @@ -645,13 +639,9 @@ class ExpirationTimeDoGetScenario : public Scenario {
ARROW_ASSIGN_OR_RAISE(auto builder,
RecordBatchBuilder::Make(schema, arrow::default_memory_pool()));
auto number_builder = builder->GetFieldAs<UInt32Builder>(0);
// First reads
ARROW_RETURN_NOT_OK(number_builder->Append(0));
ARROW_RETURN_NOT_OK(number_builder->Append(1));
ARROW_RETURN_NOT_OK(number_builder->Append(2));
// Re-reads only from endpoints that have expiration time
ARROW_RETURN_NOT_OK(number_builder->Append(1));
ARROW_RETURN_NOT_OK(number_builder->Append(2));
ARROW_ASSIGN_OR_RAISE(auto expected_record_batch, builder->Flush());
std::vector<std::shared_ptr<RecordBatch>> expected_record_batches{
expected_record_batch};
Expand Down Expand Up @@ -779,24 +769,12 @@ class ExpirationTimeRefreshFlightEndpointScenario : public Scenario {
Status RunClient(std::unique_ptr<FlightClient> client) override {
ARROW_ASSIGN_OR_RAISE(auto info,
client->GetFlightInfo(FlightDescriptor::Command("expiration")));
std::vector<std::shared_ptr<arrow::Table>> tables;
// First read from all endpoints
for (const auto& endpoint : info->endpoints()) {
ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
tables.push_back(table);
}
// Refresh all endpoints that have expiration time
std::vector<FlightEndpoint> refreshed_endpoints;
Timestamp max_expiration_time;
for (const auto& endpoint : info->endpoints()) {
if (!endpoint.expiration_time.has_value()) {
continue;
}
const auto& expiration_time = endpoint.expiration_time.value();
if (expiration_time > max_expiration_time) {
max_expiration_time = expiration_time;
}
ARROW_ASSIGN_OR_RAISE(auto refreshed_endpoint,
client->RefreshFlightEndpoint(endpoint));
if (!refreshed_endpoint.expiration_time.has_value()) {
Expand All @@ -809,57 +787,6 @@ class ExpirationTimeRefreshFlightEndpointScenario : public Scenario {
"Original:\n", endpoint.ToString(), "Refreshed:\n",
refreshed_endpoint.ToString());
}
refreshed_endpoints.push_back(std::move(refreshed_endpoint));
}
// Expire all not refreshed endpoints
{
std::vector<Timestamp> refreshed_expiration_times;
for (const auto& endpoint : refreshed_endpoints) {
refreshed_expiration_times.push_back(endpoint.expiration_time.value());
}
std::sort(refreshed_expiration_times.begin(), refreshed_expiration_times.end());
if (refreshed_expiration_times[0] < max_expiration_time) {
return Status::Invalid(
"One or more refreshed expiration time "
"are shorter than original expiration time\n",
"Original: ", max_expiration_time.time_since_epoch().count(), "\n",
"Refreshed: ", refreshed_expiration_times[0].time_since_epoch().count(),
"\n");
}
if (max_expiration_time > Timestamp::clock::now()) {
std::this_thread::sleep_for(max_expiration_time - Timestamp::clock::now());
}
}
// Re-reads only from refreshed endpoints
for (const auto& endpoint : refreshed_endpoints) {
ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
tables.push_back(table);
}
ARROW_ASSIGN_OR_RAISE(auto table, ConcatenateTables(tables));

// Build expected table
auto schema = arrow::schema({arrow::field("number", arrow::uint32(), false)});
ARROW_ASSIGN_OR_RAISE(auto builder,
RecordBatchBuilder::Make(schema, arrow::default_memory_pool()));
auto number_builder = builder->GetFieldAs<UInt32Builder>(0);
// First reads
ARROW_RETURN_NOT_OK(number_builder->Append(0));
ARROW_RETURN_NOT_OK(number_builder->Append(1));
ARROW_RETURN_NOT_OK(number_builder->Append(2));
// Re-reads only from refreshed endpoints
ARROW_RETURN_NOT_OK(number_builder->Append(1));
ARROW_RETURN_NOT_OK(number_builder->Append(2));
ARROW_ASSIGN_OR_RAISE(auto expected_record_batch, builder->Flush());
std::vector<std::shared_ptr<RecordBatch>> expected_record_batches{
expected_record_batch};
ARROW_ASSIGN_OR_RAISE(auto expected_table,
Table::FromRecordBatches(expected_record_batches));

// Check read data
if (!table->Equals(*expected_table)) {
return Status::Invalid("Read data isn't expected\n", "Expected:\n",
expected_table->ToString(), "Actual:\n", table->ToString());
}
return Status::OK();
}
Expand Down
170 changes: 12 additions & 158 deletions go/arrow/internal/flight_integration/scenario.go
Original file line number Diff line number Diff line change
Expand Up @@ -949,8 +949,19 @@ func (tester *expirationTimeDoGetScenarioTester) RunClient(addr string, opts ...
}

var recs []arrow.Record
// First read from all endpoints
for _, ep := range info.Endpoint {
if len(recs) == 0 {
if ep.ExpirationTime != nil {
return fmt.Errorf("endpoints[0] must not have " +
"expiration time")
}
} else {
if ep.ExpirationTime == nil {
return fmt.Errorf("endpoints[1] must have " +
"expiration time")
}
}

if len(ep.Location) != 0 {
return fmt.Errorf("expected to receive empty locations to use the original service: %s",
ep.Location)
Expand All @@ -977,39 +988,6 @@ func (tester *expirationTimeDoGetScenarioTester) RunClient(addr string, opts ...
return rdr.Err()
}
}
// Re-reads only from endpoints that have expiration time
for _, ep := range info.Endpoint {
stream, err := client.DoGet(ctx, ep.Ticket)
if err != nil {
return err
}

rdr, err := flight.NewRecordReader(stream)
if ep.ExpirationTime == nil {
if err == nil {
rdr.Release()
return fmt.Errorf("data that doesn't have " +
"expiration time shouldn't be " +
"readable multiple times")
}
continue
}

if err != nil {
return err
}
defer rdr.Release()

for rdr.Next() {
record := rdr.Record()
record.Retain()
defer record.Release()
recs = append(recs, record)
}
if rdr.Err() != nil {
return rdr.Err()
}
}

// Build expected records
mem := memory.DefaultAllocator
Expand All @@ -1023,8 +1001,6 @@ func (tester *expirationTimeDoGetScenarioTester) RunClient(addr string, opts ...
`[{"number": 0}]`,
`[{"number": 1}]`,
`[{"number": 2}]`,
`[{"number": 1}]`,
`[{"number": 2}]`,
})
defer expectedTable.Release()

Expand Down Expand Up @@ -1190,46 +1166,12 @@ func (tester *expirationTimeRefreshFlightEndpointScenarioTester) RunClient(addr
return err
}

var recs []arrow.Record
// First read from all endpoints
for _, ep := range info.Endpoint {
if len(ep.Location) != 0 {
return fmt.Errorf("expected to receive empty locations to use the original service: %s",
ep.Location)
}

stream, err := client.DoGet(ctx, ep.Ticket)
if err != nil {
return err
}

rdr, err := flight.NewRecordReader(stream)
if err != nil {
return err
}
defer rdr.Release()

for rdr.Next() {
record := rdr.Record()
record.Retain()
defer record.Release()
recs = append(recs, record)
}
if rdr.Err() != nil {
return rdr.Err()
}
}
// Refresh all endpoints that have expiration time
var refreshedEndpoints []*flight.FlightEndpoint
maxExpirationTime := time.Now()
for _, ep := range info.Endpoint {
if ep.ExpirationTime == nil {
continue
}
expirationTime := ep.ExpirationTime.AsTime()
if expirationTime.Sub(maxExpirationTime) > 0 {
maxExpirationTime = expirationTime
}
refreshedEndpoint, err := client.RefreshFlightEndpoint(ctx, ep)
if err != nil {
return err
Expand All @@ -1244,94 +1186,6 @@ func (tester *expirationTimeRefreshFlightEndpointScenarioTester) RunClient(addr
"Original: %s\nRefreshed: %s",
ep, refreshedEndpoint)
}
refreshedEndpoints = append(refreshedEndpoints, refreshedEndpoint)
}
// Expire all not refreshed endpoints
{
var refreshedExpirationTimes []time.Time
for _, ep := range refreshedEndpoints {
refreshedExpirationTimes = append(refreshedExpirationTimes,
ep.ExpirationTime.AsTime())
}
sort.Slice(refreshedExpirationTimes,
func(i int, j int) bool {
a := refreshedExpirationTimes[i]
b := refreshedExpirationTimes[j]
return a.Sub(b) > 0
})
if refreshedExpirationTimes[0].Sub(maxExpirationTime) <= 0 {
return fmt.Errorf(
"one or more refreshed expiration time "+
"are shorter than original expiration time\n"+
"Original: %s\n"+
"Refreshed: %s",
maxExpirationTime,
refreshedExpirationTimes[0])
}
duration := time.Until(maxExpirationTime)
if duration > 0 {
time.Sleep(duration)
}
}
// Re-reads only from refreshed endpoints
for _, ep := range refreshedEndpoints {
stream, err := client.DoGet(ctx, ep.Ticket)
if err != nil {
return err
}

rdr, err := flight.NewRecordReader(stream)
if err != nil {
return err
}
defer rdr.Release()

for rdr.Next() {
record := rdr.Record()
record.Retain()
defer record.Release()
recs = append(recs, record)
}
if rdr.Err() != nil {
return rdr.Err()
}
}

// Build expected records
mem := memory.DefaultAllocator
schema := arrow.NewSchema(
[]arrow.Field{
{Name: "number", Type: arrow.PrimitiveTypes.Uint32},
},
nil,
)
expectedTable, _ := array.TableFromJSON(mem, schema, []string{
`[{"number": 0}]`,
`[{"number": 1}]`,
`[{"number": 2}]`,
`[{"number": 1}]`,
`[{"number": 2}]`,
})
defer expectedTable.Release()

table := array.NewTableFromRecords(schema, recs)
defer table.Release()
if !array.TableEqual(table, expectedTable) {
return fmt.Errorf("read data isn't expected\n"+
"Expected:\n"+
"%s\n"+
"numRows: %d\n"+
"numCols: %d\n"+
"Actual:\n"+
"%s\n"+
"numRows: %d\n"+
"numCols: %d",
expectedTable.Schema(),
expectedTable.NumRows(),
expectedTable.NumCols(),
table.Schema(),
table.NumRows(),
table.NumCols())
}

return nil
Expand Down
Loading

0 comments on commit af69c9b

Please sign in to comment.