|
| 1 | +package bigslice |
| 2 | + |
| 3 | +import ( |
| 4 | + "fmt" |
| 5 | + "reflect" |
| 6 | + |
| 7 | + "github.com/grailbio/base/must" |
| 8 | + "github.com/grailbio/bigslice/slicefunc" |
| 9 | + "github.com/grailbio/bigslice/sliceio" |
| 10 | + "github.com/grailbio/bigslice/typecheck" |
| 11 | +) |
| 12 | + |
| 13 | +func PushReader(nshard int, sinkRead interface{}, prags ...Pragma) Slice { |
| 14 | + fn, ok := slicefunc.Of(sinkRead) |
| 15 | + if !ok || fn.In.NumOut() != 2 || fn.In.Out(0).Kind() != reflect.Int { |
| 16 | + typecheck.Panicf(1, "pushreader: invalid reader function type %T", sinkRead) |
| 17 | + } |
| 18 | + |
| 19 | + var ( |
| 20 | + sinkType = fn.In.Out(1) |
| 21 | + errorType = reflect.TypeOf((*error)(nil)).Elem() |
| 22 | + errorNilValue = reflect.Zero(errorType) |
| 23 | + ) |
| 24 | + |
| 25 | + type state struct { |
| 26 | + sunkC chan []reflect.Value |
| 27 | + err error |
| 28 | + } |
| 29 | + readerFuncImpl := func(args []reflect.Value) []reflect.Value { |
| 30 | + state := args[1].Interface().(*state) |
| 31 | + if state.sunkC == nil { |
| 32 | + state.sunkC = make(chan []reflect.Value, defaultChunksize) |
| 33 | + sinkImpl := func(args []reflect.Value) []reflect.Value { |
| 34 | + state.sunkC <- args |
| 35 | + return nil |
| 36 | + } |
| 37 | + sinkFunc := reflect.MakeFunc(sinkType, sinkImpl) |
| 38 | + go func() { |
| 39 | + defer close(state.sunkC) |
| 40 | + defer func() { |
| 41 | + if p := recover(); p != nil { |
| 42 | + state.err = fmt.Errorf("pushreader: panic from read: %v", p) |
| 43 | + } |
| 44 | + }() |
| 45 | + outs := reflect.ValueOf(sinkRead).Call([]reflect.Value{args[0], sinkFunc}) |
| 46 | + if errI := outs[0].Interface(); errI != nil { |
| 47 | + state.err = errI.(error) |
| 48 | + } |
| 49 | + }() |
| 50 | + } |
| 51 | + |
| 52 | + var rows int |
| 53 | + loop: |
| 54 | + for rows < args[2].Len() { |
| 55 | + select { |
| 56 | + case row := <-state.sunkC: |
| 57 | + if row == nil { |
| 58 | + state.err = sliceio.EOF |
| 59 | + break loop |
| 60 | + } |
| 61 | + must.True(len(row) == len(args[2:]), "%v, %v", len(row), len(args[2:])) |
| 62 | + for c := range row { |
| 63 | + args[2+c].Index(rows).Set(row[c]) |
| 64 | + } |
| 65 | + rows++ |
| 66 | + } |
| 67 | + } |
| 68 | + errValue := errorNilValue |
| 69 | + if state.err != nil { |
| 70 | + errValue = reflect.ValueOf(state.err) |
| 71 | + } |
| 72 | + return []reflect.Value{reflect.ValueOf(rows), errValue} |
| 73 | + } |
| 74 | + readerFuncArgTypes := []reflect.Type{reflect.TypeOf(int(0)), reflect.TypeOf(&state{})} |
| 75 | + for i := 0; i < sinkType.NumIn(); i++ { |
| 76 | + readerFuncArgTypes = append(readerFuncArgTypes, reflect.SliceOf(sinkType.In(i))) |
| 77 | + } |
| 78 | + readerFuncType := reflect.FuncOf(readerFuncArgTypes, |
| 79 | + []reflect.Type{reflect.TypeOf(int(0)), errorType}, false) |
| 80 | + readerFunc := reflect.MakeFunc(readerFuncType, readerFuncImpl) |
| 81 | + |
| 82 | + return ReaderFunc(nshard, readerFunc.Interface(), prags...) |
| 83 | +} |
0 commit comments