diff --git a/spark/sql/types/row.go b/spark/sql/types/row.go index b73ef57..30f4046 100644 --- a/spark/sql/types/row.go +++ b/spark/sql/types/row.go @@ -17,9 +17,15 @@ package types type Row interface { + // At returns field's value at the given index within a [Row]. + // It returns nil for invalid indices. At(index int) any + // Value returns field's value of the given column's name within a [Row]. + // It returns nil for invalid column's name. Value(name string) any + // Values returns values of all fields within a [Row] as a slice of any. Values() []any + // Len returns the number of fields within a [Row] Len() int FieldNames() []string } @@ -30,11 +36,18 @@ type rowImpl struct { } func (r *rowImpl) At(index int) any { + if index < 0 || index > len(r.values) { + return nil + } return r.values[index] } func (r *rowImpl) Value(name string) any { - return r.values[r.offsets[name]] + idx, ok := r.offsets[name] + if !ok { + return nil + } + return r.values[idx] } func (r *rowImpl) Values() []any { @@ -46,7 +59,7 @@ func (r *rowImpl) Len() int { } func (r *rowImpl) FieldNames() []string { - names := make([]string, len(r.offsets)) + var names []string for name := range r.offsets { names = append(names, name) } diff --git a/spark/sql/types/row_test.go b/spark/sql/types/row_test.go new file mode 100644 index 0000000..791a56a --- /dev/null +++ b/spark/sql/types/row_test.go @@ -0,0 +1,110 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// rowImplTest is a read-only sample [Row] to be used in all tests. +var rowImplSample rowImpl = rowImpl{ + values: []any{1, 2, 3, 4, 5}, + offsets: map[string]int{ + "one": 0, + "two": 1, + "three": 2, + "four": 3, + "five": 4, + }, +} + +func TestRowImpl_At(t *testing.T) { + testCases := []struct { + name string + input int + exp any + }{ + { + name: "index within range", + input: 2, + exp: 3, + }, + { + name: "index out of range", + input: 6, + exp: nil, + }, + { + name: "negative index", + input: -1, + exp: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + act := rowImplSample.At(tc.input) + require.Equal(t, tc.exp, act) + }) + } +} + +func TestRowImpl_Value(t *testing.T) { + testCases := []struct { + name string + input string + exp any + }{ + { + name: "valid field name", + input: "two", + exp: 2, + }, + { + name: "invalid field name", + input: "six", + exp: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + act := rowImplSample.Value(tc.input) + require.Equal(t, tc.exp, act) + }) + } +} + +func TestRowImpl_Values(t *testing.T) { + exp := []any{1, 2, 3, 4, 5} + act := rowImplSample.Values() + require.Equal(t, exp, act) +} + +func TestRowImpl_Len(t *testing.T) { + exp := 5 + act := rowImplSample.Len() + require.Equal(t, exp, act) +} + +func TestRowImpl_FieldNames(t *testing.T) { + exp := []string{"one", "two", "three", "four", "five"} + act := rowImplSample.FieldNames() + require.ElementsMatch(t, exp, act) +}