diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 67b54e7..fef08ec 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -22,7 +22,10 @@ jobs: go-version: 1.21 - name: Libraries - run: sudo apt-get install -y libpcap-dev + run: sudo apt-get install -y libpcap-dev libluajit-5.1-dev + + - name: LuaJIT + run: git clone https://luajit.org/git/luajit-2.0.git && cd luajit-2.0 && make CCOPT="-static -fPIC" BUILDMODE="static" && sudo make install - name: Build run: go build -ldflags "-s -w" -o heplify *.go diff --git a/README.md b/README.md index 6a8f7bb..c3e5887 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,35 @@ Download [heplify.exe](https://github.com/sipcapture/heplify/releases) If you have Go 1.18+ installed, build the latest heplify binary by running `make`. +Now you should install LUA Jit: + +* Compile from sources: + + Install luajit dev libary + + `apt-get install libluajit-5.1-dev` + + or + + `yum install luajit-devel` + + or for macOS + + ```sh + # Assuming brew installs to /usr/local/ + brew install lua@5.1 luajit + ln -s /usr/local/lib/pkgconfig/luajit.pc /usr/local/lib/pkgconfig/luajit-5.1.pc + export PKG_CONFIG_PATH=/usr/local/lib/pkgconfig/ + ``` + + [install](https://golang.org/doc/install) Go 1.11+ + + `go build cmd/heplify/heplify.go` + + + + + You can also build a docker image: ```bash diff --git a/config/config.go b/config/config.go index f41c5cf..d71b04c 100644 --- a/config/config.go +++ b/config/config.go @@ -31,6 +31,8 @@ type Config struct { SendRetries uint KeepAlive uint Version bool + ScriptFile string + ScriptHEPFilter []int SkipVerify bool HEPBufferDebug bool HEPBufferEnable bool diff --git a/decoder/decoder.go b/decoder/decoder.go index 08289cf..5968e37 100644 --- a/decoder/decoder.go +++ b/decoder/decoder.go @@ -4,11 +4,14 @@ import ( "bytes" "container/list" "net" + "reflect" "strconv" "strings" "sync/atomic" "time" + "unsafe" + "github.com/VictoriaMetrics/fastcache" "github.com/segmentio/encoding/json" "github.com/google/gopacket" @@ -24,7 +27,10 @@ import ( "github.com/sipcapture/heplify/protos" ) -var PacketQueue = make(chan *Packet, 20000) +var ( + PacketQueue = make(chan *Packet, 20000) + scriptCache = fastcache.New(32 * 1024 * 1024) +) type CachePayload struct { SrcIP net.IP `json:"src_ip" default:""` @@ -1074,3 +1080,93 @@ func (d *Decoder) SendPingHEPPacket() { PacketQueue <- pkt } + +func stb(s string) []byte { + sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) + var res []byte + + bh := (*reflect.SliceHeader)((unsafe.Pointer(&res))) + bh.Data = sh.Data + bh.Len = sh.Len + bh.Cap = sh.Len + return res +} + +// Packet +func (pkt *Packet) GetVersion() uint32 { + if pkt != nil { + return uint32(pkt.Version) + } + return 0 +} + +func (pkt *Packet) GetProtocol() uint32 { + if pkt != nil { + return uint32(pkt.Protocol) + } + return 0 +} + +func (pkt *Packet) GetSrcIP() string { + if pkt != nil { + return pkt.SrcIP.String() + } + return "" +} + +func (pkt *Packet) GetDstIP() string { + if pkt != nil { + return pkt.DstIP.String() + } + return "" +} + +func (pkt *Packet) GetSrcPort() uint16 { + if pkt != nil { + return pkt.SrcPort + } + + return 0 +} + +func (pkt *Packet) GetDstPort() uint16 { + if pkt != nil { + return pkt.DstPort + } + return 0 +} + +func (pkt *Packet) GetTsec() uint32 { + if pkt != nil { + return pkt.Tsec + } + return 0 +} + +func (pkt *Packet) GetTmsec() uint32 { + if pkt != nil { + return pkt.Tmsec + } + return 0 +} + +func (pkt *Packet) GetProtoType() uint32 { + if pkt != nil { + return uint32(pkt.ProtoType) + } + return 0 +} + +func (pkt *Packet) GetPayload() string { + if pkt != nil { + return string(pkt.Payload) + } + return "" +} + +func (pkt *Packet) GetCID() string { + if pkt != nil { + return string(pkt.CID) + } + return "" +} diff --git a/decoder/luaengine.go b/decoder/luaengine.go new file mode 100644 index 0000000..8a68a99 --- /dev/null +++ b/decoder/luaengine.go @@ -0,0 +1,162 @@ +package decoder + +import ( + "fmt" + "net" + "strconv" + + "github.com/negbie/logp" + "github.com/sipcapture/golua/lua" + "github.com/sipcapture/heplify/decoder/luar" +) + +// LuaEngine +type LuaEngine struct { + /* pointer to modify */ + pkt **Packet + functions []string + LuaEngine *lua.State +} + +func (d *LuaEngine) GetHEPProtoType() uint32 { + return (*d.pkt).GetProtoType() +} + +func (d *LuaEngine) GetHEPSrcIP() string { + return (*d.pkt).GetSrcIP() +} + +func (d *LuaEngine) GetHEPSrcPort() uint16 { + return (*d.pkt).GetSrcPort() +} + +func (d *LuaEngine) GetHEPDstIP() string { + return (*d.pkt).GetDstIP() +} + +func (d *LuaEngine) GetHEPDstPort() uint16 { + return (*d.pkt).GetDstPort() +} + +func (d *LuaEngine) GetHEPTimeSeconds() uint32 { + return (*d.pkt).GetTsec() +} + +func (d *LuaEngine) GetHEPTimeUseconds() uint32 { + return (*d.pkt).GetTmsec() +} + +func (d *LuaEngine) GetRawMessage() string { + return (*d.pkt).GetPayload() +} + +func (d *LuaEngine) SetRawMessage(value string) { + if (*d.pkt) == nil { + logp.Err("can't set Raw message if HEP struct is nil, please check for nil in lua script") + return + } + pkt := *d.pkt + pkt.Payload = []byte(value) +} + +func (d *LuaEngine) SetHEPField(field string, value string) { + if (*d.pkt) == nil { + logp.Err("can't set HEP field if HEP struct is nil, please check for nil in lua script") + return + } + pkt := *d.pkt + + switch field { + case "ProtoType": + if i, err := strconv.Atoi(value); err == nil { + pkt.ProtoType = byte(i) + } + case "SrcIP": + pkt.SrcIP = net.ParseIP(value).To4() + case "SrcPort": + if i, err := strconv.Atoi(value); err == nil { + pkt.SrcPort = uint16(i) + } + case "DstIP": + pkt.DstIP = net.ParseIP(value).To4() + case "DstPort": + if i, err := strconv.Atoi(value); err == nil { + pkt.DstPort = uint16(i) + } + + case "CID": + pkt.CID = []byte(value) + + } +} + +func (d *LuaEngine) Logp(level string, message string, data interface{}) { + if level == "ERROR" { + logp.Err("[script] %s: %v", message, data) + } else { + logp.Debug("[script] %s: %v", message, data) + } +} + +func (d *LuaEngine) Close() { + d.LuaEngine.Close() +} + +// NewLuaEngine returns the script engine struct +func NewLuaEngine() (*LuaEngine, error) { + logp.Debug("script", "register Lua engine") + + d := &LuaEngine{} + d.LuaEngine = lua.NewState() + d.LuaEngine.OpenLibs() + + luar.Register(d.LuaEngine, "", luar.Map{ + "GetHEPProtoType": d.GetHEPProtoType, + "GetHEPSrcIP": d.GetHEPSrcIP, + "GetHEPSrcPort": d.GetHEPSrcPort, + "GetHEPDstIP": d.GetHEPDstIP, + "GetHEPDstPort": d.GetHEPDstPort, + "GetHEPTimeSeconds": d.GetHEPTimeSeconds, + "GetHEPTimeUseconds": d.GetHEPTimeUseconds, + "GetRawMessage": d.GetRawMessage, + "SetRawMessage": d.SetRawMessage, + "SetHEPField": d.SetHEPField, + "HashTable": HashTable, + "HashString": HashString, + "Logp": d.Logp, + "Print": fmt.Println, + }) + + _, code, err := scanCode() + if err != nil { + logp.Err("Error in scan script: %v", err) + return nil, err + } + + err = d.LuaEngine.DoString(code.String()) + if err != nil { + logp.Err("Error in lua script: %v", err) + return nil, err + } + + d.functions = extractFunc(code) + if len(d.functions) < 1 { + logp.Err("no function name found in lua scripts: %v", err) + return nil, fmt.Errorf("no function name found in lua scripts") + } + + return d, nil +} + +// Run will execute the script +func (d *LuaEngine) Run(pkt *Packet) error { + /* preload */ + d.pkt = &pkt + for _, v := range d.functions { + err := d.LuaEngine.DoString(v) + if err != nil { + return err + } + } + return nil +} diff --git a/decoder/luar/luaobject.go b/decoder/luar/luaobject.go new file mode 100644 index 0000000..aa27d46 --- /dev/null +++ b/decoder/luar/luaobject.go @@ -0,0 +1,443 @@ +package luar + +import ( + "errors" + "reflect" + + "github.com/sipcapture/golua/lua" +) + +// LuaObject encapsulates a Lua object like a table or a function. +// +// We do not make the type distinction since metatables can make tables callable +// and functions indexable. +type LuaObject struct { + l *lua.State + ref int +} + +var ( + ErrLuaObjectCallResults = errors.New("results must be a pointer to pointer/slice/struct") + ErrLuaObjectCallable = errors.New("LuaObject must be callable") + ErrLuaObjectIndexable = errors.New("not indexable") + ErrLuaObjectUnsharedState = errors.New("LuaObjects must share the same state") +) + +// NewLuaObject creates a new LuaObject from stack index. +func NewLuaObject(L *lua.State, idx int) *LuaObject { + L.PushValue(idx) + ref := L.Ref(lua.LUA_REGISTRYINDEX) + return &LuaObject{l: L, ref: ref} +} + +// NewLuaObjectFromName creates a new LuaObject from the object designated by +// the sequence of 'subfields'. +func NewLuaObjectFromName(L *lua.State, subfields ...interface{}) *LuaObject { + L.GetGlobal("_G") + defer L.Pop(1) + err := get(L, subfields...) + if err != nil { + return nil + } + val := NewLuaObject(L, -1) + L.Pop(1) + return val +} + +// NewLuaObjectFromValue creates a new LuaObject from a Go value. +// Note that this will convert any slices or maps into Lua tables. +func NewLuaObjectFromValue(L *lua.State, val interface{}) *LuaObject { + GoToLua(L, val) + return NewLuaObject(L, -1) +} + +// Call calls a Lua function, given the desired results and the arguments. +// 'results' must be a pointer to a pointer/struct/slice. +// +// - If a pointer, then only the first result is stored to that pointer. +// +// - If a struct with 'n' exported fields, then the first 'n' results are stored in the first 'n' exported fields. +// +// - If a slice, then all the results are stored in the slice. The slice is re-allocated if necessary. +// +// If the function returns more values than can be stored in the 'results' +// argument, they will be ignored. +// +// If 'results' is nil, results will be discarded. +func (lo *LuaObject) Call(results interface{}, args ...interface{}) error { + L := lo.l + // Push the callable value. + lo.Push() + if !L.IsFunction(-1) { + if !L.GetMetaField(-1, "__call") { + L.Pop(1) + return ErrLuaObjectCallable + } + // We leave the __call metamethod on stack. + L.Remove(-2) + } + + // Push the args. + for _, arg := range args { + GoToLuaProxy(L, arg) + } + + // Special case: discard the results. + if results == nil { + err := L.Call(len(args), 0) + if err != nil { + L.Pop(1) + return err + } + return nil + } + + resptr := reflect.ValueOf(results) + if resptr.Kind() != reflect.Ptr { + return ErrLuaObjectCallResults + } + res := resptr.Elem() + + switch res.Kind() { + case reflect.Ptr: + err := L.Call(len(args), 1) + defer L.Pop(1) + if err != nil { + return err + } + return LuaToGo(L, -1, res.Interface()) + + case reflect.Slice: + residx := L.GetTop() - len(args) + err := L.Call(len(args), lua.LUA_MULTRET) + if err != nil { + L.Pop(1) + return err + } + + nresults := L.GetTop() - residx + 1 + defer L.Pop(nresults) + t := res.Type() + + // Adjust the length of the slice. + if res.IsNil() || nresults > res.Len() { + v := reflect.MakeSlice(t, nresults, nresults) + res.Set(v) + } else if nresults < res.Len() { + res.SetLen(nresults) + } + + for i := 0; i < nresults; i++ { + err = LuaToGo(L, residx+i, res.Index(i).Addr().Interface()) + if err != nil { + return err + } + } + + case reflect.Struct: + exportedFields := []reflect.Value{} + for i := 0; i < res.NumField(); i++ { + if res.Field(i).CanInterface() { + exportedFields = append(exportedFields, res.Field(i).Addr()) + } + } + nresults := len(exportedFields) + err := L.Call(len(args), nresults) + if err != nil { + L.Pop(1) + return err + } + defer L.Pop(nresults) + residx := L.GetTop() - nresults + 1 + + for i := 0; i < nresults; i++ { + err = LuaToGo(L, residx+i, exportedFields[i].Interface()) + if err != nil { + return err + } + } + + default: + return ErrLuaObjectCallResults + } + + return nil +} + +// Close frees the Lua reference of this object. +func (lo *LuaObject) Close() { + lo.l.Unref(lua.LUA_REGISTRYINDEX, lo.ref) +} + +// get pushes the Lua value indexed at the sequence of 'subfields' from the +// indexable value on top of the stack. +// +// It pushes nothing on error. +// +// Numeric indices start from 1: see Set(). +func get(L *lua.State, subfields ...interface{}) error { + // TODO: See if worth exporting. + + // Duplicate iterable since the following loop removes the last table on stack + // and we don't want to pop it to be consistent with lua.GetField and + // lua.GetTable. + L.PushValue(-1) + + for _, field := range subfields { + if L.IsTable(-1) { + GoToLua(L, field) + L.GetTable(-2) + } else if L.GetMetaField(-1, "__index") { + L.PushValue(-2) + GoToLua(L, field) + err := L.Call(2, 1) + if err != nil { + L.Pop(1) + return err + } + } else { + return ErrLuaObjectIndexable + } + // Remove last iterable. + L.Remove(-2) + } + return nil +} + +// Get stores in 'a' the Lua value indexed at the sequence of 'subfields'. +// 'a' must be a pointer as in LuaToGo. +func (lo *LuaObject) Get(a interface{}, subfields ...interface{}) error { + lo.Push() + defer lo.l.Pop(1) + err := get(lo.l, subfields...) + if err != nil { + return err + } + defer lo.l.Pop(1) + return LuaToGo(lo.l, -1, a) +} + +// GetObject returns the LuaObject indexed at the sequence of 'subfields'. +func (lo *LuaObject) GetObject(subfields ...interface{}) (*LuaObject, error) { + lo.Push() + defer lo.l.Pop(1) + err := get(lo.l, subfields...) + if err != nil { + return nil, err + } + val := NewLuaObject(lo.l, -1) + lo.l.Pop(1) + return val, nil +} + +// Push pushes this LuaObject on the stack. +func (lo *LuaObject) Push() { + lo.l.RawGeti(lua.LUA_REGISTRYINDEX, lo.ref) +} + +// Set sets the value at the sequence of 'subfields' with the value 'a'. +// Numeric indices start from 1, as in Lua: if we started from zero, access to +// index 0 or negative indices would be shifted awkwardly. +func (lo *LuaObject) Set(a interface{}, subfields ...interface{}) error { + parentKeys := subfields[:len(subfields)-1] + parent, err := lo.GetObject(parentKeys...) + if err != nil { + return err + } + + L := parent.l + parent.Push() + defer L.Pop(1) + + lastField := subfields[len(subfields)-1] + if L.IsTable(-1) { + GoToLuaProxy(L, lastField) + GoToLuaProxy(L, a) + L.SetTable(-3) + } else if L.GetMetaField(-1, "__newindex") { + L.PushValue(-2) + GoToLuaProxy(L, lastField) + GoToLuaProxy(L, a) + err := L.Call(3, 0) + if err != nil { + L.Pop(1) + return err + } + } else { + return ErrLuaObjectIndexable + } + return nil +} + +// Setv copies values between two tables in the same Lua state. +// It overwrites existing values. +func (lo *LuaObject) Setv(src *LuaObject, keys ...string) error { + // TODO: Rename? This function seems to be too specialized, is it worth + // keeping at all? + L := lo.l + if L != src.l { + return ErrLuaObjectUnsharedState + } + lo.Push() + defer L.Pop(1) + loIdx := L.GetTop() + + var set func(int, string) + if L.IsTable(loIdx) { + set = L.SetField + } else if L.GetMetaField(loIdx, "__newindex") { + L.Pop(1) + set = func(idx int, key string) { + resultIdx := L.GetTop() + L.GetMetaField(loIdx, "__newindex") + L.PushValue(loIdx) + L.PushString(key) + L.PushValue(resultIdx) + L.Remove(resultIdx) + L.Call(3, 0) + } + } else { + return ErrLuaObjectIndexable + } + + src.Push() + defer src.l.Pop(1) + srcIdx := L.GetTop() + var get func(int, string) + if L.IsTable(srcIdx) { + get = L.GetField + } else if L.GetMetaField(srcIdx, "__index") { + L.Pop(1) + get = func(idx int, key string) { + L.GetMetaField(srcIdx, "__index") + L.PushValue(srcIdx) + L.PushString(key) + L.Call(2, 1) + } + } else { + return ErrLuaObjectIndexable + } + + for _, key := range keys { + get(srcIdx, key) + set(loIdx, key) + } + + return nil +} + +// LuaTableIter is the Go equivalent of a Lua table iterator. +type LuaTableIter struct { + lo *LuaObject + // keyRef is LUA_NOREF before iteration. + keyRef int + // Reference to the iterator in case the metamethod gets changed while + // iterating. + iterRef int + // TODO: See if this is an idiomatic implementation of error storage. + err error +} + +// Error returns the error that happened during last iteration, if any. +func (ti *LuaTableIter) Error() error { + return ti.err +} + +// Iter creates a Lua iterator. +func (lo *LuaObject) Iter() (*LuaTableIter, error) { + L := lo.l + lo.Push() + defer L.Pop(1) + if L.IsTable(-1) { + return &LuaTableIter{lo: lo, keyRef: lua.LUA_NOREF, iterRef: lua.LUA_NOREF}, nil + } else if L.GetMetaField(-1, "__pairs") { + // __pairs(t) = iterator, t, first-key. + L.PushValue(-2) + // Only keep iterator on stack, hence '1' result only. + err := L.Call(1, 1) + if err != nil { + L.Pop(1) + return nil, err + } + ref := L.Ref(lua.LUA_REGISTRYINDEX) + return &LuaTableIter{lo: lo, keyRef: lua.LUA_NOREF, iterRef: ref}, nil + } else { + return nil, ErrLuaObjectIndexable + } +} + +// Next gets the next key/value pair from the indexable value. +// +// 'value' must be a valid argument for LuaToGo. As a special case, 'value' can +// be nil to make it possible to loop over keys without caring about associated +// values. +func (ti *LuaTableIter) Next(key, value interface{}) bool { + if ti.lo == nil { + ti.err = errors.New("empty iterator") + return false + } + L := ti.lo.l + + if ti.iterRef == lua.LUA_NOREF { + // Must be a table. This requires the Iter() function to set + // ref=LUA_NOREF. + + // Push table. + ti.lo.Push() + defer L.Pop(1) + + if ti.keyRef == lua.LUA_NOREF { + L.PushNil() + } else { + L.RawGeti(lua.LUA_REGISTRYINDEX, ti.keyRef) + } + + if L.Next(-2) == 0 { + L.Unref(lua.LUA_REGISTRYINDEX, ti.keyRef) + return false + } + + } else { + L.RawGeti(lua.LUA_REGISTRYINDEX, ti.iterRef) + ti.lo.Push() + + if ti.keyRef == lua.LUA_NOREF { + L.PushNil() + } else { + L.RawGeti(lua.LUA_REGISTRYINDEX, ti.keyRef) + } + + err := L.Call(2, 2) + if err != nil { + L.Pop(1) + ti.err = err + return false + } + if L.IsNil(-2) { + L.Pop(2) + L.Unref(lua.LUA_REGISTRYINDEX, ti.iterRef) + return false + } + } + + err := LuaToGo(L, -2, key) + if err != nil { + ti.err = err + return false + } + if value != nil { + err = LuaToGo(L, -1, value) + if err != nil { + ti.err = err + return false + } + } + + // Drop value, key is now on top. + L.Pop(1) + + // Replace former key reference with new key. + L.Unref(lua.LUA_REGISTRYINDEX, ti.keyRef) + ti.keyRef = L.Ref(lua.LUA_REGISTRYINDEX) + return true +} diff --git a/decoder/luar/luar.go b/decoder/luar/luar.go new file mode 100644 index 0000000..68136c5 --- /dev/null +++ b/decoder/luar/luar.go @@ -0,0 +1,959 @@ +// Copyright (c) 2010-2016 Steve Donovan + +package luar + +import ( + "errors" + "fmt" + "reflect" + + "github.com/sipcapture/golua/lua" +) + +// ConvError records a conversion error from value 'From' to value 'To'. +type ConvError struct { + From interface{} + To interface{} +} + +// ErrTableConv arises when some table entries could not be converted. +// The table conversion result is usable. +// TODO: Work out a more relevant name. +// TODO: Should it be a type instead embedding the actual error? +var ErrTableConv = errors.New("some table elements could not be converted") + +func (l ConvError) Error() string { + return fmt.Sprintf("cannot convert %v to %v", l.From, l.To) +} + +// Lua 5.1 'lua_tostring' function only supports string and numbers. Extend it for internal purposes. +// From the Lua 5.3 source code. +func luaToString(L *lua.State, idx int) string { + switch L.Type(idx) { + case lua.LUA_TNUMBER: + L.PushValue(idx) + defer L.Pop(1) + return L.ToString(-1) + case lua.LUA_TSTRING: + return L.ToString(-1) + case lua.LUA_TBOOLEAN: + b := L.ToBoolean(idx) + if b { + return "true" + } + return "false" + case lua.LUA_TNIL: + return "nil" + } + return fmt.Sprintf("%s: %v", L.LTypename(idx), L.ToPointer(idx)) +} + +func luaDesc(L *lua.State, idx int) string { + return fmt.Sprintf("Lua value '%v' (%v)", luaToString(L, idx), L.LTypename(idx)) +} + +// NullT is the type of Null. +// Having a dedicated type allows us to make the distinction between zero values and Null. +type NullT int + +// Map is an alias for map of strings. +type Map map[string]interface{} + +var ( + // Null is the definition of 'luar.null' which is used in place of 'nil' when + // converting slices and structs. + Null = NullT(0) +) + +var ( + tslice = typeof((*[]interface{})(nil)) + tmap = typeof((*map[string]interface{})(nil)) + nullv = reflect.ValueOf(Null) +) + +// visitor holds the index to the table in LUA_REGISTRYINDEX with all the tables +// we ran across during a GoToLua conversion. +type visitor struct { + L *lua.State + index int +} + +func newVisitor(L *lua.State) visitor { + var v visitor + v.L = L + v.L.NewTable() + v.index = v.L.Ref(lua.LUA_REGISTRYINDEX) + return v +} + +func (v *visitor) close() { + v.L.Unref(lua.LUA_REGISTRYINDEX, v.index) +} + +// Mark value on top of the stack as visited using the registry index. +func (v *visitor) mark(val reflect.Value) { + ptr := val.Pointer() + if ptr == 0 { + // We do not mark uninitialized 'val' as this is meaningless and this would + // bind all uninitialized values to the same mark. + return + } + + v.L.RawGeti(lua.LUA_REGISTRYINDEX, v.index) + // Copy value on top. + v.L.PushValue(-2) + // Set value to table. + // TODO: Handle overflow. + v.L.RawSeti(-2, int(ptr)) + v.L.Pop(1) +} + +// Push visited value on top of the stack. +// If the value was not visited, return false and push nothing. +func (v *visitor) push(val reflect.Value) bool { + ptr := val.Pointer() + v.L.RawGeti(lua.LUA_REGISTRYINDEX, v.index) + v.L.RawGeti(-1, int(ptr)) + if v.L.IsNil(-1) { + // Not visited. + v.L.Pop(2) + return false + } + v.L.Replace(-2) + return true +} + +// Init makes and initializes a new pre-configured Lua state. +// +// It populates the 'luar' table with some helper functions/values: +// +// method: ProxyMethod +// unproxify: Unproxify +// +// chan: MakeChan +// complex: MakeComplex +// map: MakeMap +// slice: MakeSlice +// +// null: Null +// +// It replaces the 'pairs'/'ipairs' functions with ProxyPairs/ProxyIpairs +// respectively, so that __pairs/__ipairs can be used, Lua 5.2 style. It allows +// for looping over Go composite types and strings. +// +// It also replaces the 'type' function with ProxyType. +// +// It is not required for using the 'GoToLua' and 'LuaToGo' functions. +func Init() *lua.State { + var L = lua.NewState() + L.OpenLibs() + Register(L, "luar", Map{ + // Functions. + "unproxify": Unproxify, + + "method": ProxyMethod, + + "chan": MakeChan, + "complex": Complex, + "map": MakeMap, + "slice": MakeSlice, + + // Values. + "null": Null, + }) + Register(L, "", Map{ + "pairs": ProxyPairs, + "type": ProxyType, + }) + // 'ipairs' needs a special case for performance reasons. + RegProxyIpairs(L, "", "ipairs") + return L +} + +func isNil(v reflect.Value) bool { + nullables := [...]bool{ + reflect.Chan: true, + reflect.Func: true, + reflect.Interface: true, + reflect.Map: true, + reflect.Ptr: true, + reflect.Slice: true, + } + + kind := v.Type().Kind() + if int(kind) >= len(nullables) { + return false + } + return nullables[kind] && v.IsNil() +} + +func copyMapToTable(L *lua.State, v reflect.Value, visited visitor) { + n := v.Len() + L.CreateTable(0, n) + visited.mark(v) + for _, key := range v.MapKeys() { + val := v.MapIndex(key) + goToLua(L, key, true, visited) + if isNil(val) { + val = nullv + } + goToLua(L, val, false, visited) + L.SetTable(-3) + } +} + +// Also for arrays. +func copySliceToTable(L *lua.State, v reflect.Value, visited visitor) { + vp := v + for v.Kind() == reflect.Ptr { + // For arrays. + v = v.Elem() + } + + n := v.Len() + L.CreateTable(n, 0) + if v.Kind() == reflect.Slice { + visited.mark(v) + } else if vp.Kind() == reflect.Ptr { + visited.mark(vp) + } + + for i := 0; i < n; i++ { + L.PushInteger(int64(i + 1)) + val := v.Index(i) + if isNil(val) { + val = nullv + } + goToLua(L, val, false, visited) + L.SetTable(-3) + } +} + +func copyStructToTable(L *lua.State, v reflect.Value, visited visitor) { + // If 'vstruct' is a pointer to struct, use the pointer to mark as visited. + vp := v + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + + n := v.NumField() + L.CreateTable(n, 0) + if vp.Kind() == reflect.Ptr { + visited.mark(vp) + } + + for i := 0; i < n; i++ { + st := v.Type() + field := st.Field(i) + key := field.Name + tag := field.Tag.Get("lua") + if tag != "" { + key = tag + } + goToLua(L, key, false, visited) + val := v.Field(i) + goToLua(L, val, false, visited) + L.SetTable(-3) + } +} + +func callGoFunction(L *lua.State, v reflect.Value, args []reflect.Value) []reflect.Value { + defer func() { + if x := recover(); x != nil { + L.RaiseError(fmt.Sprintf("error %s", x)) + } + }() + results := v.Call(args) + return results +} + +func goToLuaFunction(L *lua.State, v reflect.Value) lua.LuaGoFunction { + switch f := v.Interface().(type) { + case func(*lua.State) int: + return f + } + + t := v.Type() + argsT := make([]reflect.Type, t.NumIn()) + for i := range argsT { + argsT[i] = t.In(i) + } + + return func(L *lua.State) int { + var lastT reflect.Type + isVariadic := t.IsVariadic() + + if isVariadic { + n := len(argsT) + lastT = argsT[n-1].Elem() + argsT = argsT[:n-1] + } + + args := make([]reflect.Value, len(argsT)) + for i, t := range argsT { + val := reflect.New(t) + err := LuaToGo(L, i+1, val.Interface()) + if err != nil { + L.RaiseError(fmt.Sprintf("cannot convert Go function argument #%v: %v", i, err)) + } + args[i] = val.Elem() + } + + if isVariadic { + n := L.GetTop() + for i := len(argsT) + 1; i <= n; i++ { + val := reflect.New(lastT) + err := LuaToGo(L, i, val.Interface()) + if err != nil { + L.RaiseError(fmt.Sprintf("cannot convert Go function argument #%v: %v", i, err)) + } + args = append(args, val.Elem()) + } + argsT = argsT[:len(argsT)+1] + } + results := callGoFunction(L, v, args) + for _, val := range results { + GoToLuaProxy(L, val) + } + return len(results) + } +} + +// GoToLua pushes a Go value 'val' on the Lua stack. +// +// It unboxes interfaces. +// +// Pointers are followed recursively. Slices, structs and maps are copied over as tables. +func GoToLua(L *lua.State, a interface{}) { + visited := newVisitor(L) + goToLua(L, a, false, visited) + visited.close() +} + +// GoToLuaProxy is like GoToLua but pushes a proxy on the Lua stack when it makes sense. +// +// A proxy is a Lua userdata that wraps a Go value. +// +// Proxies have several uses: +// +// - Type checking in Go function calls, so variable of user-defined type are +// always profixied. +// +// - Reflexive modification of the Go data straight from the Lua code. We only +// allow this for compound types. +// +// - Call methods of user-defined types. +// +// Predeclared scalar types are never proxified as they have no methods and we +// only allow compound types to be set reflexively. +// +// Structs are always proxified since their type is always user-defined. If they +// they are not settable (e.g. not nested, not passed by reference, value of a +// map), then a copy is passed as a proxy (otherwise setting the fields from Lua +// would panic). This will not impact the corresponding Go value. +// +// Arrays are only proxified if they are settable (so that the user can set the +// Go value from the Lua side) or if they are of a user-defined type (method +// calls or function parameters). If the type user-defined but the array is not +// settable, then a proxy of a copy is made, just as for structs. +// +// Lua cannot dereference pointers and Go can only call methods over one level +// of indirection at maximum. Thus proxies wrap around values dereferenced up to +// the last pointer. +// +// Go functions can be passed to Lua. If the parameters require several levels +// of indirections, the arguments will be converted automatically. Since proxies +// can only wrap around one level of indirection, functions modifying the value +// of the pointers after one level of indirection will have no effect. +func GoToLuaProxy(L *lua.State, a interface{}) { + visited := newVisitor(L) + goToLua(L, a, true, visited) + visited.close() +} + +func goToLua(L *lua.State, a interface{}, proxify bool, visited visitor) { + var v reflect.Value + v, ok := a.(reflect.Value) + if !ok { + v = reflect.ValueOf(a) + } + if !v.IsValid() { + L.PushNil() + return + } + + if v.Kind() == reflect.Interface && !v.IsNil() { + // Unbox interface. + v = reflect.ValueOf(v.Interface()) + } + + // Follow pointers if not proxifying. We save the parent pointer Value in case + // we proxify since Lua cannot dereference pointers and has no use of + // multiple-level references, while single references are useful for method + // calls functions that make use of one level of indirection. + vp := v + for v.Kind() == reflect.Ptr { + vp = v + v = v.Elem() + } + + if !v.IsValid() { + L.PushNil() + return + } + + // As a special case, we always proxify Null, the empty element for slices and maps. + if v.CanInterface() && v.Interface() == Null { + makeValueProxy(L, v, cInterfaceMeta) + return + } + + switch v.Kind() { + case reflect.Float64, reflect.Float32: + if proxify && isNewType(v.Type()) { + makeValueProxy(L, vp, cNumberMeta) + } else { + L.PushNumber(v.Float()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if proxify && isNewType(v.Type()) { + makeValueProxy(L, vp, cNumberMeta) + } else { + L.PushNumber(float64(v.Int())) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if proxify && isNewType(v.Type()) { + makeValueProxy(L, vp, cNumberMeta) + } else { + L.PushNumber(float64(v.Uint())) + } + case reflect.String: + if proxify && isNewType(v.Type()) { + makeValueProxy(L, vp, cStringMeta) + } else { + L.PushString(v.String()) + } + case reflect.Bool: + if proxify && isNewType(v.Type()) { + makeValueProxy(L, vp, cInterfaceMeta) + } else { + L.PushBoolean(v.Bool()) + } + case reflect.Complex128, reflect.Complex64: + makeValueProxy(L, vp, cComplexMeta) + case reflect.Array: + if proxify { + // To check if it is a user-defined type, we compare its type to that of a + // new go array with the same length and the same element type. + vRawType := reflect.ArrayOf(v.Type().Len(), v.Type().Elem()) + if vRawType != v.Type() || v.CanSet() { + if !v.CanSet() { + vp = reflect.New(v.Type()) + reflect.Copy(vp.Elem(), v) + // 'vp' is a pointer of v.Type(), we want the dereferenced type. + vp = vp.Elem() + } + makeValueProxy(L, vp, cSliceMeta) + return + } + // Else don't proxify. + } + // See the case of struct. + if vp.Kind() == reflect.Ptr && visited.push(vp) { + return + } + copySliceToTable(L, vp, visited) + case reflect.Slice: + if proxify { + makeValueProxy(L, vp, cSliceMeta) + } else { + if visited.push(v) { + return + } + copySliceToTable(L, v, visited) + } + case reflect.Map: + if proxify { + makeValueProxy(L, vp, cMapMeta) + } else { + if visited.push(v) { + return + } + copyMapToTable(L, v, visited) + } + case reflect.Struct: + if proxify { + if vp.CanInterface() { + switch v := vp.Interface().(type) { + case error: + // TODO: Test proxification of errors. + L.PushString(v.Error()) + return + case *LuaObject: + // TODO: Move out of 'proxify' condition? LuaObject is meant to be + // manipulated from the Go side, it is not useful in Lua. + if v.l == L { + v.Push() + } else { + // TODO: What shall we do when LuaObject state is not the current + // state? Copy across states? Is it always possible? + L.PushNil() + } + return + default: + } + } + + // Structs are always user-defined types, so it makes sense to always + // proxify them. + if !v.CanSet() { + vp = reflect.New(v.Type()) + vp.Elem().Set(v) + } + makeValueProxy(L, vp, cStructMeta) + } else { + // Use vp instead of v to detect cycles from the very first element, if a pointer. + if vp.Kind() == reflect.Ptr && visited.push(vp) { + return + } + copyStructToTable(L, vp, visited) + } + case reflect.Chan: + makeValueProxy(L, vp, cChannelMeta) + case reflect.Func: + L.PushGoFunction(goToLuaFunction(L, v)) + default: + if val, ok := v.Interface().(error); ok { + L.PushString(val.Error()) + } else if v.IsNil() { + L.PushNil() + } else { + makeValueProxy(L, vp, cInterfaceMeta) + } + } +} + +func luaIsEmpty(L *lua.State, idx int) bool { + L.PushNil() + if idx < 0 { + idx-- + } + if L.Next(idx) != 0 { + L.Pop(2) + return false + } + return true +} + +func luaMapLen(L *lua.State, idx int) int { + L.PushNil() + if idx < 0 { + idx-- + } + len := 0 + for L.Next(idx) != 0 { + len++ + L.Pop(1) + } + return len +} + +func copyTableToMap(L *lua.State, idx int, v reflect.Value, visited map[uintptr]reflect.Value) (status error) { + t := v.Type() + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + te, tk := t.Elem(), t.Key() + + // See copyTableToSlice. + ptr := L.ToPointer(idx) + if !luaIsEmpty(L, idx) { + visited[ptr] = v + } + + L.PushNil() + if idx < 0 { + idx-- + } + for L.Next(idx) != 0 { + // key at -2, value at -1 + key := reflect.New(tk).Elem() + err := luaToGo(L, -2, key, visited) + if err != nil { + status = ErrTableConv + L.Pop(1) + continue + } + val := reflect.New(te).Elem() + err = luaToGo(L, -1, val, visited) + if err != nil { + status = ErrTableConv + L.Pop(1) + continue + } + v.SetMapIndex(key, val) + L.Pop(1) + } + + return +} + +// Also for arrays. TODO: Create special function for arrays? +func copyTableToSlice(L *lua.State, idx int, v reflect.Value, visited map[uintptr]reflect.Value) (status error) { + t := v.Type() + n := int(L.ObjLen(idx)) + + // Adjust the length of the array/slice. + if n > v.Len() { + if t.Kind() == reflect.Array { + n = v.Len() + } else { + // Slice + v.Set(reflect.MakeSlice(t, n, n)) + } + } else if n < v.Len() { + if t.Kind() == reflect.Array { + // Nullify remaining elements. + for i := n; i < v.Len(); i++ { + v.Index(i).Set(reflect.Zero(t.Elem())) + } + } else { + // Slice + v.SetLen(n) + } + } + + // Do not add empty slices to the list of visited elements. + // The empty Lua table is a single instance object and gets re-used across maps, slices and others. + // Arrays cannot be cyclic since the interface type will ask for slices. + if n > 0 && t.Kind() != reflect.Array { + ptr := L.ToPointer(idx) + visited[ptr] = v + } + + te := t.Elem() + for i := 1; i <= n; i++ { + L.RawGeti(idx, i) + val := reflect.New(te).Elem() + err := luaToGo(L, -1, val, visited) + if err != nil { + status = ErrTableConv + L.Pop(1) + continue + } + v.Index(i - 1).Set(val) + L.Pop(1) + } + + return +} + +func copyTableToStruct(L *lua.State, idx int, v reflect.Value, visited map[uintptr]reflect.Value) (status error) { + t := v.Type() + + // See copyTableToSlice. + ptr := L.ToPointer(idx) + if !luaIsEmpty(L, idx) { + visited[ptr] = v.Addr() + } + + // Associate Lua keys with Go fields: tags have priority over matching field + // name. + fields := map[string]string{} + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + tag := field.Tag.Get("lua") + if tag != "" { + fields[tag] = field.Name + continue + } + fields[field.Name] = field.Name + } + + L.PushNil() + if idx < 0 { + idx-- + } + for L.Next(idx) != 0 { + L.PushValue(-2) + // Warning: ToString changes the value on stack. + key := L.ToString(-1) + L.Pop(1) + f := v.FieldByName(fields[key]) + if f.CanSet() { + val := reflect.New(f.Type()).Elem() + err := luaToGo(L, -1, val, visited) + if err != nil { + status = ErrTableConv + L.Pop(1) + continue + } + f.Set(val) + } + L.Pop(1) + } + + return +} + +// LuaToGo converts the Lua value at index 'idx' to the Go value. +// +// The Go value must be a non-nil pointer. +// +// Conversions to strings and numbers are straightforward. +// +// Lua 'nil' is converted to the zero value of the specified Go value. +// +// If the Lua value is non-nil, pointers are dereferenced (multiple times if +// required) and the pointed value is the one that is set. If 'nil', then the Go +// pointer is set to 'nil'. To set a pointer's value to its zero value, use +// 'luar.null'. +// +// The Go value can be an interface, in which case the type is inferred. When +// converting a table to an interface, the Go value is a []interface{} slice if +// all its elements are indexed consecutively from 1, or a +// map[string]interface{} otherwise. +// +// Existing entries in maps and structs are kept. Arrays and slices are reset. +// +// Nil maps and slices are automatically allocated. +// +// Proxies are unwrapped to the Go value, if convertible. If both the proxy and +// the Go value are pointers, then the Go pointer will be set to the proxy +// pointer. +// Userdata that is not a proxy will be converted to a LuaObject if the Go value +// is an interface or a LuaObject. +func LuaToGo(L *lua.State, idx int, a interface{}) error { + // LuaToGo should not pop the Lua stack to be consistent with L.ToString(), etc. + // It is also easier in practice when we want to keep working with the value on stack. + + v := reflect.ValueOf(a) + // TODO: Test interfaces with methods. + // TODO: Allow unreferenced map? encoding/json does not do it. + if v.Kind() != reflect.Ptr { + return errors.New("not a pointer") + } + if v.IsNil() { + return errors.New("nil pointer") + } + + v = v.Elem() + // If the Lua value is 'nil' and the Go value is a pointer, nullify the pointer. + if v.Kind() == reflect.Ptr && L.IsNil(idx) { + v.Set(reflect.Zero(v.Type())) + return nil + } + + return luaToGo(L, idx, v, map[uintptr]reflect.Value{}) +} + +func luaToGo(L *lua.State, idx int, v reflect.Value, visited map[uintptr]reflect.Value) error { + // Derefence 'v' until a non-pointer. + // This initializes the values, which will be useless effort if the conversion + // fails. + // This must be done here and not in LuaToGo so that the copyTable* functions + // can also call luaToGo on pointers. + vp := v + for v.Kind() == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + vp = v + v = v.Elem() + } + kind := v.Kind() + + switch L.Type(idx) { + case lua.LUA_TNIL: + v.Set(reflect.Zero(v.Type())) + case lua.LUA_TBOOLEAN: + if kind != reflect.Bool && kind != reflect.Interface { + return ConvError{From: luaDesc(L, idx), To: v.Type()} + } + v.Set(reflect.ValueOf(L.ToBoolean(idx))) + case lua.LUA_TNUMBER: + switch k := unsizedKind(v); k { + case reflect.Int64, reflect.Uint64, reflect.Float64, reflect.Interface: + // We do not use ToInteger as it may truncate the value. Let Go truncate + // instead in Convert(). + f := reflect.ValueOf(L.ToNumber(idx)) + v.Set(f.Convert(v.Type())) + case reflect.Complex128: + v.SetComplex(complex(L.ToNumber(idx), 0)) + default: + return ConvError{From: luaDesc(L, idx), To: v.Type()} + } + case lua.LUA_TSTRING: + if kind != reflect.String && kind != reflect.Interface { + return ConvError{From: luaDesc(L, idx), To: v.Type()} + } + v.Set(reflect.ValueOf(L.ToString(idx))) + case lua.LUA_TUSERDATA: + if isValueProxy(L, idx) { + val, typ := valueOfProxy(L, idx) + if val.Interface() == Null { + // Special case for Null. + v.Set(reflect.Zero(v.Type())) + return nil + } + + // If both 'val' and 'v' are pointers, set the 'val' pointer to 'v'. + if typ.ConvertibleTo(vp.Type()) { + vp.Set(val.Convert(vp.Type())) + return nil + } + + // Otherwise dereference. + for !typ.ConvertibleTo(v.Type()) && val.Kind() == reflect.Ptr { + val = val.Elem() + typ = typ.Elem() + } + if !typ.ConvertibleTo(v.Type()) { + return ConvError{From: fmt.Sprintf("proxy (%v)", typ), To: v.Type()} + } + // We automatically convert between types. This behaviour is consistent + // with LuaToGo conversions elsewhere. + v.Set(val.Convert(v.Type())) + return nil + } else if kind != reflect.Interface || v.Type() != reflect.TypeOf(LuaObject{}) { + return ConvError{From: luaDesc(L, idx), To: v.Type()} + } + // Wrap the userdata into a LuaObject. + v.Set(reflect.ValueOf(NewLuaObject(L, idx))) + case lua.LUA_TTABLE: + // If several Lua objects point to the same value while they map to Go + // values of different types, 'visited' should be skipped. Since such a + // condition is hard to infere, we simply check if it is convertible. + // + // Lua source: + // t = { + // names = {"foo", "bar"}, + // altnames = names, + // } + // + // Go target: + // t := struct { + // names: []string + // altnames: map[string]string + // } + ptr := L.ToPointer(idx) + if val, ok := visited[ptr]; ok { + if v.Kind() == reflect.Struct && val.Type().ConvertibleTo(vp.Type()) { + vp.Set(val) + return nil + } else if val.Type().ConvertibleTo(v.Type()) { + v.Set(val) + return nil + } + } + + switch kind { + case reflect.Array: + fallthrough + case reflect.Slice: + return copyTableToSlice(L, idx, v, visited) + case reflect.Map: + return copyTableToMap(L, idx, v, visited) + case reflect.Struct: + return copyTableToStruct(L, idx, v, visited) + case reflect.Interface: + n := int(L.ObjLen(idx)) + + switch v.Elem().Kind() { + case reflect.Map: + return copyTableToMap(L, idx, v.Elem(), visited) + case reflect.Slice: + // Need to make/resize the slice here since interface values are not adressable. + v.Set(reflect.MakeSlice(v.Elem().Type(), n, n)) + return copyTableToSlice(L, idx, v.Elem(), visited) + } + + if luaMapLen(L, idx) != n { + v.Set(reflect.MakeMap(tmap)) + return copyTableToMap(L, idx, v.Elem(), visited) + } + v.Set(reflect.MakeSlice(tslice, n, n)) + return copyTableToSlice(L, idx, v.Elem(), visited) + default: + return ConvError{From: luaDesc(L, idx), To: v.Type()} + } + case lua.LUA_TFUNCTION: + if kind == reflect.Interface { + v.Set(reflect.ValueOf(NewLuaObject(L, idx))) + } else if vp.Type() == reflect.TypeOf(&LuaObject{}) { + vp.Set(reflect.ValueOf(NewLuaObject(L, idx))) + } else { + return ConvError{From: luaDesc(L, idx), To: v.Type()} + } + default: + return ConvError{From: luaDesc(L, idx), To: v.Type()} + } + + return nil +} + +func isNewType(t reflect.Type) bool { + types := [...]reflect.Type{ + reflect.Invalid: nil, // Invalid Kind = iota + reflect.Bool: typeof((*bool)(nil)), + reflect.Int: typeof((*int)(nil)), + reflect.Int8: typeof((*int8)(nil)), + reflect.Int16: typeof((*int16)(nil)), + reflect.Int32: typeof((*int32)(nil)), + reflect.Int64: typeof((*int64)(nil)), + reflect.Uint: typeof((*uint)(nil)), + reflect.Uint8: typeof((*uint8)(nil)), + reflect.Uint16: typeof((*uint16)(nil)), + reflect.Uint32: typeof((*uint32)(nil)), + reflect.Uint64: typeof((*uint64)(nil)), + reflect.Uintptr: typeof((*uintptr)(nil)), + reflect.Float32: typeof((*float32)(nil)), + reflect.Float64: typeof((*float64)(nil)), + reflect.Complex64: typeof((*complex64)(nil)), + reflect.Complex128: typeof((*complex128)(nil)), + reflect.String: typeof((*string)(nil)), + } + + pt := types[int(t.Kind())] + return pt != t +} + +// Register makes a number of Go values available in Lua code as proxies. +// 'values' is a map of strings to Go values. +// +// - If table is non-nil, then create or reuse a global table of that name and +// put the values in it. +// +// - If table is '' then put the values in the global table (_G). +// +// - If table is '*' then assume that the table is already on the stack. +// +// See GoToLuaProxy's documentation. +func Register(L *lua.State, table string, values Map) { + pop := true + if table == "*" { + pop = false + } else if len(table) > 0 { + L.GetGlobal(table) + if L.IsNil(-1) { + L.Pop(1) + L.NewTable() + L.SetGlobal(table) + L.GetGlobal(table) + } + } else { + L.GetGlobal("_G") + } + for name, val := range values { + GoToLuaProxy(L, val) + L.SetField(-2, name) + } + if pop { + L.Pop(1) + } +} + +// Closest we'll get to a typeof operator. +func typeof(a interface{}) reflect.Type { + return reflect.TypeOf(a).Elem() +} diff --git a/decoder/luar/proxy.go b/decoder/luar/proxy.go new file mode 100644 index 0000000..3f4a48b --- /dev/null +++ b/decoder/luar/proxy.go @@ -0,0 +1,297 @@ +package luar + +import ( + "fmt" + "reflect" + "strconv" + "sync" + + "github.com/sipcapture/golua/lua" +) + +// Lua proxy objects for Go slices, maps and structs +// TODO: Replace by interface{}? +type valueProxy struct { + v reflect.Value + t reflect.Type +} + +const ( + cNumberMeta = "numberMT" + cComplexMeta = "complexMT" + cStringMeta = "stringMT" + cSliceMeta = "sliceMT" + cMapMeta = "mapMT" + cStructMeta = "structMT" + cInterfaceMeta = "interfaceMT" + cChannelMeta = "channelMT" +) + +var ( + proxyIdCounter uintptr + proxyMap = map[uintptr]*valueProxy{} + proxymu sync.RWMutex +) + +// commonKind returns the kind to which v1 and v2 can be converted with the +// least information loss. +func commonKind(v1, v2 reflect.Value) reflect.Kind { + k1 := unsizedKind(v1) + k2 := unsizedKind(v2) + if k1 == k2 && (k1 == reflect.Uint64 || k1 == reflect.Int64) { + return k1 + } + if k1 == reflect.Complex128 || k2 == reflect.Complex128 { + return reflect.Complex128 + } + return reflect.Float64 +} + +func isPointerToPrimitive(v reflect.Value) bool { + return v.Kind() == reflect.Ptr && v.Elem().IsValid() && v.Elem().Type() != nil +} + +func isPredeclaredType(t reflect.Type) bool { + return t == reflect.TypeOf(0.0) || t == reflect.TypeOf("") +} + +func isValueProxy(L *lua.State, idx int) bool { + res := false + if L.IsUserdata(idx) { + L.GetMetaTable(idx) + if !L.IsNil(-1) { + L.GetField(-1, "luago.value") + res = !L.IsNil(-1) + L.Pop(1) + } + L.Pop(1) + } + return res +} + +func luaToGoValue(L *lua.State, idx int) (reflect.Value, reflect.Type) { + var a interface{} + err := LuaToGo(L, idx, &a) + if err != nil { + L.RaiseError(err.Error()) + } + return reflect.ValueOf(a), reflect.TypeOf(a) +} + +func makeValueProxy(L *lua.State, v reflect.Value, proxyMT string) { + // The metatable needs be set up in the Lua state before the proxy is created, + // otherwise closing the state will fail on calling the garbage collector. Not + // really sure why this happens though... + L.LGetMetaTable(proxyMT) + if L.IsNil(-1) { + flagValue := func() { + L.SetMetaMethod("__tostring", proxy__tostring) + L.SetMetaMethod("__gc", proxy__gc) + L.SetMetaMethod("__eq", proxy__eq) + L.PushBoolean(true) + L.SetField(-2, "luago.value") + L.Pop(1) + } + switch proxyMT { + case cNumberMeta: + L.NewMetaTable(proxyMT) + L.SetMetaMethod("__index", interface__index) + L.SetMetaMethod("__lt", number__lt) + L.SetMetaMethod("__add", number__add) + L.SetMetaMethod("__sub", number__sub) + L.SetMetaMethod("__mul", number__mul) + L.SetMetaMethod("__div", number__div) + L.SetMetaMethod("__mod", number__mod) + L.SetMetaMethod("__pow", number__pow) + L.SetMetaMethod("__unm", number__unm) + flagValue() + case cComplexMeta: + L.NewMetaTable(proxyMT) + L.SetMetaMethod("__index", complex__index) + L.SetMetaMethod("__add", number__add) + L.SetMetaMethod("__sub", number__sub) + L.SetMetaMethod("__mul", number__mul) + L.SetMetaMethod("__div", number__div) + L.SetMetaMethod("__pow", number__pow) + L.SetMetaMethod("__unm", number__unm) + flagValue() + case cStringMeta: + L.NewMetaTable(proxyMT) + L.SetMetaMethod("__index", string__index) + L.SetMetaMethod("__len", string__len) + L.SetMetaMethod("__lt", string__lt) + L.SetMetaMethod("__concat", string__concat) + L.SetMetaMethod("__ipairs", string__ipairs) + L.SetMetaMethod("__pairs", string__ipairs) + flagValue() + case cSliceMeta: + L.NewMetaTable(proxyMT) + L.SetMetaMethod("__index", slice__index) + L.SetMetaMethod("__newindex", slice__newindex) + L.SetMetaMethod("__len", slicemap__len) + L.SetMetaMethod("__ipairs", slice__ipairs) + L.SetMetaMethod("__pairs", slice__ipairs) + flagValue() + case cMapMeta: + L.NewMetaTable(proxyMT) + L.SetMetaMethod("__index", map__index) + L.SetMetaMethod("__newindex", map__newindex) + L.SetMetaMethod("__len", slicemap__len) + L.SetMetaMethod("__ipairs", map__ipairs) + L.SetMetaMethod("__pairs", map__pairs) + flagValue() + case cStructMeta: + L.NewMetaTable(proxyMT) + L.SetMetaMethod("__index", struct__index) + L.SetMetaMethod("__newindex", struct__newindex) + flagValue() + case cInterfaceMeta: + L.NewMetaTable(proxyMT) + L.SetMetaMethod("__index", interface__index) + flagValue() + case cChannelMeta: + L.NewMetaTable(proxyMT) + L.SetMetaMethod("__index", channel__index) + flagValue() + } + } + + proxymu.Lock() + id := proxyIdCounter + proxyIdCounter++ + proxyMap[id] = &valueProxy{v: v, t: v.Type()} + proxymu.Unlock() + + L.Pop(1) + rawptr := L.NewUserdata(reflect.TypeOf(id).Size()) + *(*uintptr)(rawptr) = id + L.LGetMetaTable(proxyMT) + L.SetMetaTable(-2) +} + +func pushGoMethod(L *lua.State, name string, v reflect.Value) { + method := v.MethodByName(name) + if !method.IsValid() { + t := v.Type() + // Could not resolve this method. Perhaps it's defined on the pointer? + if t.Kind() != reflect.Ptr { + if v.CanAddr() { + // If we can get a pointer directly. + v = v.Addr() + } else { + // Otherwise create and initialize one. + vp := reflect.New(t) + vp.Elem().Set(v) + v = vp + } + } + method = v.MethodByName(name) + if !method.IsValid() { + L.PushNil() + return + } + } + GoToLua(L, method) +} + +// pushNumberValue pushes the number resulting from an arithmetic operation. +// +// At least one operand must be a proxy for this function to be called. See the +// main documentation for the conversion rules. +func pushNumberValue(L *lua.State, a interface{}, t1, t2 reflect.Type) { + v := reflect.ValueOf(a) + isComplex := unsizedKind(v) == reflect.Complex128 + mt := cNumberMeta + if isComplex { + mt = cComplexMeta + } + if t1 == t2 || isPredeclaredType(t2) { + makeValueProxy(L, v.Convert(t1), mt) + } else if isPredeclaredType(t1) { + makeValueProxy(L, v.Convert(t2), mt) + } else if isComplex { + complexType := reflect.TypeOf(0i) + makeValueProxy(L, v.Convert(complexType), cComplexMeta) + } else { + L.PushNumber(valueToNumber(L, v)) + } +} + +func slicer(L *lua.State, v reflect.Value, metatable string) lua.LuaGoFunction { + return func(L *lua.State) int { + L.CheckInteger(1) + L.CheckInteger(2) + i := L.ToInteger(1) - 1 + j := L.ToInteger(2) - 1 + if i < 0 || i >= v.Len() || i > j || j > v.Len() { + L.RaiseError("slice bounds out of range") + } + vn := v.Slice(i, j) + makeValueProxy(L, vn, metatable) + return 1 + } +} + +// Shorthand for kind-switches. +func unsizedKind(v reflect.Value) reflect.Kind { + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.Int64 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return reflect.Uint64 + case reflect.Float64, reflect.Float32: + return reflect.Float64 + case reflect.Complex128, reflect.Complex64: + return reflect.Complex128 + } + return v.Kind() +} + +func valueOfProxy(L *lua.State, idx int) (reflect.Value, reflect.Type) { + proxyId := *(*uintptr)(L.ToUserdata(idx)) + + proxymu.RLock() + val, ok := proxyMap[proxyId] + proxymu.RUnlock() + + if !ok { + L.RaiseError(fmt.Sprintf("No value proxy in arg #%d", idx)) + } + + return val.v, val.t +} + +func valueToComplex(L *lua.State, v reflect.Value) complex128 { + if unsizedKind(v) == reflect.Complex128 { + return v.Complex() + } + return complex(valueToNumber(L, v), 0) +} + +func valueToNumber(L *lua.State, v reflect.Value) float64 { + switch unsizedKind(v) { + case reflect.Int64: + return float64(v.Int()) + case reflect.Uint64: + return float64(v.Uint()) + case reflect.Float64: + return v.Float() + case reflect.String: + if f, err := strconv.ParseFloat(v.String(), 64); err == nil { + return f + } + } + L.RaiseError(fmt.Sprintf("cannot convert %#v to number", v)) + return 0 +} + +func valueToString(L *lua.State, v reflect.Value) string { + switch unsizedKind(v) { + case reflect.Uint64, reflect.Int64, reflect.Float64: + return fmt.Sprintf("%v", valueToNumber(L, v)) + case reflect.String: + return v.String() + } + L.RaiseError("cannot convert to string") + return "" +} diff --git a/decoder/luar/proxyfuncs.go b/decoder/luar/proxyfuncs.go new file mode 100644 index 0000000..47190aa --- /dev/null +++ b/decoder/luar/proxyfuncs.go @@ -0,0 +1,202 @@ +package luar + +// Those functions are meant to be registered in Lua to manipulate proxies. + +import ( + "reflect" + + "github.com/sipcapture/golua/lua" +) + +// Complex pushes a proxy to a Go complex on the stack. +// +// Arguments: real (number), imag (number) +// +// Returns: proxy (complex128) +func Complex(L *lua.State) int { + v1, _ := luaToGoValue(L, 1) + v2, _ := luaToGoValue(L, 2) + result := complex(valueToNumber(L, v1), valueToNumber(L, v2)) + makeValueProxy(L, reflect.ValueOf(result), cComplexMeta) + return 1 +} + +// MakeChan creates a 'chan interface{}' proxy and pushes it on the stack. +// +// Optional argument: size (number) +// +// Returns: proxy (chan interface{}) +func MakeChan(L *lua.State) int { + n := L.OptInteger(1, 0) + ch := make(chan interface{}, n) + makeValueProxy(L, reflect.ValueOf(ch), cChannelMeta) + return 1 +} + +// MakeMap creates a 'map[string]interface{}' proxy and pushes it on the stack. +// +// Returns: proxy (map[string]interface{}) +func MakeMap(L *lua.State) int { + m := reflect.MakeMap(tmap) + makeValueProxy(L, m, cMapMeta) + return 1 +} + +// MakeSlice creates a '[]interface{}' proxy and pushes it on the stack. +// +// Optional argument: size (number) +// +// Returns: proxy ([]interface{}) +func MakeSlice(L *lua.State) int { + n := L.OptInteger(1, 0) + s := reflect.MakeSlice(tslice, n, n+1) + makeValueProxy(L, s, cSliceMeta) + return 1 +} + +func ipairsAux(L *lua.State) int { + i := L.CheckInteger(2) + 1 + L.PushInteger(int64(i)) + L.PushInteger(int64(i)) + L.GetTable(1) + if L.Type(-1) == lua.LUA_TNIL { + return 1 + } + return 2 +} + +// ProxyIpairs implements Lua 5.2 'ipairs' functions. +// It respects the __ipairs metamethod. +// +// It is only useful for compatibility with Lua 5.1. +// +// Because it cannot call 'ipairs' for it might recurse infinitely, ProxyIpairs +// reimplements `ipairsAux` in Go which can be a performance issue in tight +// loops. +// +// You should call 'RegProxyIpairs' instead. +func ProxyIpairs(L *lua.State) int { + // See Lua >=5.2 source code. + if L.GetMetaField(1, "__ipairs") { + L.PushValue(1) + L.Call(1, 3) + return 3 + } + + L.CheckType(1, lua.LUA_TTABLE) + L.PushGoFunction(ipairsAux) + L.PushValue(1) + L.PushInteger(0) + return 3 +} + +// Register a function 'table.name' equivalent to ProxyIpairs that uses 'ipairs' +// when '__ipairs' is not present. +// +// This is much faster than ProxyIpairs. +func RegProxyIpairs(L *lua.State, table, name string) { + L.GetGlobal("ipairs") + ref := L.Ref(lua.LUA_REGISTRYINDEX) + + f := func(L *lua.State) int { + // See Lua >=5.2 source code. + if L.GetMetaField(1, "__ipairs") { + L.PushValue(1) + L.Call(1, 3) + return 3 + } + L.RawGeti(lua.LUA_REGISTRYINDEX, ref) + L.PushValue(1) + L.Call(1, 3) + return 3 + } + + Register(L, table, Map{ + name: f, + }) +} + +// ProxyMethod pushes the proxy method on the stack. +// +// Argument: proxy +// +// Returns: method (function) +func ProxyMethod(L *lua.State) int { + if !isValueProxy(L, 1) { + L.PushNil() + return 1 + } + v, _ := valueOfProxy(L, 1) + name := L.ToString(2) + pushGoMethod(L, name, v) + return 1 +} + +// ProxyPairs implements Lua 5.2 'pairs' functions. +// It respects the __pairs metamethod. +// +// It is only useful for compatibility with Lua 5.1. +func ProxyPairs(L *lua.State) int { + // See Lua >=5.2 source code. + if L.GetMetaField(1, "__pairs") { + L.PushValue(1) + L.Call(1, 3) + return 3 + } + + L.CheckType(1, lua.LUA_TTABLE) + L.GetGlobal("next") + L.PushValue(1) + L.PushNil() + return 3 +} + +// ProxyType pushes the proxy type on the stack. +// +// It behaves like Lua's "type" except for proxies for which it returns +// 'table', 'string' or 'number' with TYPE being the go type. +// +// Argument: proxy +// +// Returns: type (string) +func ProxyType(L *lua.State) int { + if !isValueProxy(L, 1) { + L.PushString(L.LTypename(1)) + return 1 + } + v, _ := valueOfProxy(L, 1) + + pointerLevel := "" + for v.Kind() == reflect.Ptr { + pointerLevel += "*" + v = v.Elem() + } + + prefix := "userdata" + switch unsizedKind(v) { + case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: + prefix = "table" + case reflect.String: + prefix = "string" + case reflect.Uint64, reflect.Int64, reflect.Float64, reflect.Complex128: + prefix = "number" + } + + L.PushString(prefix + "<" + pointerLevel + v.Type().String() + ">") + return 1 +} + +// Unproxify converts a proxy to an unproxified Lua value. +// +// Argument: proxy +// +// Returns: value (Lua value) +func Unproxify(L *lua.State) int { + if !isValueProxy(L, 1) { + L.PushNil() + return 1 + } + v, _ := valueOfProxy(L, 1) + GoToLua(L, v) + return 1 +} diff --git a/decoder/luar/proxymm.go b/decoder/luar/proxymm.go new file mode 100644 index 0000000..320e852 --- /dev/null +++ b/decoder/luar/proxymm.go @@ -0,0 +1,561 @@ +package luar + +// Metamethods. + +// Errors in metamethod will yield a call to RaiseError. +// It is not possible to return an error / bool / message to the caller when +// metamethods are called via Lua operators (e.g. __newindex). + +// TODO: Replicate Go/Lua error messages in RaiseError. + +import ( + "fmt" + "math" + "math/cmplx" + "reflect" + + "github.com/sipcapture/golua/lua" +) + +func channel__index(L *lua.State) int { + v, t := valueOfProxy(L, 1) + name := L.ToString(2) + switch name { + case "recv": + f := func(L *lua.State) int { + val, ok := v.Recv() + if ok { + GoToLuaProxy(L, val) + return 1 + } + return 0 + } + L.PushGoFunction(f) + case "send": + f := func(L *lua.State) int { + val := reflect.New(t.Elem()) + err := LuaToGo(L, 1, val.Interface()) + if err != nil { + L.RaiseError(fmt.Sprintf("channel requires %v value type", t.Elem())) + } + v.Send(val.Elem()) + return 0 + } + L.PushGoFunction(f) + case "close": + f := func(L *lua.State) int { + v.Close() + return 0 + } + L.PushGoFunction(f) + default: + pushGoMethod(L, name, v) + } + return 1 +} + +func complex__index(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + name := L.ToString(2) + switch name { + case "real": + L.PushNumber(real(v.Complex())) + case "imag": + L.PushNumber(imag(v.Complex())) + default: + pushGoMethod(L, name, v) + } + return 1 +} + +func interface__index(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + name := L.ToString(2) + pushGoMethod(L, name, v) + return 1 +} + +// TODO: Should map[string] and struct allow direct method calls? Check if first letter is uppercase? +func map__index(L *lua.State) int { + v, t := valueOfProxy(L, 1) + key := reflect.New(t.Key()) + err := LuaToGo(L, 2, key.Interface()) + if err == nil { + key = key.Elem() + val := v.MapIndex(key) + if val.IsValid() { + GoToLuaProxy(L, val) + return 1 + } + } + if !L.IsNumber(2) && L.IsString(2) { + name := L.ToString(2) + pushGoMethod(L, name, v) + return 1 + } + if err != nil { + L.RaiseError(fmt.Sprintf("map requires %v key", t.Key())) + } + return 0 +} + +func map__ipairs(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + keys := v.MapKeys() + intKeys := map[uint64]reflect.Value{} + + // Filter integer keys. + for _, k := range keys { + if k.Kind() == reflect.Interface { + k = k.Elem() + } + switch unsizedKind(k) { + case reflect.Int64: + i := k.Int() + if i > 0 { + intKeys[uint64(i)] = k + } + case reflect.Uint64: + intKeys[k.Uint()] = k + } + } + + idx := uint64(0) + iter := func(L *lua.State) int { + idx++ + if _, ok := intKeys[idx]; !ok { + return 0 + } + GoToLuaProxy(L, idx) + val := v.MapIndex(intKeys[idx]) + GoToLuaProxy(L, val) + return 2 + } + L.PushGoFunction(iter) + return 1 +} + +func map__newindex(L *lua.State) int { + v, t := valueOfProxy(L, 1) + key := reflect.New(t.Key()) + err := LuaToGo(L, 2, key.Interface()) + if err != nil { + L.RaiseError(fmt.Sprintf("map requires %v key", t.Key())) + } + key = key.Elem() + val := reflect.New(t.Elem()) + err = LuaToGo(L, 3, val.Interface()) + if err != nil { + L.RaiseError(fmt.Sprintf("map requires %v value type", t.Elem())) + } + val = val.Elem() + v.SetMapIndex(key, val) + return 0 +} + +func map__pairs(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + keys := v.MapKeys() + idx := -1 + n := v.Len() + iter := func(L *lua.State) int { + idx++ + if idx == n { + return 0 + } + GoToLuaProxy(L, keys[idx]) + val := v.MapIndex(keys[idx]) + GoToLuaProxy(L, val) + return 2 + } + L.PushGoFunction(iter) + return 1 +} + +func number__add(L *lua.State) int { + v1, t1 := luaToGoValue(L, 1) + v2, t2 := luaToGoValue(L, 2) + var result interface{} + switch commonKind(v1, v2) { + case reflect.Uint64: + result = v1.Uint() + v2.Uint() + case reflect.Int64: + result = v1.Int() + v2.Int() + case reflect.Float64: + result = valueToNumber(L, v1) + valueToNumber(L, v2) + case reflect.Complex128: + result = valueToComplex(L, v1) + valueToComplex(L, v2) + } + pushNumberValue(L, result, t1, t2) + return 1 +} + +func number__div(L *lua.State) int { + v1, t1 := luaToGoValue(L, 1) + v2, t2 := luaToGoValue(L, 2) + var result interface{} + switch commonKind(v1, v2) { + case reflect.Uint64: + result = v1.Uint() / v2.Uint() + case reflect.Int64: + result = v1.Int() / v2.Int() + case reflect.Float64: + result = valueToNumber(L, v1) / valueToNumber(L, v2) + case reflect.Complex128: + result = valueToComplex(L, v1) / valueToComplex(L, v2) + } + pushNumberValue(L, result, t1, t2) + return 1 +} + +func number__lt(L *lua.State) int { + v1, _ := luaToGoValue(L, 1) + v2, _ := luaToGoValue(L, 2) + switch commonKind(v1, v2) { + case reflect.Uint64: + L.PushBoolean(v1.Uint() < v2.Uint()) + case reflect.Int64: + L.PushBoolean(v1.Int() < v2.Int()) + case reflect.Float64: + L.PushBoolean(valueToNumber(L, v1) < valueToNumber(L, v2)) + } + return 1 +} + +func number__mod(L *lua.State) int { + v1, t1 := luaToGoValue(L, 1) + v2, t2 := luaToGoValue(L, 2) + var result interface{} + switch commonKind(v1, v2) { + case reflect.Uint64: + result = v1.Uint() % v2.Uint() + case reflect.Int64: + result = v1.Int() % v2.Int() + case reflect.Float64: + result = math.Mod(valueToNumber(L, v1), valueToNumber(L, v2)) + } + pushNumberValue(L, result, t1, t2) + return 1 +} + +func number__mul(L *lua.State) int { + v1, t1 := luaToGoValue(L, 1) + v2, t2 := luaToGoValue(L, 2) + var result interface{} + switch commonKind(v1, v2) { + case reflect.Uint64: + result = v1.Uint() * v2.Uint() + case reflect.Int64: + result = v1.Int() * v2.Int() + case reflect.Float64: + result = valueToNumber(L, v1) * valueToNumber(L, v2) + case reflect.Complex128: + result = valueToComplex(L, v1) * valueToComplex(L, v2) + } + pushNumberValue(L, result, t1, t2) + return 1 +} + +func number__pow(L *lua.State) int { + v1, t1 := luaToGoValue(L, 1) + v2, t2 := luaToGoValue(L, 2) + var result interface{} + switch commonKind(v1, v2) { + case reflect.Uint64: + result = math.Pow(float64(v1.Uint()), float64(v2.Uint())) + case reflect.Int64: + result = math.Pow(float64(v1.Int()), float64(v2.Int())) + case reflect.Float64: + result = math.Pow(valueToNumber(L, v1), valueToNumber(L, v2)) + case reflect.Complex128: + result = cmplx.Pow(valueToComplex(L, v1), valueToComplex(L, v2)) + } + pushNumberValue(L, result, t1, t2) + return 1 +} + +func number__sub(L *lua.State) int { + v1, t1 := luaToGoValue(L, 1) + v2, t2 := luaToGoValue(L, 2) + var result interface{} + switch commonKind(v1, v2) { + case reflect.Uint64: + result = v1.Uint() - v2.Uint() + case reflect.Int64: + result = v1.Int() - v2.Int() + case reflect.Float64: + result = valueToNumber(L, v1) - valueToNumber(L, v2) + case reflect.Complex128: + result = valueToComplex(L, v1) - valueToComplex(L, v2) + } + pushNumberValue(L, result, t1, t2) + return 1 +} + +func number__unm(L *lua.State) int { + v1, t1 := luaToGoValue(L, 1) + var result interface{} + switch unsizedKind(v1) { + case reflect.Uint64: + result = -v1.Uint() + case reflect.Int64: + result = -v1.Int() + case reflect.Float64, reflect.String: + result = -valueToNumber(L, v1) + case reflect.Complex128: + result = -v1.Complex() + } + v := reflect.ValueOf(result) + if unsizedKind(v1) == reflect.Complex128 { + makeValueProxy(L, v.Convert(t1), cComplexMeta) + } else if isNewType(t1) { + makeValueProxy(L, v.Convert(t1), cNumberMeta) + } else { + L.PushNumber(v.Float()) + } + return 1 +} + +// From Lua's specs: "A metamethod only is selected when both objects being +// compared have the same type and the same metamethod for the selected +// operation." Thus both arguments must be proxies for this function to be +// called. No need to check for type equality: Go's "==" operator will do it for +// us. +func proxy__eq(L *lua.State) int { + var a1 interface{} + _ = LuaToGo(L, 1, &a1) + var a2 interface{} + _ = LuaToGo(L, 2, &a2) + L.PushBoolean(a1 == a2) + return 1 +} + +func proxy__gc(L *lua.State) int { + proxyId := *(*uintptr)(L.ToUserdata(1)) + proxymu.Lock() + delete(proxyMap, proxyId) + proxymu.Unlock() + return 0 +} + +func proxy__tostring(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + L.PushString(fmt.Sprintf("%v", v)) + return 1 +} + +func slice__index(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + for v.Kind() == reflect.Ptr { + // For arrays. + v = v.Elem() + } + if L.IsNumber(2) { + idx := L.ToInteger(2) + if idx < 1 || idx > v.Len() { + L.RaiseError("slice/array get: index out of range") + } + v := v.Index(idx - 1) + GoToLuaProxy(L, v) + + } else if L.IsString(2) { + name := L.ToString(2) + if v.Kind() == reflect.Array { + pushGoMethod(L, name, v) + return 1 + } + switch name { + case "append": + f := func(L *lua.State) int { + narg := L.GetTop() + args := []reflect.Value{} + for i := 1; i <= narg; i++ { + elem := reflect.New(v.Type().Elem()) + err := LuaToGo(L, i, elem.Interface()) + if err != nil { + L.RaiseError(fmt.Sprintf("slice requires %v value type", v.Type().Elem())) + } + args = append(args, elem.Elem()) + } + newslice := reflect.Append(v, args...) + makeValueProxy(L, newslice, cSliceMeta) + return 1 + } + L.PushGoFunction(f) + case "cap": + L.PushInteger(int64(v.Cap())) + case "slice": + L.PushGoFunction(slicer(L, v, cSliceMeta)) + default: + pushGoMethod(L, name, v) + } + } else { + L.RaiseError("non-integer slice/array index") + } + return 1 +} + +func slice__ipairs(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + n := v.Len() + idx := -1 + iter := func(L *lua.State) int { + idx++ + if idx == n { + return 0 + } + GoToLuaProxy(L, idx+1) // report as 1-based index + val := v.Index(idx) + GoToLuaProxy(L, val) + return 2 + } + L.PushGoFunction(iter) + return 1 +} + +func slice__newindex(L *lua.State) int { + v, t := valueOfProxy(L, 1) + for v.Kind() == reflect.Ptr { + // For arrays. + v = v.Elem() + t = t.Elem() + } + idx := L.ToInteger(2) + val := reflect.New(t.Elem()) + err := LuaToGo(L, 3, val.Interface()) + if err != nil { + L.RaiseError(fmt.Sprintf("slice requires %v value type", t.Elem())) + } + val = val.Elem() + if idx < 1 || idx > v.Len() { + L.RaiseError("slice/array set: index out of range") + } + v.Index(idx - 1).Set(val) + return 0 +} + +func slicemap__len(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + for v.Kind() == reflect.Ptr { + // For arrays. + v = v.Elem() + } + L.PushInteger(int64(v.Len())) + return 1 +} + +// Lua accepts concatenation with string and number. +func string__concat(L *lua.State) int { + v1, t1 := luaToGoValue(L, 1) + v2, t2 := luaToGoValue(L, 2) + s1 := valueToString(L, v1) + s2 := valueToString(L, v2) + result := s1 + s2 + + if t1 == t2 || isPredeclaredType(t2) { + v := reflect.ValueOf(result) + makeValueProxy(L, v.Convert(t1), cStringMeta) + } else if isPredeclaredType(t1) { + v := reflect.ValueOf(result) + makeValueProxy(L, v.Convert(t2), cStringMeta) + } else { + L.PushString(result) + } + + return 1 +} + +func string__index(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + if L.IsNumber(2) { + idx := L.ToInteger(2) + if idx < 1 || idx > v.Len() { + L.RaiseError("index out of range") + } + v := v.Index(idx - 1).Convert(reflect.TypeOf("")) + GoToLuaProxy(L, v) + } else if L.IsString(2) { + name := L.ToString(2) + if name == "slice" { + L.PushGoFunction(slicer(L, v, cStringMeta)) + } else { + pushGoMethod(L, name, v) + } + } else { + L.RaiseError("non-integer string index") + } + return 1 +} + +func string__ipairs(L *lua.State) int { + v, _ := valueOfProxy(L, 1) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + r := []rune(v.String()) + n := len(r) + idx := -1 + iter := func(L *lua.State) int { + idx++ + if idx == n { + return 0 + } + GoToLuaProxy(L, idx+1) // report as 1-based index + GoToLuaProxy(L, string(r[idx])) + return 2 + } + L.PushGoFunction(iter) + return 1 +} + +func string__len(L *lua.State) int { + v1, _ := luaToGoValue(L, 1) + L.PushInteger(int64(v1.Len())) + return 1 +} + +func string__lt(L *lua.State) int { + v1, _ := luaToGoValue(L, 1) + v2, _ := luaToGoValue(L, 2) + L.PushBoolean(v1.String() < v2.String()) + return 1 +} + +func struct__index(L *lua.State) int { + v, t := valueOfProxy(L, 1) + name := L.ToString(2) + vp := v + if t.Kind() == reflect.Ptr { + v = v.Elem() + } + field := v.FieldByName(name) + if !field.IsValid() || !field.CanSet() { + // No such exported field, try for method. + pushGoMethod(L, name, vp) + } else { + GoToLuaProxy(L, field) + } + return 1 +} + +func struct__newindex(L *lua.State) int { + v, t := valueOfProxy(L, 1) + name := L.ToString(2) + if t.Kind() == reflect.Ptr { + v = v.Elem() + } + field := v.FieldByName(name) + if !field.IsValid() { + L.RaiseError(fmt.Sprintf("no field named `%s` for type %s", name, v.Type())) + } + val := reflect.New(field.Type()) + err := LuaToGo(L, 3, val.Interface()) + if err != nil { + L.RaiseError(fmt.Sprintf("struct field %v requires %v value type, error with target: %v", name, field.Type(), err)) + } + field.Set(val.Elem()) + return 0 +} diff --git a/decoder/scriptengine.go b/decoder/scriptengine.go new file mode 100644 index 0000000..888bec4 --- /dev/null +++ b/decoder/scriptengine.go @@ -0,0 +1,105 @@ +package decoder + +import ( + "bufio" + "bytes" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "fmt" + "io" + "os" + "strings" + "unicode" + + "github.com/sipcapture/heplify/config" +) + +// ScriptEngine interface +type ScriptEngine interface { + Run(pkt *Packet) error + Close() +} + +// NewScriptEngine returns a script interface +func NewScriptEngine() (ScriptEngine, error) { + return NewLuaEngine() +} + +func scanCode() (string, *bytes.Buffer, error) { + buf := bytes.NewBuffer(nil) + + file := config.Cfg.ScriptFile + + if file != "" { + f, err := os.Open(file) + if err != nil { + return file, nil, err + } + _, err = io.Copy(buf, f) + if err != nil { + return file, nil, err + } + err = f.Close() + if err != nil { + return file, nil, err + } + } + + return file, buf, nil +} + +func extractFunc(r io.Reader) []string { + var funcs []string + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := cutSpace(scanner.Text()) + if strings.HasPrefix(line, "--") { + continue + } + if strings.HasPrefix(line, "function") { + if b, e := strings.Index(line, "("), strings.Index(line, ")"); b > -1 && e > -1 && b < e { + funcs = append(funcs, line[len("function"):e+1]) + } + } + } + return funcs +} + +func cutSpace(str string) string { + return strings.Map(func(r rune) rune { + if unicode.IsSpace(r) { + return -1 + } + return r + }, str) +} + +// HashString returns md5, sha1 or sha256 sum +func HashString(algo, s string) string { + switch algo { + case "md5": + return fmt.Sprintf("%x", md5.Sum([]byte(s))) + case "sha1": + return fmt.Sprintf("%x", sha1.Sum([]byte(s))) + case "sha256": + return fmt.Sprintf("%x", sha256.Sum256([]byte(s))) + } + return s +} + +// HashTable is a simple kv store +func HashTable(op, key, val string) string { + /*switch op { + case "get": + if res := scriptCache.Get(nil, stb(key)); res != nil { + return string(res) + } + case "set": + scriptCache.Set(stb(key), stb(val)) + case "del": + scriptCache.Del(stb(key)) + } + */ + return "" +} diff --git a/example.lua b/example.lua new file mode 100644 index 0000000..1750414 --- /dev/null +++ b/example.lua @@ -0,0 +1,82 @@ +-- this function will be executed first +function checkRAW() + + local protoType = GetHEPProtoType() + + Logp("DEBUG", "protoType", protoType) + + -- Check if we have SIP type + if protoType ~= 1 then + return + end + + -- original SIP message Payload + local raw = GetRawMessage() + Logp("DEBUG", "raw", raw) + + -- local _, _, name, value = string.find(raw, "(Call-ID:)%s*:%s*(.+)") + -- local name, value = raw:match("(CSeq):%s+(.-)\n") + + -- Set the raw message back + SetRawMessage(raw) + + return + +end + +-- this function will be executed second +function checkHEP() + + -- get GetHEPSrcIP + local src_ip = GetHEPSrcIP() + + Logp("ERROR", "src_ip:", src_ip) + + -- a struct can be nil so better check it + if (src_ip == nil or src_ip == '') then + return + end + + if src_ip == "10.153.177.21" then + Logp("ERROR", "found bad src IP:", src_ip) + local new_ip = "1.1.1.1" + SetHEPField("SrcIP", new_ip) + Logp("ERROR", "replace to new src IP:", new_ip) + end + + local dst_ip = GetHEPDstIP() + + -- a struct can be nil so better check it + if (dst_ip == nil or dst_ip == '') then + return + end + + if dst_ip == "10.153.177.21" then + Logp("ERROR", "found bad dst IP:", dst_ip) + local new_ip = "8.8.8.8" + SetHEPField("DstIP", new_ip) + Logp("ERROR", "replace to new dst IP:", new_ip) + end + + -- ports + + local src_port = GetHEPSrcPort() + + if src_port == 5060 then + Logp("ERROR", "found bad src port", src_port) + local new_src_port = "9060" + SetHEPField("SrcPort", new_src_port) + Logp("ERROR", "set new port ", new_src_port) + end + + local dst_port = GetHEPDstPort() + + if dst_port == 5060 then + Logp("ERROR", "found bad dst port", dst_port) + end + + SetHEPField("DstPort", "9999") + + return + +end diff --git a/go.mod b/go.mod index 47b5f5b..b1007a3 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/sipcapture/heplify go 1.15 require ( + github.com/VictoriaMetrics/fastcache v1.12.2 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/gogo/protobuf v1.3.2 github.com/google/gopacket v1.1.19 @@ -11,6 +12,7 @@ require ( github.com/prometheus/client_golang v1.17.0 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/segmentio/encoding v0.3.6 + github.com/sipcapture/golua v0.0.0-20200610090950-538d24098d76 // indirect github.com/stretchr/testify v1.8.2 golang.org/x/net v0.18.0 ) diff --git a/go.sum b/go.sum index 4c555be..69faf17 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/VictoriaMetrics/fastcache v1.12.2 h1:N0y9ASrJ0F6h0QaC3o6uJb3NIZ9VKLjCM7NQbSmF7WI= +github.com/VictoriaMetrics/fastcache v1.12.2/go.mod h1:AmC+Nzz1+3G2eCPapF6UcsnkThDcMsQicp4xDukwJYI= github.com/alecthomas/kingpin/v2 v2.3.1/go.mod h1:oYL5vtsvEHZGHxU7DMp32Dvx+qL+ptGn6lWaot2vCNE= github.com/alecthomas/kingpin/v2 v2.3.2/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -44,6 +46,7 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= +github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -113,6 +116,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -230,6 +235,8 @@ github.com/segmentio/encoding v0.1.15 h1:btgfyAuFo3uLw7eOrRDPo8H4Bc881+bSPHzAEe0 github.com/segmentio/encoding v0.1.15/go.mod h1:RWhr02uzMB9gQC1x+MfYxedtmBibb9cZ6Vv9VxRSSbw= github.com/segmentio/encoding v0.3.6 h1:E6lVLyDPseWEulBmCmAKPanDd3jiyGDo5gMcugCRwZQ= github.com/segmentio/encoding v0.3.6/go.mod h1:n0JeuIqEQrQoPDGsjo8UNd1iA0U8d8+oHAA4E3G3OxM= +github.com/sipcapture/golua v0.0.0-20200610090950-538d24098d76 h1:LHLWVuD4zXJrQB8aEyN9QKIqAeWAonenmCwz2im1/+o= +github.com/sipcapture/golua v0.0.0-20200610090950-538d24098d76/go.mod h1:NxkBb6hztCHXAf1j/ENBqbofdUtm48P3hPjpedewJl8= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= diff --git a/main.go b/main.go index 19d48d6..d6b963c 100644 --- a/main.go +++ b/main.go @@ -15,7 +15,7 @@ import ( "github.com/sipcapture/heplify/sniffer" ) -const version = "heplify 1.65.16" +const version = "heplify 1.66.1" func createFlags() { @@ -34,6 +34,7 @@ func createFlags() { sys bool fNum int fSize uint64 + hepfilter string ) //long @@ -60,6 +61,10 @@ func createFlags() { flag.BoolVar(&ifaceConfig.WithErspan, "erspan", false, "erspan") flag.IntVar(&fNum, "fnum", 7, "The total num of log files to keep") flag.Uint64Var(&fSize, "fsize", 10*1024*1024, "The rotate size per log file based on byte") + //scripts + flag.StringVar(&config.Cfg.ScriptFile, "script-file", "", "Script file to execute on each packet") + flag.StringVar(&hepfilter, "script-hep-filter", "1", "HEP filter for script, comma separated list of HEP types") + //short flag.StringVar(&config.Cfg.Filter, "fi", "", "Filter interesting packets by any string") flag.StringVar(&config.Cfg.HepCollector, "hin", "", "HEP collector address [udp:127.0.0.1:9093]") @@ -93,6 +98,17 @@ func createFlags() { flag.StringVar(&ifaceConfig.Type, "t", "af_packet", "Capture types are [pcap, af_packet]") flag.Parse() + if hepfilter != "" { + hepfilter = strings.Replace(hepfilter, " ", "", -1) + for _, val := range strings.Split(hepfilter, ",") { + intVal, err := strconv.Atoi(val) + if err != nil { + continue + } + config.Cfg.ScriptHEPFilter = append(config.Cfg.ScriptHEPFilter, intVal) + } + } + config.Cfg.Iface = &ifaceConfig logp.ToStderr = &std logging.ToSyslog = &sys diff --git a/publish/publisher.go b/publish/publisher.go index 64f36d7..75c6a44 100644 --- a/publish/publisher.go +++ b/publish/publisher.go @@ -5,9 +5,12 @@ import ( "time" "github.com/negbie/logp" + "github.com/sipcapture/heplify/config" "github.com/sipcapture/heplify/decoder" ) +var scriptEnable bool + type Outputer interface { Output(msg []byte) SendPingPacket(msg []byte) @@ -16,6 +19,7 @@ type Outputer interface { type Publisher struct { pubCount uint64 outputer Outputer + script decoder.ScriptEngine } func NewPublisher(out Outputer) *Publisher { @@ -23,6 +27,18 @@ func NewPublisher(out Outputer) *Publisher { outputer: out, pubCount: 0, } + + if config.Cfg.ScriptFile != "" { + var err error + p.script, err = decoder.NewScriptEngine() + if err != nil { + logp.Err("%v, please fix and run killall -HUP heplify", err) + } else { + scriptEnable = true + //defer p.script.Close() + } + } + go p.Start(decoder.PacketQueue) go p.printStats() return p @@ -50,6 +66,7 @@ func (pub *Publisher) Start(pq chan *decoder.Packet) { for pkt := range pq { atomic.AddUint64(&pub.pubCount, 1) + var err error //Version == 100 just for forwarding... if pkt.Version == 100 { @@ -65,6 +82,23 @@ func (pub *Publisher) Start(pq chan *decoder.Packet) { pub.setHEPPing(msg) logp.Debug("publisher", "sent hep ping from collector") } else { + + if scriptEnable { + for _, v := range config.Cfg.ScriptHEPFilter { + if int(pkt.ProtoType) == v { + if err = pub.script.Run(pkt); err != nil { + logp.Err("%v", err) + } + break + } + } + + if pkt == nil || pkt.ProtoType == 1 && pkt.Payload == nil { + logp.Warn("nil struct after script processing") + continue + } + } + msg, err := EncodeHEP(pkt) if err != nil { logp.Warn("%v", err) diff --git a/scripts/build_binary.sh b/scripts/build_binary.sh index d21a104..a77813d 100755 --- a/scripts/build_binary.sh +++ b/scripts/build_binary.sh @@ -9,6 +9,4 @@ fi docker run --rm \ -v $PWD:/app \ golang:alpine \ - sh -c "apk --update add linux-headers musl-dev gcc libpcap-dev ca-certificates git && cd /app && CGO_ENABLED=1 GOOS=linux go build -a --ldflags '-linkmode external -extldflags \"-static -s -w\"' -o heplify ." - - + sh -c "apk --update add linux-headers musl-dev gcc libpcap-dev ca-certificates git luajit-dev && cd /app && CGO_ENABLED=1 GOOS=linux go build -buildvcs=false -a --ldflags '-linkmode external -extldflags \"-static -s -w\"' -o heplify ."