diff --git a/v2/json2/json_test.go b/v2/json2/json_test.go index 23a19a8..a8e4444 100644 --- a/v2/json2/json_test.go +++ b/v2/json2/json_test.go @@ -152,6 +152,23 @@ func executeInvalidJSON(t *testing.T, s *rpc.Server, res interface{}) error { return DecodeClientResponse(w.Body, res) } +func TestNewService(t *testing.T) { + s := rpc.NewServer() + s.RegisterCodec(NewCodec(), "application/json") + if err := s.RegisterService(new(Service1), "Service1.service2.service3"); err != nil { + t.Fatal(err) + } + + var res Service1Response + if err := execute(t, s, "Service1.service2.service3.Multiply", + &Service1Request{4, 2}, &res); err != nil { + t.Error("Expected err to be nil, but got:", err) + } + + println("result:", res.Result) + +} + func TestService(t *testing.T) { s := rpc.NewServer() s.RegisterCodec(NewCodec(), "application/json") diff --git a/v2/map.go b/v2/map.go index dda4216..b5cda62 100644 --- a/v2/map.go +++ b/v2/map.go @@ -30,6 +30,7 @@ type service struct { rcvr reflect.Value // receiver of methods for the service rcvrType reflect.Type // type of the receiver methods map[string]*serviceMethod // registered methods + services map[string]*service // 保存下一级的其他服务 } type serviceMethod struct { @@ -48,15 +49,67 @@ type serviceMap struct { services map[string]*service } +// 注册多个服务名的服务,每个服务名增加一个服务 +func (m *serviceMap) registryService(name string) (*service, error) { + // 切分服务全名 + parts := strings.Split(name, ".") + + var lastService *service + var newService *service + var serviceName string + + // 遍历服务名 + for _, part := range parts { + // 构建一个新的服务 + newService = &service{ + name: part, + methods: make(map[string]*serviceMethod), + services: make(map[string]*service), + } + + // 服务名字 + serviceName += part + "." + println(serviceName) + // 如果一开始是第一个服务名,要放到ServiceMap中 + if lastService == nil { + lastService = newService + + m.mutex.Lock() + + if m.services == nil { + m.services = make(map[string]*service) + } else if _, ok := m.services[lastService.name]; ok { + return nil, fmt.Errorf("rpc: service already defined: %q", + serviceName) + } + m.services[lastService.name] = lastService + + m.mutex.Unlock() + } else { + + if _, ok := lastService.services[newService.name]; ok { + return nil, fmt.Errorf("rpc: service already defined: %q", + serviceName) + } + + lastService.services[newService.name] = newService + lastService = newService + } + } + + return newService, nil +} + // register adds a new service using reflection to extract its methods. func (m *serviceMap) register(rcvr interface{}, name string) error { // Setup service. - s := &service{ - name: name, - rcvr: reflect.ValueOf(rcvr), - rcvrType: reflect.TypeOf(rcvr), - methods: make(map[string]*serviceMethod), + s, err := m.registryService(name) + if err != nil { + return err } + s.rcvr = reflect.ValueOf(rcvr) + s.rcvrType = reflect.TypeOf(rcvr) + if name == "" { s.name = reflect.Indirect(s.rcvr).Type().Name() if !isExported(s.name) { @@ -111,15 +164,15 @@ func (m *serviceMap) register(rcvr interface{}, name string) error { return fmt.Errorf("rpc: %q has no exported methods of suitable type", s.name) } - // Add to the map. - m.mutex.Lock() - defer m.mutex.Unlock() - if m.services == nil { - m.services = make(map[string]*service) - } else if _, ok := m.services[s.name]; ok { - return fmt.Errorf("rpc: service already defined: %q", s.name) - } - m.services[s.name] = s + // // Add to the map. + // m.mutex.Lock() + // defer m.mutex.Unlock() + // if m.services == nil { + // m.services = make(map[string]*service) + // } else if _, ok := m.services[s.name]; ok { + // return fmt.Errorf("rpc: service already defined: %q", s.name) + // } + // m.services[s.name] = s return nil } @@ -127,19 +180,38 @@ func (m *serviceMap) register(rcvr interface{}, name string) error { // // The method name uses a dotted notation as in "Service.Method". func (m *serviceMap) get(method string) (*service, *serviceMethod, error) { + // 分割方法名,考虑到可能有多级服务名 parts := strings.Split(method, ".") - if len(parts) != 2 { + if len(parts) < 2 { err := fmt.Errorf("rpc: service/method request ill-formed: %q", method) return nil, nil, err } + + // 实际方法名 + methodName := parts[len(parts)-1] + + // 按层次遍历服务 m.mutex.Lock() - service := m.services[parts[0]] - m.mutex.Unlock() - if service == nil { - err := fmt.Errorf("rpc: can't find service %q", method) - return nil, nil, err + var service *service + for index, part := range parts { + if index == len(parts)-1 { + break + } + + if service == nil { + service = m.services[part] + } else { + service = service.services[part] + } + + if service == nil { + err := fmt.Errorf("rpc: can't find service %q", method) + return nil, nil, err + } } - serviceMethod := service.methods[parts[1]] + m.mutex.Unlock() + + serviceMethod := service.methods[methodName] if serviceMethod == nil { err := fmt.Errorf("rpc: can't find method %q", method) return nil, nil, err