diff --git a/api/v1/dto/gnp.go b/api/v1/dto/gnp.go index 3cd48f4..f286984 100644 --- a/api/v1/dto/gnp.go +++ b/api/v1/dto/gnp.go @@ -19,6 +19,7 @@ type GNPMetadata struct { } type GNPSpec struct { + Order uint32 `json:"order" yaml:"order"` Selector string `json:"selector,omitempty" yaml:"selector"` Ingress []GNPSpecRule `json:"ingress,omitempty" yaml:"ingress"` Egress []GNPSpecRule `json:"egress,omitempty" yaml:"egress"` @@ -27,9 +28,9 @@ type GNPSpec struct { type GNPSpecRule struct { Metadata map[string]string `json:"metadata,omitempty" yaml:"metadata"` Action string `json:"action" yaml:"action"` - Protocol string `json:"protocol,omitempty" yaml:"protocol"` - NotProtocol string `json:"notProtocol,omitempty" yaml:"notProtocol"` - IPVersion int `json:"ipVersion" yaml:"ipVersion"` + Protocol interface{} `json:"protocol,omitempty" yaml:"protocol"` + NotProtocol interface{} `json:"notProtocol,omitempty" yaml:"notProtocol"` + IPVersion *int `json:"ipVersion,omitempty" yaml:"ipVersion"` Source *GNPSpecRuleEntity `json:"source,omitempty" yaml:"source"` Destination *GNPSpecRuleEntity `json:"destination,omitempty" yaml:"destination"` } @@ -54,6 +55,7 @@ type GNPMetadataInput struct { } type GNPSpecInput struct { + Order *uint32 `json:"order" yaml:"order"` Selector string `json:"selector" yaml:"selector" validate:"omitempty,selector"` Ingress []GNPSpecRuleInput `json:"ingress" yaml:"ingress" validate:"omitempty,min=1,dive"` Egress []GNPSpecRuleInput `json:"egress" yaml:"egress" validate:"omitempty,min=1,dive"` @@ -62,9 +64,9 @@ type GNPSpecInput struct { type GNPSpecRuleInput struct { Metadata map[string]string `json:"metadata" yaml:"metadata"` Action string `json:"action" yaml:"action" validate:"required,action"` - Protocol string `json:"protocol" yaml:"protocol" validate:"omitempty,protocol"` - NotProtocol string `json:"notProtocol" yaml:"notProtocol" validate:"omitempty,protocol"` - IPVersion int `json:"ipVersion" yaml:"ipVersion" validate:"required,ip_version"` + Protocol interface{} `json:"protocol" yaml:"protocol" validate:"omitempty,protocol"` + NotProtocol interface{} `json:"notProtocol" yaml:"notProtocol" validate:"omitempty,protocol"` + IPVersion *int `json:"ipVersion" yaml:"ipVersion" validate:"omitempty,ip_version"` Source *GNPSpecRuleEntityInput `json:"source" yaml:"source" validate:"omitempty"` Destination *GNPSpecRuleEntityInput `json:"destination" yaml:"destination" validate:"omitempty"` } @@ -73,8 +75,8 @@ type GNPSpecRuleEntityInput struct { Selector string `json:"selector" yaml:"selector" validate:"omitempty,selector"` Nets []string `json:"nets" yaml:"nets" validate:"omitempty,min=1,unique"` NotNets []string `json:"notNets" yaml:"notNets" validate:"omitempty,min=1,unique"` - Ports []interface{} `json:"ports" yaml:"ports" validate:"omitempty,min=1,unique,dive"` - NotPorts []interface{} `json:"notPorts" yaml:"notPorts" validate:"omitempty,min=1,unique,dive"` + Ports []interface{} `json:"ports" yaml:"ports" validate:"omitempty,min=1,unique,dive,port"` + NotPorts []interface{} `json:"notPorts" yaml:"notPorts" validate:"omitempty,min=1,unique,dive,port"` } type GetGNPInput struct { @@ -84,3 +86,7 @@ type GetGNPInput struct { type DeleteGlobalNetworkPolicyInput struct { Metadata GNPMetadataInput `json:"metadata" yaml:"metadata" validate:"required"` } + +type ListGNPsInput struct { + IsOrder bool `form:"isOrder"` +} diff --git a/api/v1/dto/gns.go b/api/v1/dto/gns.go index e681153..c93ffae 100644 --- a/api/v1/dto/gns.go +++ b/api/v1/dto/gns.go @@ -37,6 +37,8 @@ type GNSSpecInput struct { Nets []string `json:"nets" yaml:"nets" validate:"min=1,unique"` } +type ListGNSsInput struct{} + type GetGNSInput struct { Name string `uri:"name" validate:"required"` } diff --git a/api/v1/dto/hep.go b/api/v1/dto/hep.go index 36d6618..3fac8cc 100644 --- a/api/v1/dto/hep.go +++ b/api/v1/dto/hep.go @@ -22,6 +22,8 @@ type HostEndpointMetadata struct { type HostEndpointSpec struct { InterfaceName string `json:"interfaceName" yaml:"interfaceName"` + TenantID uint64 `json:"tenantID" yaml:"tenantID"` + IP string `json:"ip" yaml:"ip"` IPs []string `json:"ips" yaml:"ips"` } @@ -32,36 +34,45 @@ type CreateHostEndpointInput struct { } type HostEndpointMetadataInput struct { - Name string `json:"name" yaml:"name" validate:"required,name"` + Name string `json:"name" yaml:"name" validate:"omitempty,name"` Labels map[string]string `json:"labels" yaml:"labels"` } type HostEndpointSpecInput struct { InterfaceName string `json:"interfaceName" yaml:"interfaceName"` + TenantID uint64 `json:"tenantID" yaml:"tenantID" validate:"omitempty"` + IP string `json:"ip" yaml:"ip" validate:"omitempty,ip"` IPs []string `json:"ips" yaml:"ips" validate:"min=1,unique,dive,ip"` } +type ListHostEndpointsInput struct { + TenantID *uint64 `form:"tenantID" yaml:"tenantID" validate:"omitempty"` + IP *string `form:"ip" yaml:"ip" validate:"omitempty,ip"` +} + type GetHostEndpointInput struct { - Name string `uri:"name" validate:"required"` + TenantID uint64 `uri:"tenantID" yaml:"tenantID" validate:"required"` + IP string `uri:"ip" yaml:"ip" validate:"required,ip"` } type DeleteHostEndpointInput struct { - Metadata HostEndpointMetadataInput `json:"metadata" yaml:"metadata" validate:"required"` + Spec HostEndpointSpecInput `json:"spec" yaml:"spec" validate:"required"` } -type FetchHostEndpointPolicyInput struct { - Name string `uri:"name" validate:"required"` +type FetchHostEndpointPoliciesInput struct { + TenantID *uint64 `form:"tenantID" yaml:"tenantID" validate:"omitempty"` + IP *string `form:"ip" yaml:"ip" validate:"omitempty,ip"` } type HostEndpointPolicy struct { - MetaData HostEndPointPolicyMetadata `json:"metadata"` + MetaData HostEndpointPolicyMetadata `json:"metadata"` HEP *HostEndpoint `json:"hostEndpoint"` ParsedGNPs []*ParsedGNP `json:"parsedGNPs"` ParsedHEPs []*ParsedHEP `json:"parsedHEPs"` ParsedGNSs []*ParsedGNS `json:"parsedGNSs"` } -type HostEndPointPolicyMetadata struct { +type HostEndpointPolicyMetadata struct { HEPVersions map[string]uint `json:"hepVersions"` GNPVersions map[string]uint `json:"gnpVersions"` GNSVersions map[string]uint `json:"gnsVersions"` @@ -76,29 +87,31 @@ type ParsedGNP struct { } type ParsedRule struct { - Action string `json:"action"` - IPVersion int `json:"ipVersion"` - Protocol string `json:"protocol"` - IsProtocolNegative bool `json:"isProtocolNegative"` - SrcNets []string `json:"srcNets"` - IsSrcNetNegative bool `json:"isSrcNetNegative"` - SrcGNSUUIDs []string `json:"srcGNSUUIDs"` - SrcHEPUUIDs []string `json:"srcHEPUUIDs"` - SrcPorts []string `json:"srcPorts"` - IsSrcPortNegative bool `json:"isSrcPortNegative"` - DstNets []string `json:"dstNets"` - IsDstNetNegative bool `json:"isDstNetNegative"` - DstGNSUUIDs []string `json:"dstGNSUUIDs"` - DstHEPUUIDs []string `json:"dstHEPUUIDs"` - DstPorts []string `json:"dstPorts"` - IsDstPortNegative bool `json:"isDstPortNegative"` + Action string `json:"action"` + IPVersion *int `json:"ipVersion"` + Protocol interface{} `json:"protocol"` + IsProtocolNegative bool `json:"isProtocolNegative"` + SrcNets []string `json:"srcNets"` + IsSrcNetNegative bool `json:"isSrcNetNegative"` + SrcGNSUUIDs []string `json:"srcGNSUUIDs"` + SrcHEPUUIDs []string `json:"srcHEPUUIDs"` + SrcPorts []string `json:"srcPorts"` + IsSrcPortNegative bool `json:"isSrcPortNegative"` + DstNets []string `json:"dstNets"` + IsDstNetNegative bool `json:"isDstNetNegative"` + DstGNSUUIDs []string `json:"dstGNSUUIDs"` + DstHEPUUIDs []string `json:"dstHEPUUIDs"` + DstPorts []string `json:"dstPorts"` + IsDstPortNegative bool `json:"isDstPortNegative"` } type ParsedHEP struct { - UUID string `json:"uuid"` - Name string `json:"name"` - IPsV4 []string `json:"ipsV4"` - IPsV6 []string `json:"ipsV6"` + UUID string `json:"uuid"` + TenantID uint64 `json:"tenantID"` + Name string `json:"name"` + IP string `json:"ip"` + IPsV4 []string `json:"ipsV4"` + IPsV6 []string `json:"ipsV6"` } type ParsedGNS struct { diff --git a/api/v1/handler/gnp.go b/api/v1/handler/gnp.go index acbe2f0..81cc2e6 100644 --- a/api/v1/handler/gnp.go +++ b/api/v1/handler/gnp.go @@ -16,6 +16,7 @@ import ( type gnpService interface { Create(ctx context.Context, input *model.CreateGlobalNetworkPolicyInput) (*entity.GlobalNetworkPolicy, *ierror.Error) + List(ctx context.Context, input *model.ListGNPsInput) ([]*entity.GlobalNetworkPolicy, *ierror.Error) Get(ctx context.Context, name string) (*entity.GlobalNetworkPolicy, *ierror.Error) Delete(ctx context.Context, name string) *ierror.Error } @@ -45,6 +46,21 @@ func (h *gnp) Create(c *gin.Context) { httpbase.ReturnSuccessResponse(c, http.StatusOK, mapper.ToGlobalNetworkPolicyDTO(gnsEntity)) } +func (h *gnp) List(c *gin.Context) { + in := new(dto.ListGNPsInput) + if ierr := httpbase.BindInput(c, in); ierr != nil { + httpbase.ReturnErrorResponse(c, ierr) + return + } + + gnpsEntity, ierr := h.service.List(c.Request.Context(), &model.ListGNPsInput{IsOrder: in.IsOrder}) + if ierr != nil { + httpbase.ReturnErrorResponse(c, ierr) + return + } + httpbase.ReturnSuccessResponse(c, http.StatusOK, mapper.ToListGlobalNetworkPolicyDTOs(gnpsEntity)) +} + func (h *gnp) Get(c *gin.Context) { in := new(dto.GetGNPInput) if ierr := httpbase.BindInput(c, in); ierr != nil { diff --git a/api/v1/handler/gns.go b/api/v1/handler/gns.go index eb80454..4e8ad05 100644 --- a/api/v1/handler/gns.go +++ b/api/v1/handler/gns.go @@ -16,6 +16,7 @@ import ( type gnsService interface { Create(ctx context.Context, input *model.CreateGlobalNetworkSetInput) (*entity.GlobalNetworkSet, *ierror.Error) + List(ctx context.Context) ([]*entity.GlobalNetworkSet, *ierror.Error) Get(ctx context.Context, name string) (*entity.GlobalNetworkSet, *ierror.Error) Delete(ctx context.Context, name string) *ierror.Error } @@ -45,6 +46,15 @@ func (h *gns) Create(c *gin.Context) { httpbase.ReturnSuccessResponse(c, http.StatusOK, mapper.ToGlobalNetworkSetDTO(gnsEntity)) } +func (h *gns) List(c *gin.Context) { + gnpsEntity, ierr := h.service.List(c.Request.Context()) + if ierr != nil { + httpbase.ReturnErrorResponse(c, ierr) + return + } + httpbase.ReturnSuccessResponse(c, http.StatusOK, mapper.ToListGlobalNetworkSetDTOs(gnpsEntity)) +} + func (h *gns) Get(c *gin.Context) { in := new(dto.GetGNSInput) if ierr := httpbase.BindInput(c, in); ierr != nil { diff --git a/api/v1/handler/hep.go b/api/v1/handler/hep.go index dadeea4..baed832 100644 --- a/api/v1/handler/hep.go +++ b/api/v1/handler/hep.go @@ -16,9 +16,10 @@ import ( type hepService interface { Create(ctx context.Context, input *model.CreateHostEndpointInput) (*entity.HostEndpoint, *ierror.Error) - Get(ctx context.Context, name string) (*entity.HostEndpoint, *ierror.Error) - Delete(ctx context.Context, name string) *ierror.Error - FetchPolicies(ctx context.Context, input *model.FetchHostEndpointPolicyInput) (*model.HostEndPointPolicy, *ierror.Error) + List(ctx context.Context, input *model.ListHostEndpointsInput) ([]*entity.HostEndpoint, *ierror.Error) + Get(ctx context.Context, input *model.GetHostEndpointInput) (*entity.HostEndpoint, *ierror.Error) + Delete(ctx context.Context, input *model.DeleteHostEndpointInput) *ierror.Error + FetchPolicies(ctx context.Context, input *model.ListHostEndpointsInput) ([]*model.HostEndpointPolicy, *ierror.Error) } func NewHEP(s hepService) *hep { @@ -46,6 +47,21 @@ func (h *hep) Create(c *gin.Context) { httpbase.ReturnSuccessResponse(c, http.StatusOK, mapper.ToHostEndpointDTO(hepEntity)) } +func (h *hep) List(c *gin.Context) { + in := new(dto.ListHostEndpointsInput) + if ierr := httpbase.BindInput(c, in); ierr != nil { + httpbase.ReturnErrorResponse(c, ierr) + return + } + + gnpsEntity, ierr := h.service.List(c.Request.Context(), mapper.ToListHostEndpointsInput(in)) + if ierr != nil { + httpbase.ReturnErrorResponse(c, ierr) + return + } + httpbase.ReturnSuccessResponse(c, http.StatusOK, mapper.ToListHostEndpointDTOs(gnpsEntity)) +} + func (h *hep) Get(c *gin.Context) { in := new(dto.GetHostEndpointInput) if ierr := httpbase.BindInput(c, in); ierr != nil { @@ -53,7 +69,7 @@ func (h *hep) Get(c *gin.Context) { return } - hepEntity, ierr := h.service.Get(c.Request.Context(), in.Name) + hepEntity, ierr := h.service.Get(c.Request.Context(), mapper.ToGetHostEndpointInput(in)) if ierr != nil { httpbase.ReturnErrorResponse(c, ierr) return @@ -68,7 +84,11 @@ func (h *hep) Delete(c *gin.Context) { return } - if err := h.service.Delete(c.Request.Context(), in.Metadata.Name); err != nil { + if err := h.service.Delete(c.Request.Context(), &model.DeleteHostEndpointInput{ + TenantID: in.Spec.TenantID, + IP: in.Spec.IP, + IPs: in.Spec.IPs, + }); err != nil { httpbase.ReturnErrorResponse(c, err) return } @@ -76,15 +96,15 @@ func (h *hep) Delete(c *gin.Context) { } func (h *hep) FetchPolicies(c *gin.Context) { - in := new(dto.FetchHostEndpointPolicyInput) + in := new(dto.FetchHostEndpointPoliciesInput) if ierr := httpbase.BindInput(c, in); ierr != nil { httpbase.ReturnErrorResponse(c, ierr) return } - hostEndpointPolicy, ierr := h.service.FetchPolicies(c.Request.Context(), mapper.ToFetchHostEndPointPolicyInput(in)) + hostEndpointPolicies, ierr := h.service.FetchPolicies(c.Request.Context(), mapper.ToFetchHostEndpointPolicyInput(in)) if ierr != nil { httpbase.ReturnErrorResponse(c, ierr) return } - httpbase.ReturnSuccessResponse(c, http.StatusOK, mapper.ToFetchPoliciesOutput(hostEndpointPolicy)) + httpbase.ReturnSuccessResponse(c, http.StatusOK, mapper.ToFetchHEPPoliciesOutput(hostEndpointPolicies)) } diff --git a/api/v1/mapper/gnp.go b/api/v1/mapper/gnp.go index 6ce19e5..4b3f279 100644 --- a/api/v1/mapper/gnp.go +++ b/api/v1/mapper/gnp.go @@ -6,6 +6,14 @@ import ( "github.com/bamboo-firewall/be/domain/model" ) +func ToListGlobalNetworkPolicyDTOs(gnps []*entity.GlobalNetworkPolicy) []*dto.GlobalNetworkPolicy { + gnpDTOs := make([]*dto.GlobalNetworkPolicy, 0, len(gnps)) + for _, gnp := range gnps { + gnpDTOs = append(gnpDTOs, ToGlobalNetworkPolicyDTO(gnp)) + } + return gnpDTOs +} + func ToGlobalNetworkPolicyDTO(gnp *entity.GlobalNetworkPolicy) *dto.GlobalNetworkPolicy { if gnp == nil { return nil @@ -29,6 +37,7 @@ func ToGlobalNetworkPolicyDTO(gnp *entity.GlobalNetworkPolicy) *dto.GlobalNetwor Labels: gnp.Metadata.Labels, }, Spec: dto.GNPSpec{ + Order: gnp.Spec.Order, Selector: gnp.Spec.Selector, Ingress: specIngress, Egress: specEgress, @@ -45,7 +54,7 @@ func toRuleDTO(rule entity.GNPSpecRule) dto.GNPSpecRule { Action: rule.Action, Protocol: rule.Protocol, NotProtocol: rule.NotProtocol, - IPVersion: int(rule.IPVersion), + IPVersion: rule.IPVersion, Source: toRuleEntityDTO(rule.Source), Destination: toRuleEntityDTO(rule.Destination), } @@ -65,12 +74,12 @@ func toRuleEntityDTO(ruleEntity *entity.GNPSpecRuleEntity) *dto.GNPSpecRuleEntit } func ToCreateGlobalNetworkPolicyInput(in *dto.CreateGlobalNetworkPolicyInput) *model.CreateGlobalNetworkPolicyInput { - var specIngress []model.GNPSpecRuleInput + specIngress := make([]model.GNPSpecRuleInput, 0, len(in.Spec.Ingress)) for _, rule := range in.Spec.Ingress { specIngress = append(specIngress, toRuleInput(rule)) } - var specEgress []model.GNPSpecRuleInput + specEgress := make([]model.GNPSpecRuleInput, 0, len(in.Spec.Egress)) for _, rule := range in.Spec.Egress { specEgress = append(specEgress, toRuleInput(rule)) } @@ -81,6 +90,7 @@ func ToCreateGlobalNetworkPolicyInput(in *dto.CreateGlobalNetworkPolicyInput) *m Labels: in.Metadata.Labels, }, Spec: model.GNPSpecInput{ + Order: in.Spec.Order, Selector: in.Spec.Selector, Ingress: specIngress, Egress: specEgress, diff --git a/api/v1/mapper/gns.go b/api/v1/mapper/gns.go index 6a06238..67f9290 100644 --- a/api/v1/mapper/gns.go +++ b/api/v1/mapper/gns.go @@ -6,6 +6,14 @@ import ( "github.com/bamboo-firewall/be/domain/model" ) +func ToListGlobalNetworkSetDTOs(gnss []*entity.GlobalNetworkSet) []*dto.GlobalNetworkSet { + gnsDTOs := make([]*dto.GlobalNetworkSet, 0, len(gnss)) + for _, gns := range gnss { + gnsDTOs = append(gnsDTOs, ToGlobalNetworkSetDTO(gns)) + } + return gnsDTOs +} + func ToGlobalNetworkSetDTO(gns *entity.GlobalNetworkSet) *dto.GlobalNetworkSet { if gns == nil { return nil diff --git a/api/v1/mapper/hep.go b/api/v1/mapper/hep.go index 545471a..4ff08a7 100644 --- a/api/v1/mapper/hep.go +++ b/api/v1/mapper/hep.go @@ -3,9 +3,18 @@ package mapper import ( "github.com/bamboo-firewall/be/api/v1/dto" "github.com/bamboo-firewall/be/cmd/server/pkg/entity" + "github.com/bamboo-firewall/be/cmd/server/pkg/net" "github.com/bamboo-firewall/be/domain/model" ) +func ToListHostEndpointDTOs(heps []*entity.HostEndpoint) []*dto.HostEndpoint { + hepDTOs := make([]*dto.HostEndpoint, 0, len(heps)) + for _, hep := range heps { + hepDTOs = append(hepDTOs, ToHostEndpointDTO(hep)) + } + return hepDTOs +} + func ToHostEndpointDTO(hep *entity.HostEndpoint) *dto.HostEndpoint { if hep == nil { return nil @@ -20,6 +29,8 @@ func ToHostEndpointDTO(hep *entity.HostEndpoint) *dto.HostEndpoint { }, Spec: dto.HostEndpointSpec{ InterfaceName: hep.Spec.InterfaceName, + TenantID: hep.Spec.TenantID, + IP: net.IntToIP(hep.Spec.IP).String(), IPs: hep.Spec.IPs, }, Description: hep.Description, @@ -36,19 +47,65 @@ func ToCreateHostEndpointInput(in *dto.CreateHostEndpointInput) *model.CreateHos }, Spec: model.HostEndpointSpecInput{ InterfaceName: in.Spec.InterfaceName, + IP: in.Spec.IP, + TenantID: in.Spec.TenantID, IPs: in.Spec.IPs, }, Description: in.Description, } } -func ToFetchHostEndPointPolicyInput(in *dto.FetchHostEndpointPolicyInput) *model.FetchHostEndpointPolicyInput { - return &model.FetchHostEndpointPolicyInput{ - Name: in.Name, +func ToGetHostEndpointInput(in *dto.GetHostEndpointInput) *model.GetHostEndpointInput { + var ipInt uint32 + netIP := net.ParseIP(in.IP) + if netIP != nil { + ipInt = net.IPToInt(*netIP) + } + return &model.GetHostEndpointInput{ + TenantID: in.TenantID, + IP: ipInt, + } +} + +func ToListHostEndpointsInput(in *dto.ListHostEndpointsInput) *model.ListHostEndpointsInput { + var ipInt *uint32 + if in.IP != nil { + netIP := net.ParseIP(*in.IP) + if netIP != nil { + ip := net.IPToInt(*netIP) + ipInt = &ip + } + } + return &model.ListHostEndpointsInput{ + TenantID: in.TenantID, + IP: ipInt, + } +} + +func ToFetchHostEndpointPolicyInput(in *dto.FetchHostEndpointPoliciesInput) *model.ListHostEndpointsInput { + var ipInt *uint32 + if in.IP != nil { + netIP := net.ParseIP(*in.IP) + if netIP != nil { + ip := net.IPToInt(*netIP) + ipInt = &ip + } + } + return &model.ListHostEndpointsInput{ + TenantID: in.TenantID, + IP: ipInt, + } +} + +func ToFetchHEPPoliciesOutput(hepPolicies []*model.HostEndpointPolicy) []*dto.HostEndpointPolicy { + result := make([]*dto.HostEndpointPolicy, 0, len(hepPolicies)) + for _, hepPolicy := range hepPolicies { + result = append(result, ToFetchHEPPolicyOutput(hepPolicy)) } + return result } -func ToFetchPoliciesOutput(hostEndpointPolicy *model.HostEndPointPolicy) *dto.HostEndpointPolicy { +func ToFetchHEPPolicyOutput(hostEndpointPolicy *model.HostEndpointPolicy) *dto.HostEndpointPolicy { parsedGNPDTOs := make([]*dto.ParsedGNP, len(hostEndpointPolicy.ParsedGNPs)) for i, policy := range hostEndpointPolicy.ParsedGNPs { parsedGNPDTOs[i] = toParsedGNPDTO(policy) @@ -62,7 +119,7 @@ func ToFetchPoliciesOutput(hostEndpointPolicy *model.HostEndPointPolicy) *dto.Ho parsedGNSDTOs[i] = toParsedGNSDTO(set) } return &dto.HostEndpointPolicy{ - MetaData: dto.HostEndPointPolicyMetadata{ + MetaData: dto.HostEndpointPolicyMetadata{ HEPVersions: hostEndpointPolicy.MetaData.HEPVersions, GNPVersions: hostEndpointPolicy.MetaData.GNPVersions, GNSVersions: hostEndpointPolicy.MetaData.GNSVersions, diff --git a/build/init.sh b/build/init.sh index cd5dea6..7058007 100755 --- a/build/init.sh +++ b/build/init.sh @@ -6,13 +6,14 @@ BUILD_CMD_PATH="${BUILD_OUTPUT_PATH}/bin" PACKAGE_NAME="github.com/bamboo-firewall/be" VERSION="$(git describe --abbrev=0 --tags)" -BRANCH="$(git rev-parse --abbrev-ref HEAD)" +BRANCH="${BRANCH:-$(git rev-parse --abbrev-ref HEAD)}" +SHA="$(git describe --match=none --always --abbrev=8)" BUILD_TIME="$(date +%Y-%m-%dT%H:%M:%S%z)" ORGANIZATION="ATAOCloud" LDFLAGS="-s -w -X ${PACKAGE_NAME}/buildinfo.Version=${VERSION} \ - -X ${PACKAGE_NAME}/buildinfo.GitBranch=${BRANCH} \ + -X ${PACKAGE_NAME}/buildinfo.GitBranch=${BRANCH}.${SHA} \ -X ${PACKAGE_NAME}/buildinfo.BuildDate=${BUILD_TIME} \ -X ${PACKAGE_NAME}/buildinfo.Organization=${ORGANIZATION}" diff --git a/cmd/bamboofwcli/command/create.go b/cmd/bamboofwcli/command/create.go index 02868a1..b44f621 100644 --- a/cmd/bamboofwcli/command/create.go +++ b/cmd/bamboofwcli/command/create.go @@ -25,10 +25,10 @@ var createCMD = &cobra.Command{ * GlobalNetworkSet(or gns) * GlobalNetworkPolicy(or gnp)`, Example: ` # Create a global network policy - bbfwcli create gnp policy.yaml + bbfw create gnp -f policy.yaml # Create many global network policy - bbfwcli create gnp policy1.yaml policy2.yaml`, + bbfw create gnp -f policy1.yaml -f policy2.yaml`, Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { if err := create(cmd, args); err != nil { @@ -70,7 +70,7 @@ func create(cmd *cobra.Command, args []string) error { for _, r := range resources { err = resourceMgr.Create(context.Background(), apiServer, r.Content) if err != nil { - fmt.Printf("Fail to create resource. Error: %v\n", err) + fmt.Printf("Fail to create resource: %s. Error: %v\n", r.Name, err) } else { fmt.Printf("Successsfully created resource from %s\n", r.Name) numHandled++ diff --git a/cmd/bamboofwcli/command/delete.go b/cmd/bamboofwcli/command/delete.go index 92ab660..ffdac07 100644 --- a/cmd/bamboofwcli/command/delete.go +++ b/cmd/bamboofwcli/command/delete.go @@ -13,25 +13,46 @@ import ( "github.com/bamboo-firewall/be/pkg/client" ) -var fileDeletes []string +var ( + deleteHEPByTenantID uint64 + deleteHEPByIP string + fileDeletes []string +) var deleteCMD = &cobra.Command{ Use: "delete [resourceType]", - Short: "Delete resources by name or filename", - Long: `The delete command is used to delete resources by name or filename. + Short: "Delete resources", + Long: `The delete command is used to delete resources by name(Global Network Policy, Global Network Set), +by tenantID,IP(Host Endpoint) or filename. Resource type available: * HostEndpoint(or hep) * GlobalNetworkSet(or gns) * GlobalNetworkPolicy(or gnp)`, Example: ` # Delete a policy with name - bbfwcli delete gnp allow_ssh + bbfw delete gnp allow_ssh # Delete many policy with name - bbfwcli delete hep allow_ssh allow_ping + bbfw delete gnp allow_ssh allow_ping # Delete many policy with filename - bbfwcli delete hep allow_ssh.yaml allow_ping.yaml`, + bbfw delete gnp -f allow_ssh.yaml -f allow_ping.yaml + + # Delete a set with name + bbfw delete gns server + + # Delete many sets with name + bbfw delete gns server vm + + # Delete many sets with filename + bbfw delete gns -f server.yaml -f vm.yaml + + # Delete a hep with tenantID and ip + bbfw delete hep --tenantID=1 --ip=192.168.1.1 + + # Delete many heps with filename + bbfw delete hep -f server.yaml -f vm.yaml +`, Args: cobra.MinimumNArgs(1), Run: func(cmd *cobra.Command, args []string) { if err := deleteResources(cmd, args); err != nil { @@ -42,6 +63,8 @@ var deleteCMD = &cobra.Command{ } func init() { + deleteCMD.Flags().Uint64Var(&deleteHEPByTenantID, "tenantID", 0, "HEP: get by tenantID") + deleteCMD.Flags().StringVar(&deleteHEPByIP, "ip", "", "HEP: get by ip") deleteCMD.Flags().StringArrayVarP(&fileDeletes, "file", "f", []string{}, "file to read") } @@ -52,13 +75,21 @@ func deleteResources(cmd *cobra.Command, args []string) error { return err } var resourcesName []string - if len(args) > 1 { - resourcesName = args[1:] - } - if len(resourcesName) > 0 && len(fileDeletes) > 0 { - return fmt.Errorf("cannot use name resource with file param together") - } else if len(resourcesName) == 0 && len(fileDeletes) == 0 { - return fmt.Errorf("must specify at least one resource to delete") + if resourceMgr.GetResourceType() == resouremanager.ResourceTypeHEP { + if deleteHEPByTenantID > 0 && deleteHEPByIP != "" && len(fileDeletes) > 0 { + return fmt.Errorf("cannot use tenantID, ip with file param together") + } else if (deleteHEPByTenantID == 0 || deleteHEPByIP == "") && len(fileDeletes) == 0 { + return fmt.Errorf("must specify tenantID, IP or file to delete") + } + } else { + if len(args) > 1 { + resourcesName = args[1:] + } + if len(resourcesName) > 0 && len(fileDeletes) > 0 { + return fmt.Errorf("cannot use name resource with file param together") + } else if len(resourcesName) == 0 && len(fileDeletes) == 0 { + return fmt.Errorf("must specify at least one resource to delete") + } } var resources []*common.ResourceFile @@ -80,14 +111,6 @@ func deleteResources(cmd *cobra.Command, args []string) error { for _, name := range resourcesName { switch resourceMgr.GetResourceType() { case resouremanager.ResourceTypeHEP: - resources = append(resources, &common.ResourceFile{ - Name: name, - Content: &dto.DeleteHostEndpointInput{ - Metadata: dto.HostEndpointMetadataInput{ - Name: name, - }, - }, - }) case resouremanager.ResourceTypeGNS: resources = append(resources, &common.ResourceFile{ Name: name, @@ -109,7 +132,19 @@ func deleteResources(cmd *cobra.Command, args []string) error { default: return fmt.Errorf("unsupported resource type: %s", resourceType) } + } + if resourceMgr.GetResourceType() == resouremanager.ResourceTypeHEP { + resources = append(resources, &common.ResourceFile{ + Name: fmt.Sprintf("%d_%s", deleteHEPByTenantID, deleteHEPByIP), + Content: &dto.DeleteHostEndpointInput{ + Spec: dto.HostEndpointSpecInput{ + TenantID: deleteHEPByTenantID, + IP: deleteHEPByIP, + IPs: []string{deleteHEPByIP}, + }, + }, + }) } } @@ -118,7 +153,7 @@ func deleteResources(cmd *cobra.Command, args []string) error { for _, r := range resources { err = resourceMgr.Delete(context.Background(), apiServer, r.Content) if err != nil { - fmt.Printf("fail to delete resource from: %v\n", err) + fmt.Printf("fail to delete resource %s from: %v\n", r.Name, err) } else { fmt.Printf("successsfully deleted resource from %s\n", r.Name) numHandled++ diff --git a/cmd/bamboofwcli/command/get.go b/cmd/bamboofwcli/command/get.go index 974d7b2..a592a86 100644 --- a/cmd/bamboofwcli/command/get.go +++ b/cmd/bamboofwcli/command/get.go @@ -16,18 +16,31 @@ import ( "github.com/bamboo-firewall/be/pkg/client" ) -var outputFormat string +var ( + getHEPByTenantID uint64 + getHEPByIP string + outputFormat string +) var getCMD = &cobra.Command{ Use: "get", - Short: "Get resource by name", + Short: "Get resource", Example: ` # Get a global network policy by name - bbfwcli get gnp allow_ssh + bbfw get gnp allow_ssh # Get a global network policy by name with json output format - bbfwcli get gnp allow_ssh -o json + bbfw get gnp allow_ssh -o json + + # Get a host endpoint + bbfw get hep --tenantID=1 --ip=192.168.123.0 + + # Get a global network set by name + bbfw get gns allow_ssh + + # Get a global network set by name with json output format + bbfw get gns my_set -o json `, - Args: cobra.ExactArgs(2), + Args: cobra.MaximumNArgs(2), Run: func(cmd *cobra.Command, args []string) { if err := get(cmd, args); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) @@ -37,6 +50,8 @@ var getCMD = &cobra.Command{ } func init() { + getCMD.Flags().Uint64Var(&getHEPByTenantID, "tenantID", 0, "HEP: get by tenantID") + getCMD.Flags().StringVar(&getHEPByIP, "ip", "", "HEP: get by ip") getCMD.Flags().StringVarP(&outputFormat, "output", "o", "", "output format(yaml|json). Default: yaml") } @@ -47,15 +62,30 @@ func get(cmd *cobra.Command, args []string) error { return err } - resourceName := args[1] + var resourceName string + if len(args) > 1 { + resourceName = args[1] + } var input interface{} switch resourceMgr.GetResourceType() { case resouremanager.ResourceTypeHEP: - input = &dto.GetHostEndpointInput{Name: resourceName} + if getHEPByTenantID == 0 || getHEPByIP == "" { + return fmt.Errorf("get HEP by tenantID or ip is required") + } + input = &dto.GetHostEndpointInput{ + TenantID: getHEPByTenantID, + IP: getHEPByIP, + } case resouremanager.ResourceTypeGNS: + if resourceName == "" { + return fmt.Errorf("no resource name provided") + } input = &dto.GetGNSInput{Name: resourceName} case resouremanager.ResourceTypeGNP: + if resourceName == "" { + return fmt.Errorf("no resource name provided") + } input = &dto.GetGNPInput{Name: resourceName} default: return fmt.Errorf("unsupported resource type: %s", resourceType) @@ -71,7 +101,12 @@ func get(cmd *cobra.Command, args []string) error { var output []byte switch common.FileExtension(outputFormat) { case common.FileExtensionJSON: - output, err = json.MarshalIndent(resource, "", " ") + var buf bytes.Buffer + encoder := json.NewEncoder(&buf) + encoder.SetEscapeHTML(false) + encoder.SetIndent("", " ") + err = encoder.Encode(resource) + output = buf.Bytes() default: var buf bytes.Buffer yamlEncoder := yaml.NewEncoder(&buf) diff --git a/cmd/bamboofwcli/command/list.go b/cmd/bamboofwcli/command/list.go new file mode 100644 index 0000000..16a14b2 --- /dev/null +++ b/cmd/bamboofwcli/command/list.go @@ -0,0 +1,140 @@ +package command + +import ( + "bytes" + "context" + "fmt" + "os" + "text/tabwriter" + "text/template" + + "github.com/spf13/cobra" + + "github.com/bamboo-firewall/be/api/v1/dto" + "github.com/bamboo-firewall/be/cmd/bamboofwcli/command/common" + "github.com/bamboo-firewall/be/cmd/bamboofwcli/command/resouremanager" + "github.com/bamboo-firewall/be/pkg/client" +) + +var ( + ListHEPsByTenantID uint64 + ListHEPsByIP string + + ListGNPsByIsOrder bool +) + +var listCMD = &cobra.Command{ + Use: "list", + Short: "List resource", + Example: ` # List global network sets + bbfw list gns + + # List global network policy + bbfw list gnp + + # List global network policy with order + bbfw list gnp --isOrder + + # List host endpoint + bbfw list hep + + # List host endpoint with tenantID + bbfw list hep --tenantID=1 + + # List host endpoint with IP + bbfw list hep --ip=192.168.0.1 + + # List host endpoint with tenantID and IP + bbfw list hep --tenantID=1 --ip=192.168.0.1, +`, + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + if err := list(cmd, args); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + }, +} + +func init() { + listCMD.Flags().Uint64Var(&ListHEPsByTenantID, "tenantID", 0, "Host Endpoint: filter by TenantID") + listCMD.Flags().StringVar(&ListHEPsByIP, "ip", "", "Host Endpoint: filter by IP") + listCMD.Flags().BoolVar(&ListGNPsByIsOrder, "isOrder", false, "Global Network Policy: filter by Order") +} + +func list(cmd *cobra.Command, args []string) error { + resourceType := args[0] + resourceMgr, err := common.GetResourceMgrByType(resourceType) + if err != nil { + return err + } + + var input interface{} + switch resourceMgr.GetResourceType() { + case resouremanager.ResourceTypeHEP: + listHEPsInput := &dto.ListHostEndpointsInput{} + if ListHEPsByTenantID > 0 { + listHEPsInput.TenantID = &ListHEPsByTenantID + } + if ListHEPsByIP != "" { + listHEPsInput.IP = &ListHEPsByIP + } + input = listHEPsInput + case resouremanager.ResourceTypeGNS: + case resouremanager.ResourceTypeGNP: + input = &dto.ListGNPsInput{IsOrder: ListGNPsByIsOrder} + default: + return fmt.Errorf("unsupported resources type: %s", resourceType) + } + + apiServer := client.NewAPIServer(os.Getenv(common.APIServerENV)) + + resources, err := resourceMgr.List(context.Background(), apiServer, input) + if err != nil { + return fmt.Errorf("list resources failed: %w", err) + } + + if err = printResources(resourceMgr, resources); err != nil { + return err + } + return nil +} + +func printResources(resourceMgr resouremanager.Resource, resources interface{}) error { + header := resourceMgr.GetHeader() + headerMap := resourceMgr.GetHeaderMap() + + buf := new(bytes.Buffer) + for _, h := range header { + buf.WriteString(h) + buf.WriteByte('\t') + } + buf.WriteByte('\n') + + buf.WriteString("{{range .}}") + + for _, h := range header { + value, ok := headerMap[h] + if !ok { + continue + } + buf.WriteString(value) + buf.WriteByte('\t') + } + buf.WriteByte('\n') + + buf.WriteString("{{end}}") + + tmpl, err := template.New("list").Parse(buf.String()) + if err != nil { + return fmt.Errorf("parse template failed: %w", err) + } + writer := tabwriter.NewWriter(os.Stdout, 5, 1, 3, ' ', 0) + err = tmpl.Execute(writer, resources) + if err != nil { + return fmt.Errorf("execute template failed: %w", err) + } + writer.Flush() + fmt.Printf("\n") + return nil +} diff --git a/cmd/bamboofwcli/command/resouremanager/gnp.go b/cmd/bamboofwcli/command/resouremanager/gnp.go index 42d2dd0..a11aef4 100644 --- a/cmd/bamboofwcli/command/resouremanager/gnp.go +++ b/cmd/bamboofwcli/command/resouremanager/gnp.go @@ -18,6 +18,11 @@ func (p *gnp) Create(ctx context.Context, apiServer APIServer, resource interfac return apiServer.CreateGNP(ctx, r) } +func (p *gnp) List(ctx context.Context, apiServer APIServer, resource interface{}) (interface{}, error) { + r := resource.(*dto.ListGNPsInput) + return apiServer.ListGNPs(ctx, r) +} + func (p *gnp) Get(ctx context.Context, apiServer APIServer, resource interface{}) (interface{}, error) { r := resource.(*dto.GetGNPInput) return apiServer.GetGNP(ctx, r) @@ -31,3 +36,16 @@ func (p *gnp) Delete(ctx context.Context, apiServer APIServer, resource interfac func (p *gnp) GetResourceType() ResourceType { return ResourceTypeGNP } + +func (p *gnp) GetHeader() []string { + return []string{"UUID", "NAME", "ORDER", "VERSION"} +} + +func (p *gnp) GetHeaderMap() map[string]string { + return map[string]string{ + "UUID": "{{.UUID}}", + "NAME": "{{.Metadata.Name}}", + "ORDER": "{{.Spec.Order}}", + "VERSION": "{{.Version}}", + } +} diff --git a/cmd/bamboofwcli/command/resouremanager/gns.go b/cmd/bamboofwcli/command/resouremanager/gns.go index 665f549..6a3b90c 100644 --- a/cmd/bamboofwcli/command/resouremanager/gns.go +++ b/cmd/bamboofwcli/command/resouremanager/gns.go @@ -18,6 +18,10 @@ func (s *gns) Create(ctx context.Context, apiServer APIServer, resource interfac return apiServer.CreateGNS(ctx, r) } +func (s *gns) List(ctx context.Context, apiServer APIServer, resource interface{}) (interface{}, error) { + return apiServer.ListGNSs(ctx) +} + func (s *gns) Get(ctx context.Context, apiServer APIServer, resource interface{}) (interface{}, error) { r := resource.(*dto.GetGNSInput) return apiServer.GetGNS(ctx, r) @@ -31,3 +35,16 @@ func (s *gns) Delete(ctx context.Context, apiServer APIServer, resource interfac func (s *gns) GetResourceType() ResourceType { return ResourceTypeGNS } + +func (s *gns) GetHeader() []string { + return []string{"UUID", "NAME", "NETS", "VERSION"} +} + +func (s *gns) GetHeaderMap() map[string]string { + return map[string]string{ + "UUID": "{{.UUID}}", + "NAME": "{{.Metadata.Name}}", + "NETS": "{{.Spec.Nets}}", + "VERSION": "{{.Version}}", + } +} diff --git a/cmd/bamboofwcli/command/resouremanager/hep.go b/cmd/bamboofwcli/command/resouremanager/hep.go index 0f00ffc..07daf36 100644 --- a/cmd/bamboofwcli/command/resouremanager/hep.go +++ b/cmd/bamboofwcli/command/resouremanager/hep.go @@ -18,6 +18,11 @@ func (h *hep) Create(ctx context.Context, apiServer APIServer, resource interfac return apiServer.CreateHEP(ctx, r) } +func (h *hep) List(ctx context.Context, apiServer APIServer, resource interface{}) (interface{}, error) { + r := resource.(*dto.ListHostEndpointsInput) + return apiServer.ListHEPs(ctx, r) +} + func (h *hep) Get(ctx context.Context, apiServer APIServer, resource interface{}) (interface{}, error) { r := resource.(*dto.GetHostEndpointInput) return apiServer.GetHEP(ctx, r) @@ -31,3 +36,18 @@ func (h *hep) Delete(ctx context.Context, apiServer APIServer, resource interfac func (h *hep) GetResourceType() ResourceType { return ResourceTypeHEP } + +func (h *hep) GetHeader() []string { + return []string{"UUID", "NAME", "TENANT_ID", "IP", "IPS", "VERSION"} +} + +func (h *hep) GetHeaderMap() map[string]string { + return map[string]string{ + "UUID": "{{.UUID}}", + "NAME": "{{.Metadata.Name}}", + "TENANT_ID": "{{.Spec.TenantID}}", + "IP": "{{.Spec.IP}}", + "IPS": "{{.Spec.IPs}}", + "VERSION": "{{.Version}}", + } +} diff --git a/cmd/bamboofwcli/command/resouremanager/resource_manager.go b/cmd/bamboofwcli/command/resouremanager/resource_manager.go index 524210c..fcd20ba 100644 --- a/cmd/bamboofwcli/command/resouremanager/resource_manager.go +++ b/cmd/bamboofwcli/command/resouremanager/resource_manager.go @@ -17,19 +17,25 @@ const ( type Resource interface { Create(ctx context.Context, apiServer APIServer, resource interface{}) error + List(ctx context.Context, apiServer APIServer, resource interface{}) (interface{}, error) Get(ctx context.Context, apiServer APIServer, resource interface{}) (interface{}, error) Delete(ctx context.Context, apiServer APIServer, resource interface{}) error GetResourceType() ResourceType + GetHeader() []string + GetHeaderMap() map[string]string } type APIServer interface { CreateHEP(ctx context.Context, input *dto.CreateHostEndpointInput) error + ListHEPs(ctx context.Context, input *dto.ListHostEndpointsInput) ([]*dto.HostEndpoint, error) GetHEP(ctx context.Context, input *dto.GetHostEndpointInput) (*dto.HostEndpoint, error) DeleteHEP(ctx context.Context, input *dto.DeleteHostEndpointInput) error CreateGNS(ctx context.Context, input *dto.CreateGlobalNetworkSetInput) error + ListGNSs(ctx context.Context) ([]*dto.GlobalNetworkSet, error) GetGNS(ctx context.Context, input *dto.GetGNSInput) (*dto.GlobalNetworkSet, error) DeleteGNS(ctx context.Context, input *dto.DeleteGlobalNetworkSetInput) error CreateGNP(ctx context.Context, input *dto.CreateGlobalNetworkPolicyInput) error + ListGNPs(ctx context.Context, input *dto.ListGNPsInput) ([]*dto.GlobalNetworkPolicy, error) GetGNP(ctx context.Context, input *dto.GetGNPInput) (*dto.GlobalNetworkPolicy, error) DeleteGNP(ctx context.Context, input *dto.DeleteGlobalNetworkPolicyInput) error } diff --git a/cmd/bamboofwcli/command/root.go b/cmd/bamboofwcli/command/root.go index 1ef336d..8684ebc 100644 --- a/cmd/bamboofwcli/command/root.go +++ b/cmd/bamboofwcli/command/root.go @@ -22,6 +22,7 @@ Description: func Execute() { rootCMD.AddCommand(createCMD) + rootCMD.AddCommand(listCMD) rootCMD.AddCommand(getCMD) rootCMD.AddCommand(deleteCMD) rootCMD.AddCommand(versionCMD) @@ -29,12 +30,12 @@ func Execute() { rootCMD.AddCommand(&cobra.Command{ Use: "completion", DisableFlagsInUseLine: true, - Short: "Generate bash completion script for shell(bash, zsh)", + Short: "Generate a completion script for bash or zsh shell", Example: ` # Gen completion for bash shell - bbfwcli completion bash + bbfw completion bash # Gen completion for zsh shell - bbfwcli completion zsh`, + bbfw completion zsh`, Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { switch args[0] { @@ -43,7 +44,7 @@ func Execute() { case "zsh": rootCMD.GenZshCompletion(os.Stdout) default: - fmt.Fprintf(os.Stderr, "Unknown shell bash: %s\n", args[0]) + fmt.Fprintf(os.Stderr, "Unknown shell: %s\n", args[0]) } }, }) diff --git a/cmd/server/middleware/log.go b/cmd/server/middleware/log.go new file mode 100644 index 0000000..d8b937a --- /dev/null +++ b/cmd/server/middleware/log.go @@ -0,0 +1,30 @@ +package middleware + +import ( + "fmt" + "time" + + "github.com/gin-gonic/gin" +) + +var LogFormatterMiddleware = func(param gin.LogFormatterParams) string { + var statusColor, methodColor, resetColor string + if param.IsOutputColor() { + statusColor = param.StatusCodeColor() + methodColor = param.MethodColor() + resetColor = param.ResetColor() + } + + if param.Latency > time.Minute { + param.Latency = param.Latency.Truncate(time.Second) + } + return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v | %s \n", + param.TimeStamp.Format(time.RFC3339), + statusColor, param.StatusCode, resetColor, + param.Latency, + param.ClientIP, + methodColor, param.Method, resetColor, + param.Path, + param.ErrorMessage, + ) +} diff --git a/cmd/server/pkg/common/errlist/common_err.go b/cmd/server/pkg/common/errlist/common_err.go index 8631675..ff18a95 100644 --- a/cmd/server/pkg/common/errlist/common_err.go +++ b/cmd/server/pkg/common/errlist/common_err.go @@ -5,9 +5,12 @@ import "github.com/bamboo-firewall/be/cmd/server/pkg/httpbase/ierror" var ( ErrDatabase = ierror.NewCoreError("err_database", "") - ErrNotFoundHostEndpoint = ierror.NewCoreError("err_not_found_host_endpoint", "") - ErrNotFoundGlobalNetworkPolicy = ierror.NewCoreError("err_not_found_global_network_policy", "") - ErrNotFoundGlobalNetworkSet = ierror.NewCoreError("err_not_found_global_network_set", "") + ErrNotFoundHostEndpoint = ierror.NewCoreError("err_not_found_host_endpoint", "") + ErrNotFoundGlobalNetworkPolicy = ierror.NewCoreError("err_not_found_global_network_policy", "") + ErrNotFoundGlobalNetworkSet = ierror.NewCoreError("err_not_found_global_network_set", "") + ErrDuplicateHostEndpoint = ierror.NewCoreError("err_duplicate_host_endpoint", "") + ErrDuplicateGlobalNetworkPolicy = ierror.NewCoreError("err_duplicate_global_network_policy", "") + ErrDuplicateGlobalNetworkSet = ierror.NewCoreError("err_duplicate_global_network_set", "") ErrUnmarshalFailed = ierror.NewCoreError("err_unmarshal_failed", "") ) diff --git a/cmd/server/pkg/entity/common.go b/cmd/server/pkg/entity/common.go index e34f656..469f43c 100644 --- a/cmd/server/pkg/entity/common.go +++ b/cmd/server/pkg/entity/common.go @@ -1,17 +1,33 @@ package entity -type IPVersion int +import ( + "strings" -const ( - IPVersion4 IPVersion = 4 - IPVersion6 IPVersion = 6 + "github.com/google/uuid" ) -type Protocol string +const ( + IPVersion4 = 4 + IPVersion6 = 6 +) const ( - ProtocolTCP Protocol = "tcp" - ProtocolUDP Protocol = "udp" - ProtocolICMP Protocol = "icmp" - ProtocolSCTP Protocol = "sctp" + ProtocolTCP = "tcp" + ProtocolUDP = "udp" + ProtocolICMP = "icmp" + // ProtocolSCTP Stream Control Transmission Protocol (SCTP) is a network protocol that allows for the reliable transmission of data between two endpoints in a computer network. + // The Stream Control Transmission Protocol (SCTP) is a computer networking communications protocol in the transport layer of the Internet protocol suite. + // Originally intended for Signaling System 7 (SS7) message transport in telecommunication, the protocol provides the message-oriented feature of the User Datagram Protocol (UDP), + // while ensuring reliable, in-sequence transport of messages with congestion control like the Transmission Control Protocol (TCP). + //Unlike UDP and TCP, the protocol supports multihoming and redundant paths to increase resilience and reliability. + ProtocolSCTP = "sctp" + ProtocolUDPLite = "udplite" + + ProtocolNumTCP = 6 + ProtocolNumUDP = 17 + ProtocolNumSCTP = 132 ) + +func NewMinifyUUID() string { + return strings.Replace(uuid.New().String(), "-", "", -1) +} diff --git a/cmd/server/pkg/entity/gnp.go b/cmd/server/pkg/entity/gnp.go index 0dcae22..6a29d2f 100644 --- a/cmd/server/pkg/entity/gnp.go +++ b/cmd/server/pkg/entity/gnp.go @@ -6,6 +6,10 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) +const ( + PolicyOrderLowest = ^uint32(0) +) + type RuleAction string const ( @@ -32,6 +36,7 @@ type GNPMetadata struct { } type GNPSpec struct { + Order uint32 `bson:"order"` Selector string `bson:"selector,omitempty"` Ingress []GNPSpecRule `bson:"ingress,omitempty"` Egress []GNPSpecRule `bson:"egress,omitempty"` @@ -40,9 +45,9 @@ type GNPSpec struct { type GNPSpecRule struct { Metadata map[string]string `bson:"metadata,omitempty"` Action string `bson:"action"` - IPVersion IPVersion `bson:"ip_version"` - Protocol string `bson:"protocol,omitempty"` - NotProtocol string `bson:"not_protocol,omitempty"` + IPVersion *int `bson:"ip_version,omitempty"` + Protocol interface{} `bson:"protocol,omitempty"` + NotProtocol interface{} `bson:"not_protocol,omitempty"` Source *GNPSpecRuleEntity `bson:"source,omitempty"` Destination *GNPSpecRuleEntity `bson:"destination,omitempty"` } diff --git a/cmd/server/pkg/entity/gns.go b/cmd/server/pkg/entity/gns.go index 6ab1b83..24162f5 100644 --- a/cmd/server/pkg/entity/gns.go +++ b/cmd/server/pkg/entity/gns.go @@ -6,6 +6,19 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) +var ( + GNSEmpty = GlobalNetworkSet{ + ID: primitive.NewObjectID(), + UUID: NewMinifyUUID(), + Version: 1, + Metadata: GNSMetadata{ + Name: "default-empty", + }, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } +) + type GlobalNetworkSet struct { ID primitive.ObjectID `bson:"_id"` UUID string `bson:"uuid"` diff --git a/cmd/server/pkg/entity/hep.go b/cmd/server/pkg/entity/hep.go index 1ba8b8d..7ce27cd 100644 --- a/cmd/server/pkg/entity/hep.go +++ b/cmd/server/pkg/entity/hep.go @@ -6,6 +6,10 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" ) +const ( + DefaultTenantID uint64 = 1 +) + type HostEndpoint struct { ID primitive.ObjectID `bson:"_id"` UUID string `bson:"uuid"` @@ -24,6 +28,8 @@ type HostEndpointMetadata struct { type HostEndpointSpec struct { InterfaceName string `bson:"interface_name"` + IP uint32 `json:"ip"` + TenantID uint64 `bson:"tenant_id"` IPs []string `bson:"ips"` IPsV4 []string `bson:"ips_v4,omitempty"` IPsV6 []string `bson:"ips_v6,omitempty"` diff --git a/cmd/server/pkg/httpbase/error.go b/cmd/server/pkg/httpbase/error.go index 449f794..ded9e44 100644 --- a/cmd/server/pkg/httpbase/error.go +++ b/cmd/server/pkg/httpbase/error.go @@ -60,6 +60,10 @@ var ( return newClientIError(ctx, ErrorCodeBadRequest, msgID).SetHTTPStatus(http.StatusBadRequest) } + ErrBadRequest = func(ctx context.Context, msgID string) *ierror.Error { + return newClientIError(ctx, ErrorCodeBadRequest, msgID).SetHTTPStatus(http.StatusBadRequest) + } + ErrDatabase = func(ctx context.Context, msgID string) *ierror.Error { return newClientIError(ctx, ErrorCodeDatabase, msgID).SetHTTPStatus(http.StatusInternalServerError) } diff --git a/cmd/server/pkg/net/ip.go b/cmd/server/pkg/net/ip.go index 1b0401d..011dca1 100644 --- a/cmd/server/pkg/net/ip.go +++ b/cmd/server/pkg/net/ip.go @@ -1,6 +1,7 @@ package net import ( + "encoding/binary" "encoding/json" "net" ) @@ -62,3 +63,16 @@ func (i IP) Network() *IPNet { } return ipnet } + +func IPToInt(ip IP) uint32 { + if len(ip.IP) == 16 { + return binary.BigEndian.Uint32(ip.IP[12:16]) + } + return binary.BigEndian.Uint32(ip.IP) +} + +func IntToIP(nn uint32) IP { + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, nn) + return IP{ip} +} diff --git a/cmd/server/pkg/repository/gnp.go b/cmd/server/pkg/repository/gnp.go index aca8735..bb04b89 100644 --- a/cmd/server/pkg/repository/gnp.go +++ b/cmd/server/pkg/repository/gnp.go @@ -3,26 +3,79 @@ package repository import ( "context" "errors" + "fmt" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/writeconcern" "github.com/bamboo-firewall/be/cmd/server/pkg/common/errlist" "github.com/bamboo-firewall/be/cmd/server/pkg/entity" "github.com/bamboo-firewall/be/cmd/server/pkg/httpbase/ierror" + "github.com/bamboo-firewall/be/domain/model" ) -func (r *PolicyDB) UpsertGroupPolicy(ctx context.Context, gnp *entity.GlobalNetworkPolicy) *ierror.CoreError { - filter := bson.D{{Key: "_id", Value: gnp.ID}} - update := bson.D{{Key: "$set", Value: gnp}} - opts := options.Update().SetUpsert(true) - _, err := r.mongo.Database.Collection(gnp.CollectionName()).UpdateOne(ctx, filter, update, opts) +func (r *PolicyDB) UpsertGroupPolicy(ctx context.Context, gnp *entity.GlobalNetworkPolicy) (*entity.GlobalNetworkPolicy, *ierror.CoreError) { + session, err := r.mongo.Database.Client().StartSession() if err != nil { - return errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(err) } + defer session.EndSession(ctx) - return nil + sessionCallback := func(sessionCtx mongo.SessionContext) (interface{}, error) { + filter := bson.D{{Key: "metadata.name", Value: gnp.Metadata.Name}} + existedGNP := new(entity.GlobalNetworkPolicy) + err = r.mongo.Database.Collection(gnp.CollectionName()).FindOne(ctx, filter).Decode(existedGNP) + if err != nil && !errors.Is(mongo.ErrNoDocuments, err) { + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("find global network policy failed: %w", err)) + } + + // gnp is existed + if !errors.Is(mongo.ErrNoDocuments, err) { + gnp.ID = existedGNP.ID + gnp.UUID = existedGNP.UUID + gnp.Version = existedGNP.Version + gnp.CreatedAt = existedGNP.CreatedAt + } + + filter = bson.D{{Key: "_id", Value: gnp.ID}} + update := bson.D{{Key: "$set", Value: gnp}} + opts := options.Update().SetUpsert(true) + _, err = r.mongo.Database.Collection(gnp.CollectionName()).UpdateOne(ctx, filter, update, opts) + if err != nil { + if mongo.IsDuplicateKeyError(err) { + return nil, errlist.ErrDuplicateGlobalNetworkPolicy.WithChild(fmt.Errorf("global network policy already exists: %w", err)) + } + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("update gnp failed: %w", err)) + } + + updateVersion := bson.M{ + "$inc": bson.M{ + "version": 1, + }, + } + optUpdateVersions := options.FindOneAndUpdate().SetReturnDocument(options.After) + err = r.mongo.Database.Collection(gnp.CollectionName()).FindOneAndUpdate(ctx, filter, updateVersion, optUpdateVersions).Decode(gnp) + if err != nil { + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("update version gnp failed: %w", err)) + } + + return nil, nil + } + + opts := options.Transaction().SetWriteConcern(writeconcern.Majority()).SetReadConcern(readconcern.Snapshot()) + _, sessionErr := session.WithTransaction(ctx, sessionCallback, opts) + if sessionErr != nil { + var coreErr *ierror.CoreError + if errors.As(sessionErr, &coreErr) { + return nil, coreErr + } + return nil, errlist.ErrDatabase.WithChild(sessionErr) + } + + return gnp, nil } func (r *PolicyDB) GetGNPByName(ctx context.Context, name string) (*entity.GlobalNetworkPolicy, *ierror.CoreError) { @@ -34,7 +87,7 @@ func (r *PolicyDB) GetGNPByName(ctx context.Context, name string) (*entity.Globa if errors.Is(err, mongo.ErrNoDocuments) { return nil, errlist.ErrNotFoundGlobalNetworkPolicy } - return nil, errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("find global network policy failed: %w", err)) } return gnp, nil } @@ -44,19 +97,25 @@ func (r *PolicyDB) DeleteGNPByName(ctx context.Context, name string) *ierror.Cor _, err := r.mongo.Database.Collection(entity.GlobalNetworkPolicy{}.CollectionName()).DeleteOne(ctx, filter) if err != nil { - return errlist.ErrDatabase.WithChild(err) + return errlist.ErrDatabase.WithChild(fmt.Errorf("delete global network policy failed: %w", err)) } return nil } -func (r *PolicyDB) ListGNP(ctx context.Context) ([]*entity.GlobalNetworkPolicy, *ierror.CoreError) { +func (r *PolicyDB) ListGNPs(ctx context.Context, input *model.ListGNPsInput) ([]*entity.GlobalNetworkPolicy, *ierror.CoreError) { + var opts []*options.FindOptions + if input != nil { + if input.IsOrder { + opts = append(opts, options.Find().SetSort(bson.D{{"spec.order", 1}})) + } + } policies := make([]*entity.GlobalNetworkPolicy, 0) - cursor, err := r.mongo.Database.Collection(entity.GlobalNetworkPolicy{}.CollectionName()).Find(ctx, bson.D{}) + cursor, err := r.mongo.Database.Collection(entity.GlobalNetworkPolicy{}.CollectionName()).Find(ctx, bson.D{}, opts...) if err != nil { - return nil, errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("list global network policies failed: %w", err)) } if err = cursor.All(ctx, &policies); err != nil { - return nil, errlist.ErrUnmarshalFailed.WithChild(err) + return nil, errlist.ErrUnmarshalFailed.WithChild(fmt.Errorf("decode global network policies failed: %w", err)) } return policies, nil } diff --git a/cmd/server/pkg/repository/gns.go b/cmd/server/pkg/repository/gns.go index 9ebfeed..d7884a4 100644 --- a/cmd/server/pkg/repository/gns.go +++ b/cmd/server/pkg/repository/gns.go @@ -3,26 +3,79 @@ package repository import ( "context" "errors" + "fmt" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/writeconcern" "github.com/bamboo-firewall/be/cmd/server/pkg/common/errlist" "github.com/bamboo-firewall/be/cmd/server/pkg/entity" "github.com/bamboo-firewall/be/cmd/server/pkg/httpbase/ierror" ) -func (r *PolicyDB) UpsertGNS(ctx context.Context, gns *entity.GlobalNetworkSet) *ierror.CoreError { - filter := bson.D{{Key: "_id", Value: gns.ID}} - update := bson.D{{Key: "$set", Value: gns}} - opts := options.Update().SetUpsert(true) - _, err := r.mongo.Database.Collection(gns.CollectionName()).UpdateOne(ctx, filter, update, opts) +func (r *PolicyDB) UpsertGNS(ctx context.Context, gns *entity.GlobalNetworkSet) (*entity.GlobalNetworkSet, *ierror.CoreError) { + session, err := r.mongo.Database.Client().StartSession() if err != nil { - return errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(err) } + defer session.EndSession(ctx) - return nil + sessionCallback := func(sessionCtx mongo.SessionContext) (interface{}, error) { + filter := bson.D{{Key: "metadata.name", Value: gns.Metadata.Name}} + existedGNS := new(entity.GlobalNetworkSet) + err = r.mongo.Database.Collection(gns.CollectionName()).FindOne(ctx, filter).Decode(existedGNS) + if err != nil && !errors.Is(mongo.ErrNoDocuments, err) { + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("find global network set failed: %w", err)) + } + + // gns is existed + if !errors.Is(mongo.ErrNoDocuments, err) { + gns.ID = existedGNS.ID + gns.UUID = existedGNS.UUID + gns.Version = existedGNS.Version + gns.CreatedAt = existedGNS.CreatedAt + } + + filter = bson.D{{Key: "_id", Value: gns.ID}} + update := bson.D{{Key: "$set", Value: gns}} + opts := options.Update().SetUpsert(true) + _, err = r.mongo.Database.Collection(gns.CollectionName()).UpdateOne(ctx, filter, update, opts) + if err != nil { + if mongo.IsDuplicateKeyError(err) { + return nil, errlist.ErrDuplicateGlobalNetworkSet. + WithChild(fmt.Errorf("global network set already exists: %w", err)) + } + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("update gns failed: %w", err)) + } + + updateVersion := bson.M{ + "$inc": bson.M{ + "version": 1, + }, + } + optUpdateVersions := options.FindOneAndUpdate().SetReturnDocument(options.After) + err = r.mongo.Database.Collection(gns.CollectionName()).FindOneAndUpdate(ctx, filter, updateVersion, optUpdateVersions).Decode(gns) + if err != nil { + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("update version gns failed: %w", err)) + } + + return nil, nil + } + + opts := options.Transaction().SetWriteConcern(writeconcern.Majority()).SetReadConcern(readconcern.Snapshot()) + _, sessionErr := session.WithTransaction(ctx, sessionCallback, opts) + if sessionErr != nil { + var coreErr *ierror.CoreError + if errors.As(sessionErr, &coreErr) { + return nil, coreErr + } + return nil, errlist.ErrDatabase.WithChild(sessionErr) + } + + return gns, nil } func (r *PolicyDB) GetGNSByName(ctx context.Context, name string) (*entity.GlobalNetworkSet, *ierror.CoreError) { @@ -34,7 +87,7 @@ func (r *PolicyDB) GetGNSByName(ctx context.Context, name string) (*entity.Globa if errors.Is(err, mongo.ErrNoDocuments) { return nil, errlist.ErrNotFoundGlobalNetworkSet } - return nil, errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("find global network set failed: %w", err)) } return gns, nil } @@ -44,19 +97,19 @@ func (r *PolicyDB) DeleteGNSByName(ctx context.Context, name string) *ierror.Cor _, err := r.mongo.Database.Collection(entity.GlobalNetworkSet{}.CollectionName()).DeleteOne(ctx, filter) if err != nil { - return errlist.ErrDatabase.WithChild(err) + return errlist.ErrDatabase.WithChild(fmt.Errorf("delete global network set failed: %w", err)) } return nil } -func (r *PolicyDB) ListGNS(ctx context.Context) ([]*entity.GlobalNetworkSet, *ierror.CoreError) { +func (r *PolicyDB) ListGNSs(ctx context.Context) ([]*entity.GlobalNetworkSet, *ierror.CoreError) { sets := make([]*entity.GlobalNetworkSet, 0) cursor, err := r.mongo.Database.Collection(entity.GlobalNetworkSet{}.CollectionName()).Find(ctx, bson.D{}) if err != nil { - return nil, errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("list global network sets failed: %w", err)) } if err = cursor.All(ctx, &sets); err != nil { - return nil, errlist.ErrUnmarshalFailed.WithChild(err) + return nil, errlist.ErrUnmarshalFailed.WithChild(fmt.Errorf("decode global network sets failed: %w", err)) } return sets, nil } diff --git a/cmd/server/pkg/repository/hep.go b/cmd/server/pkg/repository/hep.go index b5dc82c..442aa06 100644 --- a/cmd/server/pkg/repository/hep.go +++ b/cmd/server/pkg/repository/hep.go @@ -3,30 +3,89 @@ package repository import ( "context" "errors" + "fmt" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/mongo/writeconcern" "github.com/bamboo-firewall/be/cmd/server/pkg/common/errlist" "github.com/bamboo-firewall/be/cmd/server/pkg/entity" "github.com/bamboo-firewall/be/cmd/server/pkg/httpbase/ierror" + "github.com/bamboo-firewall/be/domain/model" ) -func (r *PolicyDB) UpsertHostEndpoint(ctx context.Context, hep *entity.HostEndpoint) *ierror.CoreError { - filter := bson.D{{Key: "_id", Value: hep.ID}} - update := bson.D{{Key: "$set", Value: hep}} - opts := options.Update().SetUpsert(true) - _, err := r.mongo.Database.Collection(hep.CollectionName()).UpdateOne(ctx, filter, update, opts) +func (r *PolicyDB) UpsertHostEndpoint(ctx context.Context, hep *entity.HostEndpoint) (*entity.HostEndpoint, *ierror.CoreError) { + session, err := r.mongo.Database.Client().StartSession() if err != nil { - return errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(err) } + defer session.EndSession(ctx) - return nil + sessionCallback := func(sessionCtx mongo.SessionContext) (interface{}, error) { + filter := bson.D{{Key: "spec.tenant_id", Value: hep.Spec.TenantID}, {Key: "spec.ip", Value: hep.Spec.IP}} + existedHEP := new(entity.HostEndpoint) + err = r.mongo.Database.Collection(hep.CollectionName()).FindOne(ctx, filter).Decode(existedHEP) + if err != nil && !errors.Is(mongo.ErrNoDocuments, err) { + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("find host endpoint failed: %w", err)) + } + + // hep is existed + if !errors.Is(mongo.ErrNoDocuments, err) { + hep.ID = existedHEP.ID + hep.UUID = existedHEP.UUID + hep.Version = existedHEP.Version + hep.CreatedAt = existedHEP.CreatedAt + } + + filter = bson.D{{Key: "_id", Value: hep.ID}} + update := bson.D{{Key: "$set", Value: hep}} + opts := options.Update().SetUpsert(true) + _, err = r.mongo.Database.Collection(hep.CollectionName()).UpdateOne(ctx, filter, update, opts) + if err != nil { + if mongo.IsDuplicateKeyError(err) { + return nil, errlist.ErrDuplicateHostEndpoint.WithChild(fmt.Errorf("host endpoint already exists: %w", err)) + } + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("update host endpoint failed: %w", err)) + } + + updateVersion := bson.M{ + "$inc": bson.M{ + "version": 1, + }, + } + optUpdateVersions := options.FindOneAndUpdate().SetReturnDocument(options.After) + err = r.mongo.Database.Collection(hep.CollectionName()).FindOneAndUpdate(ctx, filter, updateVersion, optUpdateVersions).Decode(hep) + if err != nil { + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("update version host endpoint failed: %w", err)) + } + + return nil, nil + } + + opts := options.Transaction().SetWriteConcern(writeconcern.Majority()).SetReadConcern(readconcern.Snapshot()) + _, sessionErr := session.WithTransaction(ctx, sessionCallback, opts) + if sessionErr != nil { + var coreErr *ierror.CoreError + if errors.As(sessionErr, &coreErr) { + return nil, coreErr + } + return nil, errlist.ErrDatabase.WithChild(sessionErr) + } + + return hep, nil } -func (r *PolicyDB) GetHostEndpointByName(ctx context.Context, name string) (*entity.HostEndpoint, *ierror.CoreError) { - filter := bson.D{{Key: "metadata.name", Value: name}} +func (r *PolicyDB) GetHostEndpoint(ctx context.Context, input *model.GetHostEndpointInput) (*entity.HostEndpoint, *ierror.CoreError) { + var filter bson.D + if input != nil { + filter = bson.D{ + {Key: "spec.tenant_id", Value: input.TenantID}, + {Key: "spec.ip", Value: input.IP}, + } + } hep := new(entity.HostEndpoint) err := r.mongo.Database.Collection(hep.CollectionName()).FindOne(ctx, filter).Decode(hep) @@ -34,29 +93,39 @@ func (r *PolicyDB) GetHostEndpointByName(ctx context.Context, name string) (*ent if errors.Is(err, mongo.ErrNoDocuments) { return nil, errlist.ErrNotFoundHostEndpoint } - return nil, errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("get host endpoint failed: %w", err)) } return hep, nil } -func (r *PolicyDB) DeleteHostEndpointByName(ctx context.Context, name string) *ierror.CoreError { - filter := bson.D{{Key: "metadata.name", Value: name}} +func (r *PolicyDB) DeleteHostEndpoint(ctx context.Context, tenantID uint64, ip uint32) *ierror.CoreError { + filter := bson.D{{Key: "spec.tenant_id", Value: tenantID}, {Key: "spec.ip", Value: ip}} _, err := r.mongo.Database.Collection(entity.HostEndpoint{}.CollectionName()).DeleteOne(ctx, filter) if err != nil { - return errlist.ErrDatabase.WithChild(err) + return errlist.ErrDatabase.WithChild(fmt.Errorf("delete host endpoint failed: %w", err)) } return nil } -func (r *PolicyDB) ListHostEndpoints(ctx context.Context) ([]*entity.HostEndpoint, *ierror.CoreError) { +func (r *PolicyDB) ListHostEndpoints(ctx context.Context, input *model.ListHostEndpointsInput) ([]*entity.HostEndpoint, *ierror.CoreError) { + filter := bson.D{} + if input != nil { + if input.TenantID != nil { + filter = append(filter, bson.E{Key: "spec.tenant_id", Value: *input.TenantID}) + } + if input.IP != nil { + filter = append(filter, bson.E{Key: "spec.ip", Value: *input.IP}) + } + } + heps := make([]*entity.HostEndpoint, 0) - cursor, err := r.mongo.Database.Collection(entity.HostEndpoint{}.CollectionName()).Find(ctx, bson.D{}) + cursor, err := r.mongo.Database.Collection(entity.HostEndpoint{}.CollectionName()).Find(ctx, filter) if err != nil { - return nil, errlist.ErrDatabase.WithChild(err) + return nil, errlist.ErrDatabase.WithChild(fmt.Errorf("list host endpoints failed: %w", err)) } if err = cursor.All(ctx, &heps); err != nil { - return nil, errlist.ErrUnmarshalFailed.WithChild(err) + return nil, errlist.ErrUnmarshalFailed.WithChild(fmt.Errorf("decode host endpoints failed: %w", err)) } return heps, nil } diff --git a/cmd/server/pkg/storage/mongo.go b/cmd/server/pkg/storage/mongo.go index 0c6661a..01ae824 100644 --- a/cmd/server/pkg/storage/mongo.go +++ b/cmd/server/pkg/storage/mongo.go @@ -2,12 +2,16 @@ package storage import ( "context" + "fmt" "log/slog" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" + + "github.com/bamboo-firewall/be/cmd/server/pkg/entity" ) type PolicyDB struct { @@ -28,7 +32,59 @@ func NewPolicyDB(uri string) (*PolicyDB, error) { if err = client.Ping(context.Background(), readpref.Primary()); err != nil { return nil, err } - return &PolicyDB{Database: client.Database(cs.Database)}, nil + pm := &PolicyDB{ + Database: client.Database(cs.Database), + } + if err = pm.createIndexes(); err != nil { + return nil, err + } + return pm, nil +} + +func (pm *PolicyDB) createIndexes() error { + indexMap := map[string][]mongo.IndexModel{ + entity.HostEndpoint{}.CollectionName(): { + { + Keys: bson.D{{Key: "spec.tenant_id", Value: 1}, {Key: "spec.ip", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + { + Keys: bson.D{{Key: "uuid", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + }, + entity.GlobalNetworkSet{}.CollectionName(): { + { + Keys: bson.D{{Key: "metadata.name", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + { + Keys: bson.D{{Key: "uuid", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + }, + entity.GlobalNetworkPolicy{}.CollectionName(): { + { + Keys: bson.D{{Key: "metadata.name", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + { + Keys: bson.D{{Key: "spec.order", Value: 1}}, + }, + { + Keys: bson.D{{Key: "uuid", Value: 1}}, + Options: options.Index().SetUnique(true), + }, + }, + } + for collectName, indexes := range indexMap { + _, err := pm.Database.Collection(collectName).Indexes().CreateMany(context.TODO(), indexes) + if err != nil { + return fmt.Errorf("create index: %w", err) + } + } + + return nil } func (pm *PolicyDB) Stop(ctx context.Context) error { diff --git a/cmd/server/pkg/validator/validator.go b/cmd/server/pkg/validator/validator.go index f902336..81fa909 100644 --- a/cmd/server/pkg/validator/validator.go +++ b/cmd/server/pkg/validator/validator.go @@ -47,7 +47,7 @@ func Init() { registerStructValidation(validateGNSSpecInput, dto.GNSSpecInput{}) } -var nameRegex = regexp.MustCompile(`^[a-z_-]+$`) +var nameRegex = regexp.MustCompile(`^[-a-zA-Z0-9_\\.]+$`) func validateName(fl validator.FieldLevel) bool { return nameRegex.MatchString(fl.Field().String()) @@ -72,12 +72,20 @@ func validateAction(fl validator.FieldLevel) bool { func validateIPVersion(fl validator.FieldLevel) bool { ipVersion := fl.Field().Interface().(int) - return slices.Contains([]entity.IPVersion{entity.IPVersion4, entity.IPVersion6}, entity.IPVersion(ipVersion)) + return slices.Contains([]int{entity.IPVersion4, entity.IPVersion6}, ipVersion) } func validateProtocol(fl validator.FieldLevel) bool { - protocol := fl.Field().Interface().(string) - return slices.Contains([]entity.Protocol{entity.ProtocolTCP, entity.ProtocolUDP, entity.ProtocolICMP, entity.ProtocolSCTP}, entity.Protocol(strings.ToLower(protocol))) + protocol := fl.Field().Interface() + switch protocol.(type) { + case string: + return slices.Contains([]string{entity.ProtocolTCP, entity.ProtocolUDP, entity.ProtocolICMP, entity.ProtocolSCTP, entity.ProtocolUDPLite}, strings.ToLower(protocol.(string))) + case float64: + protocolNum := uint8(protocol.(float64)) + return protocolNum != 0 + default: + return false + } } func validateCIDR(fl validator.FieldLevel) bool { @@ -143,11 +151,11 @@ func validateGNPSpecInput(sl validator.StructLevel) { func validateGNPSpecRuleInput(sl validator.StructLevel) { input := sl.Current().Interface().(dto.GNPSpecRuleInput) - if input.Protocol != "" && input.NotProtocol != "" { + if input.Protocol != nil && input.NotProtocol != nil { sl.ReportError(input.NotProtocol, "notProtocol", "NotProtocol", "cannot use notProtocol with protocol", "") } - if input.Protocol != "" || input.NotProtocol != "" { - if (input.Protocol != "" && !isProtocolSupportPort(input.Protocol)) || (input.NotProtocol != "" && !isProtocolSupportPort(input.NotProtocol)) { + if input.Protocol != nil || input.NotProtocol != nil { + if (input.Protocol != nil && !isProtocolSupportPort(input.Protocol)) || (input.NotProtocol != nil && !isProtocolSupportPort(input.NotProtocol)) { if input.Source != nil { if len(input.Source.Ports) > 0 { sl.ReportError(input.Source.Ports, "notPorts", "NotPorts", "protocol not support ports", "") @@ -168,33 +176,55 @@ func validateGNPSpecRuleInput(sl validator.StructLevel) { } } + var ( + seenV4, seenV6 bool + ) + var scanNets = func(nets []string, fieldName string) { + var v4, v6 bool + for i, ipNetwork := range nets { + ip, ipnet, err := net.ParseCIDR(ipNetwork) + if err != nil { + sl.ReportError(ipNetwork, fmt.Sprintf("%s[%d]", fieldName, i), "", "net", "") + continue + } + if ip.String() != ipnet.IP.String() { + sl.ReportError(ipNetwork, fmt.Sprintf("%s[%d]", fieldName, i), "", "ip network is invalid", "") + } + if input.IPVersion != nil && ip.Version() != *input.IPVersion { + sl.ReportError(ipNetwork, fmt.Sprintf("%s[%d]", fieldName, i), "", "not match with ipVersion", "") + } + + v4 = v4 || ip.Version() == entity.IPVersion4 + v6 = v6 || ip.Version() == entity.IPVersion6 + } + + if v4 && seenV6 || v6 && seenV4 || v4 && v6 { + sl.ReportError(nets, fieldName, "", "cannot use ipV4 and ipV6 together", "") + } + + seenV4 = seenV4 || v4 + seenV6 = seenV6 || v6 + } + if input.Source != nil { - isNetSameIPVersion(sl, input.IPVersion, input.Source.Nets) - isNetSameIPVersion(sl, input.IPVersion, input.Source.NotNets) + scanNets(input.Source.Nets, "Source.Nets") + scanNets(input.Source.NotNets, "Source.NotNets") } if input.Destination != nil { - isNetSameIPVersion(sl, input.IPVersion, input.Destination.Nets) - isNetSameIPVersion(sl, input.IPVersion, input.Destination.NotNets) + scanNets(input.Destination.Nets, "Destination.Nets") + scanNets(input.Destination.NotNets, "Destination.NotNets") } } -func isProtocolSupportPort(protocol string) bool { - return slices.Contains([]entity.Protocol{entity.ProtocolTCP, entity.ProtocolUDP, entity.ProtocolSCTP}, entity.Protocol(strings.ToLower(protocol))) -} - -func isNetSameIPVersion(sl validator.StructLevel, ipVersion int, nets []string) { - for i, ipNetwork := range nets { - ip, ipnet, err := net.ParseCIDROrIP(ipNetwork) - if err != nil { - sl.ReportError(ipNetwork, fmt.Sprintf("nets[%d]", i), "", "net", "") - continue - } - if ip.String() != ipnet.IP.String() { - sl.ReportError(ipNetwork, fmt.Sprintf("nets[%d]", i), "", "ip network is invalid", "") - } - if ip.Version() != ipVersion { - sl.ReportError(ipNetwork, fmt.Sprintf("nets[%d]", i), "", "not match with ipVersion", "") - } +func isProtocolSupportPort(protocol interface{}) bool { + switch protocol.(type) { + case string: + return slices.Contains([]string{entity.ProtocolTCP, entity.ProtocolUDP, entity.ProtocolSCTP}, strings.ToLower(protocol.(string))) + case float64: + protocolNum := uint8(protocol.(float64)) + return protocolNum == entity.ProtocolNumTCP || protocolNum == entity.ProtocolNumUDP || protocolNum == entity.ProtocolNumSCTP + default: + return false } } diff --git a/cmd/server/route/route.go b/cmd/server/route/route.go index 94c6843..c427bb8 100644 --- a/cmd/server/route/route.go +++ b/cmd/server/route/route.go @@ -16,20 +16,23 @@ func RegisterHandler(repo *repository.PolicyDB) http.Handler { router.Use(gin.Recovery()) router.Use(middleware.CORS()) + router.Use(gin.LoggerWithFormatter(middleware.LogFormatterMiddleware)) router.GET("/api/v1/ping", handler.Ping) { hepHandler := handler.NewHEP(service.NewHEP(repo)) router.POST("/api/v1/hostEndpoints", hepHandler.Create) - router.GET("/api/v1/hostEndpoints/byName/:name", hepHandler.Get) + router.GET("/api/v1/hostEndpoints", hepHandler.List) + router.GET("/api/v1/hostEndpoints/byTenantID/:tenantID/byIP/:ip", hepHandler.Get) router.DELETE("/api/v1/hostEndpoints", hepHandler.Delete) - router.GET("/api/internal/v1/hostEndpoints/byName/:name/fetchPolicies", hepHandler.FetchPolicies) + router.GET("/api/internal/v1/hostEndpoints/fetchPolicies", hepHandler.FetchPolicies) } { gnpHandler := handler.NewGNP(service.NewGNP(repo)) router.POST("/api/v1/globalNetworkPolicies", gnpHandler.Create) + router.GET("/api/v1/globalNetworkPolicies", gnpHandler.List) router.GET("/api/v1/globalNetworkPolicies/byName/:name", gnpHandler.Get) router.DELETE("/api/v1/globalNetworkPolicies", gnpHandler.Delete) } @@ -37,6 +40,7 @@ func RegisterHandler(repo *repository.PolicyDB) http.Handler { { gnsHandler := handler.NewGNS(service.NewGNS(repo)) router.POST("/api/v1/globalNetworkSets", gnsHandler.Create) + router.GET("/api/v1/globalNetworkSets", gnsHandler.List) router.GET("/api/v1/globalNetworkSets/byName/:name", gnsHandler.Get) router.DELETE("/api/v1/globalNetworkSets", gnsHandler.Delete) } diff --git a/domain/model/gnp.go b/domain/model/gnp.go index 8d88bc2..8dc47c4 100644 --- a/domain/model/gnp.go +++ b/domain/model/gnp.go @@ -12,6 +12,7 @@ type GNPMetadataInput struct { } type GNPSpecInput struct { + Order *uint32 Selector string Ingress []GNPSpecRuleInput Egress []GNPSpecRuleInput @@ -20,9 +21,9 @@ type GNPSpecInput struct { type GNPSpecRuleInput struct { Metadata map[string]string Action string - Protocol string - NotProtocol string - IPVersion int + Protocol interface{} + NotProtocol interface{} + IPVersion *int Source *GNPSpecRuleEntityInput Destination *GNPSpecRuleEntityInput } @@ -34,3 +35,7 @@ type GNPSpecRuleEntityInput struct { Ports []interface{} NotPorts []interface{} } + +type ListGNPsInput struct { + IsOrder bool +} diff --git a/domain/model/hep.go b/domain/model/hep.go index 3dd9e53..11c46b4 100644 --- a/domain/model/hep.go +++ b/domain/model/hep.go @@ -15,6 +15,8 @@ type HostEndpointMetadataInput struct { type HostEndpointSpecInput struct { InterfaceName string + IP string + TenantID uint64 IPs []string Ports []HostEndpointSpecPortInput } @@ -25,19 +27,31 @@ type HostEndpointSpecPortInput struct { Protocol string } -type FetchHostEndpointPolicyInput struct { - Name string +type ListHostEndpointsInput struct { + TenantID *uint64 + IP *uint32 } -type HostEndPointPolicy struct { - MetaData HostEndPointPolicyMetadata +type GetHostEndpointInput struct { + TenantID uint64 + IP uint32 +} + +type DeleteHostEndpointInput struct { + TenantID uint64 + IP string + IPs []string +} + +type HostEndpointPolicy struct { + MetaData HostEndpointPolicyMetadata HEP *entity.HostEndpoint ParsedGNPs []*ParsedGNP ParsedHEPs []*ParsedHEP ParsedGNSs []*ParsedGNS } -type HostEndPointPolicyMetadata struct { +type HostEndpointPolicyMetadata struct { GNPVersions map[string]uint HEPVersions map[string]uint GNSVersions map[string]uint @@ -53,8 +67,8 @@ type ParsedGNP struct { type ParsedRule struct { Action string - IPVersion int - Protocol string + IPVersion *int + Protocol interface{} IsProtocolNegative bool SrcNets []string IsSrcNetNegative bool @@ -71,10 +85,12 @@ type ParsedRule struct { } type ParsedHEP struct { - UUID string - Name string - IPsV4 []string - IPsV6 []string + UUID string + Name string + TenantID uint64 + IP string + IPsV4 []string + IPsV6 []string } type ParsedGNS struct { diff --git a/domain/service/gnp.go b/domain/service/gnp.go index 4a667d6..c320e7d 100644 --- a/domain/service/gnp.go +++ b/domain/service/gnp.go @@ -5,7 +5,6 @@ import ( "errors" "time" - "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson/primitive" "github.com/bamboo-firewall/be" @@ -13,6 +12,7 @@ import ( "github.com/bamboo-firewall/be/cmd/server/pkg/entity" "github.com/bamboo-firewall/be/cmd/server/pkg/httpbase" "github.com/bamboo-firewall/be/cmd/server/pkg/httpbase/ierror" + "github.com/bamboo-firewall/be/cmd/server/pkg/net" "github.com/bamboo-firewall/be/cmd/server/pkg/repository" "github.com/bamboo-firewall/be/domain/model" ) @@ -28,31 +28,31 @@ type gnp struct { } func (ds *gnp) Create(ctx context.Context, input *model.CreateGlobalNetworkPolicyInput) (*entity.GlobalNetworkPolicy, *ierror.Error) { - // ToDo: use transaction and lock row - gnpExisted, coreErr := ds.storage.GetGNPByName(ctx, input.Metadata.Name) - if coreErr != nil && !errors.Is(coreErr, errlist.ErrNotFoundGlobalNetworkPolicy) { - return nil, httpbase.ErrDatabase(ctx, "get global network policy failed").SetSubError(coreErr) - } - var specIngress []entity.GNPSpecRule for _, rule := range input.Spec.Ingress { specIngress = append(specIngress, modelToRule(rule)) } + var order uint32 + if input.Spec.Order != nil { + order = *input.Spec.Order + } else { + order = entity.PolicyOrderLowest + } var specEgress []entity.GNPSpecRule for _, rule := range input.Spec.Egress { specEgress = append(specEgress, modelToRule(rule)) } gnpEntity := &entity.GlobalNetworkPolicy{ - ID: primitive.NewObjectID(), - UUID: uuid.New().String(), - Version: 1, + ID: primitive.NewObjectID(), + UUID: entity.NewMinifyUUID(), Metadata: entity.GNPMetadata{ Name: input.Metadata.Name, Labels: input.Metadata.Labels, }, Spec: entity.GNPSpec{ + Order: order, Selector: input.Spec.Selector, Ingress: specIngress, Egress: specEgress, @@ -61,15 +61,13 @@ func (ds *gnp) Create(ctx context.Context, input *model.CreateGlobalNetworkPolic CreatedAt: time.Now(), UpdatedAt: time.Now(), } - if gnpExisted != nil { - gnpEntity.ID = gnpExisted.ID - gnpEntity.UUID = gnpExisted.UUID - gnpEntity.Version = gnpExisted.Version + 1 - gnpEntity.CreatedAt = gnpExisted.CreatedAt - } - if coreErr = ds.storage.UpsertGroupPolicy(ctx, gnpEntity); coreErr != nil { - return nil, httpbase.ErrDatabase(ctx, "create global network failed").SetSubError(coreErr) + gnpEntity, coreErr := ds.storage.UpsertGroupPolicy(ctx, gnpEntity) + if coreErr != nil { + if errors.Is(coreErr, errlist.ErrDuplicateGlobalNetworkPolicy) { + return nil, httpbase.ErrBadRequest(ctx, "duplicate global network policy").SetSubError(coreErr) + } + return nil, httpbase.ErrDatabase(ctx, "create global network policy failed").SetSubError(coreErr) } return gnpEntity, nil } @@ -92,13 +90,21 @@ func (ds *gnp) Delete(ctx context.Context, name string) *ierror.Error { return nil } +func (ds *gnp) List(ctx context.Context, input *model.ListGNPsInput) ([]*entity.GlobalNetworkPolicy, *ierror.Error) { + gnpsEntity, coreErr := ds.storage.ListGNPs(ctx, input) + if coreErr != nil { + return nil, httpbase.ErrDatabase(ctx, "list global network policies failed").SetSubError(coreErr) + } + return gnpsEntity, nil +} + func modelToRule(rule model.GNPSpecRuleInput) entity.GNPSpecRule { return entity.GNPSpecRule{ Metadata: rule.Metadata, Action: rule.Action, Protocol: rule.Protocol, NotProtocol: rule.NotProtocol, - IPVersion: entity.IPVersion(rule.IPVersion), + IPVersion: rule.IPVersion, Source: modelToRuleEntity(rule.Source), Destination: modelToRuleEntity(rule.Destination), } @@ -110,9 +116,20 @@ func modelToRuleEntity(ruleEntity *model.GNPSpecRuleEntityInput) *entity.GNPSpec } return &entity.GNPSpecRuleEntity{ Selector: ruleEntity.Selector, - Nets: ruleEntity.Nets, - NotNets: ruleEntity.NotNets, + Nets: parseNets(ruleEntity.Nets), + NotNets: parseNets(ruleEntity.NotNets), Ports: ruleEntity.Ports, NotPorts: ruleEntity.NotPorts, } } + +func parseNets(nets []string) []string { + var netResults []string + for _, n := range nets { + _, ipnet, err := net.ParseCIDROrIP(n) + if err == nil { + netResults = append(netResults, ipnet.String()) + } + } + return netResults +} diff --git a/domain/service/gns.go b/domain/service/gns.go index 3ea16e7..f6d56fb 100644 --- a/domain/service/gns.go +++ b/domain/service/gns.go @@ -6,7 +6,6 @@ import ( "log/slog" "time" - "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson/primitive" "github.com/bamboo-firewall/be" @@ -30,17 +29,10 @@ type gns struct { } func (ds *gns) Create(ctx context.Context, input *model.CreateGlobalNetworkSetInput) (*entity.GlobalNetworkSet, *ierror.Error) { - // ToDo: use transaction and lock row - gnsExisted, coreErr := ds.storage.GetGNSByName(ctx, input.Metadata.Name) - if coreErr != nil && !errors.Is(coreErr, errlist.ErrNotFoundGlobalNetworkSet) { - return nil, httpbase.ErrDatabase(ctx, "get global network set failed").SetSubError(coreErr) - } - netsV4, netsV6 := exactNets(input.Spec.Nets) gnsEntity := &entity.GlobalNetworkSet{ - ID: primitive.NewObjectID(), - UUID: uuid.New().String(), - Version: 1, + ID: primitive.NewObjectID(), + UUID: entity.NewMinifyUUID(), Metadata: entity.GNSMetadata{ Name: input.Metadata.Name, Labels: input.Metadata.Labels, @@ -54,14 +46,12 @@ func (ds *gns) Create(ctx context.Context, input *model.CreateGlobalNetworkSetIn CreatedAt: time.Now(), UpdatedAt: time.Now(), } - if gnsExisted != nil { - gnsEntity.ID = gnsExisted.ID - gnsEntity.UUID = gnsExisted.UUID - gnsEntity.Version = gnsExisted.Version + 1 - gnsEntity.CreatedAt = gnsExisted.CreatedAt - } - if coreErr = ds.storage.UpsertGNS(ctx, gnsEntity); coreErr != nil { + gnsEntity, coreErr := ds.storage.UpsertGNS(ctx, gnsEntity) + if coreErr != nil { + if errors.Is(coreErr, errlist.ErrDuplicateGlobalNetworkSet) { + return nil, httpbase.ErrBadRequest(ctx, "duplicate global network set").SetSubError(coreErr) + } return nil, httpbase.ErrDatabase(ctx, "create global network set failed").SetSubError(coreErr) } return gnsEntity, nil @@ -80,9 +70,9 @@ func exactNets(nets []string) (netsV4 []string, netsV6 []string) { } else { netV4V6 = ip.Network().String() } - if ip.Version() == int(entity.IPVersion4) { + if ip.Version() == entity.IPVersion4 { netsV4 = append(netsV4, netV4V6) - } else if ip.Version() == int(entity.IPVersion6) { + } else if ip.Version() == entity.IPVersion6 { netsV6 = append(netsV6, netV4V6) } } @@ -100,6 +90,14 @@ func (ds *gns) Get(ctx context.Context, name string) (*entity.GlobalNetworkSet, return gnsEntity, nil } +func (ds *gns) List(ctx context.Context) ([]*entity.GlobalNetworkSet, *ierror.Error) { + gnssEntity, coreErr := ds.storage.ListGNSs(ctx) + if coreErr != nil { + return nil, httpbase.ErrDatabase(ctx, "list global network sets failed").SetSubError(coreErr) + } + return gnssEntity, nil +} + func (ds *gns) Delete(ctx context.Context, name string) *ierror.Error { if coreErr := ds.storage.DeleteGNSByName(ctx, name); coreErr != nil { return httpbase.ErrDatabase(ctx, "delete global network set failed").SetSubError(coreErr) diff --git a/domain/service/hep.go b/domain/service/hep.go index 9dbbd96..9ed20de 100644 --- a/domain/service/hep.go +++ b/domain/service/hep.go @@ -7,7 +7,6 @@ import ( "log/slog" "time" - "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson/primitive" "github.com/bamboo-firewall/be" @@ -32,23 +31,33 @@ type hep struct { } func (ds *hep) Create(ctx context.Context, input *model.CreateHostEndpointInput) (*entity.HostEndpoint, *ierror.Error) { - // ToDo: use transaction and lock row - hepExisted, coreErr := ds.storage.GetHostEndpointByName(ctx, input.Metadata.Name) - if coreErr != nil && !errors.Is(coreErr, errlist.ErrNotFoundHostEndpoint) { - return nil, httpbase.ErrDatabase(ctx, "get host endpoint failed").SetSubError(coreErr) + ipsV4, ipsV6 := exactIPs(input.Spec.IPs) + if len(ipsV4) == 0 { + return nil, httpbase.ErrBadRequest(ctx, "required at least one ip version 4") + } + if input.Spec.TenantID == 0 { + input.Spec.TenantID = entity.DefaultTenantID + } + var ipString string + if input.Spec.IP == "" { + ipString = ipsV4[0] + } else { + ipString = input.Spec.IP } - ipsV4, ipsV6 := exactIPs(input.Spec.IPs) + ip := net.ParseIP(ipString) + hepEntity := &entity.HostEndpoint{ - ID: primitive.NewObjectID(), - UUID: uuid.New().String(), - Version: 1, + ID: primitive.NewObjectID(), + UUID: entity.NewMinifyUUID(), Metadata: entity.HostEndpointMetadata{ Name: input.Metadata.Name, Labels: input.Metadata.Labels, }, Spec: entity.HostEndpointSpec{ InterfaceName: input.Spec.InterfaceName, + IP: net.IPToInt(*ip), + TenantID: input.Spec.TenantID, IPs: input.Spec.IPs, IPsV4: ipsV4, IPsV6: ipsV6, @@ -57,14 +66,11 @@ func (ds *hep) Create(ctx context.Context, input *model.CreateHostEndpointInput) CreatedAt: time.Now(), UpdatedAt: time.Now(), } - if hepExisted != nil { - hepEntity.ID = hepExisted.ID - hepEntity.UUID = hepExisted.UUID - hepEntity.Version = hepExisted.Version + 1 - hepEntity.CreatedAt = hepExisted.CreatedAt - } - - if coreErr = ds.storage.UpsertHostEndpoint(ctx, hepEntity); coreErr != nil { + hepEntity, coreErr := ds.storage.UpsertHostEndpoint(ctx, hepEntity) + if coreErr != nil { + if errors.Is(coreErr, errlist.ErrDuplicateHostEndpoint) { + return nil, httpbase.ErrBadRequest(ctx, "duplicate host endpoint").SetSubError(coreErr) + } return nil, httpbase.ErrDatabase(ctx, "create host endpoint failed").SetSubError(coreErr) } return hepEntity, nil @@ -86,8 +92,8 @@ func exactIPs(ips []string) (ipsV4, ipsV6 []string) { return } -func (ds *hep) Get(ctx context.Context, name string) (*entity.HostEndpoint, *ierror.Error) { - hepEntity, coreErr := ds.storage.GetHostEndpointByName(ctx, name) +func (ds *hep) Get(ctx context.Context, input *model.GetHostEndpointInput) (*entity.HostEndpoint, *ierror.Error) { + hepEntity, coreErr := ds.storage.GetHostEndpoint(ctx, input) if coreErr != nil { if errors.Is(coreErr, errlist.ErrNotFoundHostEndpoint) { return nil, httpbase.ErrNotFound(ctx, "not found").SetSubError(coreErr) @@ -97,88 +103,117 @@ func (ds *hep) Get(ctx context.Context, name string) (*entity.HostEndpoint, *ier return hepEntity, nil } -func (ds *hep) Delete(ctx context.Context, name string) *ierror.Error { - if coreErr := ds.storage.DeleteHostEndpointByName(ctx, name); coreErr != nil { - return httpbase.ErrDatabase(ctx, "delete host endpoint failed").SetSubError(coreErr) +func (ds *hep) List(ctx context.Context, input *model.ListHostEndpointsInput) ([]*entity.HostEndpoint, *ierror.Error) { + hepsEntity, coreErr := ds.storage.ListHostEndpoints(ctx, input) + if coreErr != nil { + return nil, httpbase.ErrDatabase(ctx, "list host endpoints failed").SetSubError(coreErr) } - return nil + return hepsEntity, nil } -func (ds *hep) FetchPolicies(ctx context.Context, input *model.FetchHostEndpointPolicyInput) (*model.HostEndPointPolicy, *ierror.Error) { - hepEntity, coreErr := ds.storage.GetHostEndpointByName(ctx, input.Name) - if coreErr != nil { - if errors.Is(coreErr, errlist.ErrNotFoundHostEndpoint) { - return nil, httpbase.ErrNotFound(ctx, "not found").SetSubError(coreErr) +func (ds *hep) Delete(ctx context.Context, input *model.DeleteHostEndpointInput) *ierror.Error { + if input.TenantID == 0 { + input.TenantID = entity.DefaultTenantID + } + var ipString string + if input.IP == "" { + ipsV4, _ := exactIPs(input.IPs) + if len(ipsV4) == 0 { + return httpbase.ErrBadRequest(ctx, "required at least one ip version 4") } - return nil, httpbase.ErrDatabase(ctx, "get host endpoint failed").SetSubError(coreErr) + ipString = ipsV4[0] + } else { + ipString = input.IP } - heps, coreErr := ds.storage.ListHostEndpoints(ctx) + ip := net.ParseIP(ipString) + if coreErr := ds.storage.DeleteHostEndpoint(ctx, input.TenantID, net.IPToInt(*ip)); coreErr != nil { + return httpbase.ErrDatabase(ctx, "delete host endpoint failed").SetSubError(coreErr) + } + return nil +} + +func (ds *hep) FetchPolicies(ctx context.Context, input *model.ListHostEndpointsInput) ([]*model.HostEndpointPolicy, *ierror.Error) { + heps, coreErr := ds.storage.ListHostEndpoints(ctx, nil) if coreErr != nil { return nil, httpbase.ErrDatabase(ctx, "list host endpoint failed").SetSubError(coreErr) } - gnps, err := ds.storage.ListGNP(ctx) - if err != nil { + gnps, coreErr := ds.storage.ListGNPs(ctx, &model.ListGNPsInput{IsOrder: true}) + if coreErr != nil { return nil, httpbase.ErrDatabase(ctx, "list global network policy failed").SetSubError(coreErr) } - gnss, err := ds.storage.ListGNS(ctx) - if err != nil { + gnss, coreErr := ds.storage.ListGNSs(ctx) + if coreErr != nil { return nil, httpbase.ErrDatabase(ctx, "list global network set failed").SetSubError(coreErr) } var ( - parsedGNPs []*model.ParsedGNP - gnpVersions = make(map[string]uint) + hepPolicies []*model.HostEndpointPolicy ) - rp := &ruleParser{ - parsedHEPsMap: make(map[string]struct{}), - hepVersions: make(map[string]uint), - parsedGNSsMap: make(map[string]struct{}), - gnsVersions: make(map[string]uint), - } - - for _, policy := range gnps { - sel, errParse := selector.Parse(policy.Spec.Selector) - if errParse != nil { - slog.Warn("malformed selector", "policy_uuid", policy.UUID, "selector", policy.Spec.Selector, "err", errParse) - continue - } - if !sel.Evaluate(hepEntity.Metadata.Labels) { - continue + for _, hepEntity := range heps { + if input != nil { + if input.TenantID != nil && input.IP != nil { + if hepEntity.Spec.TenantID != *input.TenantID || hepEntity.Spec.IP != *input.IP { + continue + } + } } - gnpVersions[policy.UUID] = policy.Version - inboundRules := make([]*model.ParsedRule, 0) - outboundRules := make([]*model.ParsedRule, 0) - for _, rule := range policy.Spec.Ingress { - inboundRules = append(inboundRules, rp.parseRule(policy, &rule, heps, gnss)) + rp := &ruleParser{ + parsedHEPsMap: make(map[string]struct{}), + hepVersions: make(map[string]uint), + parsedGNSsMap: make(map[string]struct{}), + gnsVersions: make(map[string]uint), } - for _, rule := range policy.Spec.Egress { - outboundRules = append(outboundRules, rp.parseRule(policy, &rule, heps, gnss)) + + var ( + parsedGNPs []*model.ParsedGNP + gnpVersions = make(map[string]uint) + ) + for _, policy := range gnps { + sel, errParse := selector.Parse(policy.Spec.Selector) + if errParse != nil { + slog.Warn("malformed selector", "policy_uuid", policy.UUID, "selector", policy.Spec.Selector, "err", errParse) + continue + } + if !sel.Evaluate(hepEntity.Metadata.Labels) { + continue + } + gnpVersions[policy.UUID] = policy.Version + + inboundRules := make([]*model.ParsedRule, 0) + outboundRules := make([]*model.ParsedRule, 0) + for _, rule := range policy.Spec.Ingress { + inboundRules = append(inboundRules, rp.parseRule(policy, &rule, heps, gnss)) + } + for _, rule := range policy.Spec.Egress { + outboundRules = append(outboundRules, rp.parseRule(policy, &rule, heps, gnss)) + } + parsedGNPs = append(parsedGNPs, &model.ParsedGNP{ + UUID: policy.UUID, + Version: policy.Version, + Name: policy.Metadata.Name, + InboundRules: inboundRules, + OutboundRules: outboundRules, + }) } - parsedGNPs = append(parsedGNPs, &model.ParsedGNP{ - UUID: policy.UUID, - Version: policy.Version, - Name: policy.Metadata.Name, - InboundRules: inboundRules, - OutboundRules: outboundRules, + hepPolicies = append(hepPolicies, &model.HostEndpointPolicy{ + MetaData: model.HostEndpointPolicyMetadata{ + GNPVersions: gnpVersions, + HEPVersions: rp.hepVersions, + GNSVersions: rp.gnsVersions, + }, + HEP: hepEntity, + ParsedGNPs: parsedGNPs, + ParsedHEPs: rp.parsedHEPs, + ParsedGNSs: rp.parsedGNSs, }) } - return &model.HostEndPointPolicy{ - MetaData: model.HostEndPointPolicyMetadata{ - GNPVersions: gnpVersions, - HEPVersions: rp.hepVersions, - GNSVersions: rp.gnsVersions, - }, - HEP: hepEntity, - ParsedGNPs: parsedGNPs, - ParsedHEPs: rp.parsedHEPs, - ParsedGNSs: rp.parsedGNSs, - }, nil + return hepPolicies, nil } type ruleParser struct { @@ -190,9 +225,10 @@ type ruleParser struct { gnsVersions map[string]uint } -func (r *ruleParser) parseRule(policy *entity.GlobalNetworkPolicy, rule *entity.GNPSpecRule, heps []*entity.HostEndpoint, gnss []*entity.GlobalNetworkSet) *model.ParsedRule { +func (r *ruleParser) parseRule(policy *entity.GlobalNetworkPolicy, rule *entity.GNPSpecRule, heps []*entity.HostEndpoint, + gnss []*entity.GlobalNetworkSet) *model.ParsedRule { var ( - protocol string + protocol interface{} isProtocolNegative bool srcGNSUUIDs []string srcHEPUUIDs []string @@ -207,62 +243,72 @@ func (r *ruleParser) parseRule(policy *entity.GlobalNetworkPolicy, rule *entity. dstPorts []string isDstPortNegative bool ) - if rule.Protocol != "" { + if rule.Protocol != nil { protocol = rule.Protocol isProtocolNegative = false - } else if rule.NotProtocol != "" { + } else if rule.NotProtocol != nil { protocol = rule.NotProtocol isProtocolNegative = true } - // get global network set match if selector is available + var ruleIPVersion int if rule.Source != nil { - if len(rule.Source.Selector) > 0 { - for { - selSource, errParseSource := selector.Parse(rule.Source.Selector) - if errParseSource != nil { - slog.Warn("malformed selector in source", "policy_uuid", policy.UUID, "selector", rule.Source.Selector, "err", errParseSource) - break - } - for _, ep := range heps { - if !selSource.Evaluate(ep.Metadata.Labels) { - continue - } - if !((rule.IPVersion == entity.IPVersion4 && len(ep.Spec.IPsV4) > 0) || (rule.IPVersion == entity.IPVersion6 && len(ep.Spec.IPsV6) > 0)) { - continue - } - srcHEPUUIDs = append(srcHEPUUIDs, ep.UUID) - if _, ok := r.parsedHEPsMap[ep.UUID]; !ok { - r.parsedHEPsMap[ep.UUID] = struct{}{} - r.hepVersions[ep.UUID] = ep.Version - r.parsedHEPs = append(r.parsedHEPs, entityToParsedHEP(ep)) - } - } - for _, set := range gnss { - if !selSource.Evaluate(set.Metadata.Labels) { - continue - } - if !((rule.IPVersion == entity.IPVersion4 && len(set.Spec.NetsV4) > 0) || (rule.IPVersion == entity.IPVersion6 && len(set.Spec.NetsV6) > 0)) { - continue - } - srcGNSUUIDs = append(srcGNSUUIDs, set.UUID) - if _, ok := r.parsedGNSsMap[set.UUID]; !ok { - r.parsedGNSsMap[set.UUID] = struct{}{} - r.gnsVersions[set.UUID] = set.Version - r.parsedGNSs = append(r.parsedGNSs, entityToParsedGNS(set)) - } - } - break + if len(rule.Source.Nets) > 0 { + ip, _, err := net.ParseCIDR(rule.Source.Nets[0]) + if err == nil { + ruleIPVersion = ip.Version() } - } - if len(rule.Source.Nets) > 0 { srcNets = rule.Source.Nets isSrcNetNegative = false - } else if len(rule.Source.Nets) > 0 { + } else if len(rule.Source.NotNets) > 0 { + ip, _, err := net.ParseCIDR(rule.Source.NotNets[0]) + if err == nil { + ruleIPVersion = ip.Version() + } + srcNets = rule.Source.NotNets isSrcNetNegative = true } + } + if rule.Destination != nil { + if len(rule.Destination.Nets) > 0 { + ip, _, err := net.ParseCIDR(rule.Destination.Nets[0]) + if err == nil { + ruleIPVersion = ip.Version() + } + + dstNets = rule.Destination.Nets + isDstNetNegative = false + } else if len(rule.Destination.NotNets) > 0 { + ip, _, err := net.ParseCIDR(rule.Destination.NotNets[0]) + if err == nil { + ruleIPVersion = ip.Version() + } + + dstNets = rule.Destination.NotNets + isDstNetNegative = true + } + } + if rule.IPVersion == nil && ruleIPVersion > 0 { + rule.IPVersion = &ruleIPVersion + } + + // get host endpoint and global network set match if selector is available + if rule.Source != nil { + if len(rule.Source.Selector) > 0 { + hepUUIDs, gnsUUIDs, err := r.handleSelector(rule.Source.Selector, rule.IPVersion, heps, gnss) + if err != nil { + slog.Warn("malformed selector in source", "policy_uuid", policy.UUID, "selector", rule.Source.Selector, "err", err) + } + if len(hepUUIDs) > 0 { + srcHEPUUIDs = append(srcHEPUUIDs, hepUUIDs...) + } + if len(gnsUUIDs) > 0 { + srcGNSUUIDs = append(srcGNSUUIDs, gnsUUIDs...) + } + } + if len(rule.Source.Ports) > 0 { srcPorts = convertPorts(rule.Source.Ports) isSrcPortNegative = false @@ -274,51 +320,18 @@ func (r *ruleParser) parseRule(policy *entity.GlobalNetworkPolicy, rule *entity. // get global network set match if selector is available if rule.Destination != nil { if len(rule.Destination.Selector) > 0 { - for { - selDst, errParseDst := selector.Parse(rule.Destination.Selector) - if errParseDst != nil { - slog.Warn("malformed selector in destination", "policy_uuid", policy.UUID, "selector", rule.Source.Selector, "err", errParseDst) - break - } - for _, ep := range heps { - if !selDst.Evaluate(ep.Metadata.Labels) { - continue - } - if !((rule.IPVersion == entity.IPVersion4 && len(ep.Spec.IPsV4) > 0) || (rule.IPVersion == entity.IPVersion6 && len(ep.Spec.IPsV6) > 0)) { - continue - } - dstHEPUUIDs = append(dstHEPUUIDs, ep.UUID) - if _, ok := r.parsedHEPsMap[ep.UUID]; !ok { - r.parsedHEPsMap[ep.UUID] = struct{}{} - r.hepVersions[ep.UUID] = ep.Version - r.parsedHEPs = append(r.parsedHEPs, entityToParsedHEP(ep)) - } - } - for _, set := range gnss { - if !selDst.Evaluate(set.Metadata.Labels) { - continue - } - if !((rule.IPVersion == entity.IPVersion4 && len(set.Spec.NetsV4) > 0) || (rule.IPVersion == entity.IPVersion6 && len(set.Spec.NetsV6) > 0)) { - continue - } - dstGNSUUIDs = append(dstGNSUUIDs, set.Metadata.Name) - if _, ok := r.parsedGNSsMap[set.UUID]; !ok { - r.parsedGNSsMap[set.UUID] = struct{}{} - r.gnsVersions[set.UUID] = set.Version - r.parsedGNSs = append(r.parsedGNSs, entityToParsedGNS(set)) - } - } - break + hepUUIDs, gnsUUIDs, err := r.handleSelector(rule.Destination.Selector, rule.IPVersion, heps, gnss) + if err != nil { + slog.Warn("malformed selector in destination", "policy_uuid", policy.UUID, "selector", rule.Source.Selector, "err", err) + } + if len(hepUUIDs) > 0 { + dstHEPUUIDs = append(dstHEPUUIDs, hepUUIDs...) + } + if len(gnsUUIDs) > 0 { + dstGNSUUIDs = append(dstGNSUUIDs, gnsUUIDs...) } } - if len(rule.Destination.Nets) > 0 { - dstNets = rule.Destination.Nets - isDstNetNegative = false - } else if len(rule.Destination.NotNets) > 0 { - dstNets = rule.Destination.NotNets - isDstNetNegative = true - } if len(rule.Destination.Ports) > 0 { dstPorts = convertPorts(rule.Destination.Ports) isDstPortNegative = false @@ -329,7 +342,7 @@ func (r *ruleParser) parseRule(policy *entity.GlobalNetworkPolicy, rule *entity. } return &model.ParsedRule{ Action: rule.Action, - IPVersion: int(rule.IPVersion), + IPVersion: rule.IPVersion, Protocol: protocol, IsProtocolNegative: isProtocolNegative, SrcGNSUUIDs: srcGNSUUIDs, @@ -347,12 +360,72 @@ func (r *ruleParser) parseRule(policy *entity.GlobalNetworkPolicy, rule *entity. } } +func (r *ruleParser) handleSelector(selectorString string, ruleIPVersion *int, heps []*entity.HostEndpoint, + gnss []*entity.GlobalNetworkSet) ([]string, []string, error) { + var ( + hepUUIDs []string + gnsUUIDs []string + ) + sel, errParse := selector.Parse(selectorString) + if errParse != nil { + return nil, nil, fmt.Errorf("parse selector for rule failed: %w", errParse) + } + for _, ep := range heps { + if !sel.Evaluate(ep.Metadata.Labels) { + continue + } + if ruleIPVersion != nil { + if !((*ruleIPVersion == entity.IPVersion4 && len(ep.Spec.IPsV4) > 0) || (*ruleIPVersion == entity.IPVersion6 && len(ep.Spec.IPsV6) > 0)) { + continue + } + } + hepUUIDs = append(hepUUIDs, ep.UUID) + if _, ok := r.parsedHEPsMap[ep.UUID]; !ok { + r.parsedHEPsMap[ep.UUID] = struct{}{} + r.hepVersions[ep.UUID] = ep.Version + r.parsedHEPs = append(r.parsedHEPs, entityToParsedHEP(ep)) + } + } + + for _, set := range gnss { + if !sel.Evaluate(set.Metadata.Labels) { + continue + } + if ruleIPVersion != nil { + if !((*ruleIPVersion == entity.IPVersion4 && len(set.Spec.NetsV4) > 0) || (*ruleIPVersion == entity.IPVersion6 && len(set.Spec.NetsV6) > 0)) { + continue + } + } + gnsUUIDs = append(gnsUUIDs, set.UUID) + if _, ok := r.parsedGNSsMap[set.UUID]; !ok { + r.parsedGNSsMap[set.UUID] = struct{}{} + r.gnsVersions[set.UUID] = set.Version + r.parsedGNSs = append(r.parsedGNSs, entityToParsedGNS(set)) + } + } + + // if selector not match any hep and gns. Using match empty to prevent + if len(gnsUUIDs) == 0 && len(hepUUIDs) == 0 { + set := entity.GNSEmpty + gnsUUIDs = append(gnsUUIDs, set.UUID) + if _, ok := r.parsedGNSsMap[set.UUID]; !ok { + r.parsedGNSsMap[set.UUID] = struct{}{} + r.gnsVersions[set.UUID] = set.Version + r.parsedGNSs = append(r.parsedGNSs, entityToParsedGNS(&set)) + } + } + + return hepUUIDs, gnsUUIDs, nil +} + func entityToParsedHEP(hep *entity.HostEndpoint) *model.ParsedHEP { return &model.ParsedHEP{ - UUID: hep.UUID, - Name: hep.Metadata.Name, - IPsV4: hep.Spec.IPsV4, - IPsV6: hep.Spec.IPsV6, + UUID: hep.UUID, + Name: hep.Metadata.Name, + TenantID: hep.Spec.TenantID, + IP: net.IntToIP(hep.Spec.IP).String(), + IPsV4: hep.Spec.IPsV4, + IPsV6: hep.Spec.IPsV6, } } diff --git a/hep.go b/hep.go deleted file mode 100644 index 707ee9a..0000000 --- a/hep.go +++ /dev/null @@ -1,4 +0,0 @@ -package be - -type HostEndPoint struct { -} diff --git a/pkg/client/gnp.go b/pkg/client/gnp.go index 7e73810..52c8b63 100644 --- a/pkg/client/gnp.go +++ b/pkg/client/gnp.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "github.com/bamboo-firewall/be/api/v1/dto" "github.com/bamboo-firewall/be/cmd/server/pkg/httpbase" @@ -31,6 +32,28 @@ func (c *apiServer) CreateGNP(ctx context.Context, input *dto.CreateGlobalNetwor return nil } +func (c *apiServer) ListGNPs(ctx context.Context, input *dto.ListGNPsInput) ([]*dto.GlobalNetworkPolicy, error) { + res := c.client.NewRequest(). + SetSubURL("/api/v1/globalNetworkPolicies"). + SetParam("isOrder", strconv.FormatBool(input.IsOrder)). + SetMethod(http.MethodGet). + DoRequest(ctx) + + if res.Err != nil { + return nil, fmt.Errorf("failed to list gnp by name: %w", res.Err) + } + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code when list gnp, status code: %d, response: %s", res.StatusCode, res.Body) + } + + var gnps []*dto.GlobalNetworkPolicy + if err := json.Unmarshal(res.Body, &gnps); err != nil { + return nil, fmt.Errorf("failed to unmarshal when list gnp, response: %s, err: %w", string(res.Body), err) + } + return gnps, nil +} + func (c *apiServer) GetGNP(ctx context.Context, input *dto.GetGNPInput) (*dto.GlobalNetworkPolicy, error) { res := c.client.NewRequest(). SetSubURL(fmt.Sprintf("/api/v1/globalNetworkPolicies/byName/%s", input.Name)). diff --git a/pkg/client/gns.go b/pkg/client/gns.go index fc5a558..5c93157 100644 --- a/pkg/client/gns.go +++ b/pkg/client/gns.go @@ -31,6 +31,27 @@ func (c *apiServer) CreateGNS(ctx context.Context, input *dto.CreateGlobalNetwor return nil } +func (c *apiServer) ListGNSs(ctx context.Context) ([]*dto.GlobalNetworkSet, error) { + res := c.client.NewRequest(). + SetSubURL("/api/v1/globalNetworkSets"). + SetMethod(http.MethodGet). + DoRequest(ctx) + + if res.Err != nil { + return nil, fmt.Errorf("failed to list gnss by name: %w", res.Err) + } + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code when list gnss, status code: %d, response: %s", res.StatusCode, res.Body) + } + + var gnss []*dto.GlobalNetworkSet + if err := json.Unmarshal(res.Body, &gnss); err != nil { + return nil, fmt.Errorf("failed to unmarshal when list gnss, response: %s, err: %w", string(res.Body), err) + } + return gnss, nil +} + func (c *apiServer) GetGNS(ctx context.Context, input *dto.GetGNSInput) (*dto.GlobalNetworkSet, error) { res := c.client.NewRequest(). SetSubURL(fmt.Sprintf("/api/v1/globalNetworkSets/byName/%s", input.Name)). diff --git a/pkg/client/hep.go b/pkg/client/hep.go index 3b1112c..c175189 100644 --- a/pkg/client/hep.go +++ b/pkg/client/hep.go @@ -31,23 +31,54 @@ func (c *apiServer) CreateHEP(ctx context.Context, input *dto.CreateHostEndpoint return nil } +func (c *apiServer) ListHEPs(ctx context.Context, input *dto.ListHostEndpointsInput) ([]*dto.HostEndpoint, error) { + params := make(map[string]string) + if input != nil { + if input.TenantID != nil { + params["tenantID"] = fmt.Sprint(*input.TenantID) + } + if input.IP != nil { + params["ip"] = *input.IP + } + } + res := c.client.NewRequest(). + SetSubURL("/api/v1/hostEndpoints"). + SetParams(params). + SetMethod(http.MethodGet). + DoRequest(ctx) + + if res.Err != nil { + return nil, fmt.Errorf("failed to list hostendpoint: %w", res.Err) + } + + if res.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code when list hostendpoint, status code: %d, response: %s", res.StatusCode, res.Body) + } + + var heps []*dto.HostEndpoint + if err := json.Unmarshal(res.Body, &heps); err != nil { + return nil, fmt.Errorf("failed to unmarshal when list hostendpoint, response: %s, err: %w", string(res.Body), err) + } + return heps, nil +} + func (c *apiServer) GetHEP(ctx context.Context, input *dto.GetHostEndpointInput) (*dto.HostEndpoint, error) { res := c.client.NewRequest(). - SetSubURL(fmt.Sprintf("/api/v1/hostEndpoints/byName/%s", input.Name)). + SetSubURL(fmt.Sprintf("/api/v1/hostEndpoints/byTenantID/%d/byIP/%s", input.TenantID, input.IP)). SetMethod(http.MethodGet). DoRequest(ctx) if res.Err != nil { - return nil, fmt.Errorf("failed to get hostendpoint by name: %w", res.Err) + return nil, fmt.Errorf("failed to get hostendpoint: %w", res.Err) } if res.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected status code when get hostendpoint by name, status code: %d, response: %s", res.StatusCode, res.Body) + return nil, fmt.Errorf("unexpected status code when get hostendpoint, status code: %d, response: %s", res.StatusCode, res.Body) } var hep *dto.HostEndpoint if err := json.Unmarshal(res.Body, &hep); err != nil { - return nil, fmt.Errorf("failed to unmarshal when get hostendpoint by name, response: %s, err: %w", string(res.Body), err) + return nil, fmt.Errorf("failed to unmarshal when get hostendpoint, response: %s, err: %w", string(res.Body), err) } return hep, nil } diff --git a/storage.go b/storage.go index 8fadcd0..200e4ff 100644 --- a/storage.go +++ b/storage.go @@ -5,19 +5,20 @@ import ( "github.com/bamboo-firewall/be/cmd/server/pkg/entity" "github.com/bamboo-firewall/be/cmd/server/pkg/httpbase/ierror" + "github.com/bamboo-firewall/be/domain/model" ) type Storage interface { - UpsertHostEndpoint(ctx context.Context, hep *entity.HostEndpoint) *ierror.CoreError - GetHostEndpointByName(ctx context.Context, name string) (*entity.HostEndpoint, *ierror.CoreError) - DeleteHostEndpointByName(ctx context.Context, name string) *ierror.CoreError - ListHostEndpoints(ctx context.Context) ([]*entity.HostEndpoint, *ierror.CoreError) - UpsertGroupPolicy(ctx context.Context, gnp *entity.GlobalNetworkPolicy) *ierror.CoreError + UpsertHostEndpoint(ctx context.Context, hep *entity.HostEndpoint) (*entity.HostEndpoint, *ierror.CoreError) + GetHostEndpoint(ctx context.Context, input *model.GetHostEndpointInput) (*entity.HostEndpoint, *ierror.CoreError) + DeleteHostEndpoint(ctx context.Context, tenantID uint64, ip uint32) *ierror.CoreError + ListHostEndpoints(ctx context.Context, input *model.ListHostEndpointsInput) ([]*entity.HostEndpoint, *ierror.CoreError) + UpsertGroupPolicy(ctx context.Context, gnp *entity.GlobalNetworkPolicy) (*entity.GlobalNetworkPolicy, *ierror.CoreError) GetGNPByName(ctx context.Context, name string) (*entity.GlobalNetworkPolicy, *ierror.CoreError) DeleteGNPByName(ctx context.Context, name string) *ierror.CoreError - ListGNP(ctx context.Context) ([]*entity.GlobalNetworkPolicy, *ierror.CoreError) - UpsertGNS(ctx context.Context, gns *entity.GlobalNetworkSet) *ierror.CoreError + ListGNPs(ctx context.Context, input *model.ListGNPsInput) ([]*entity.GlobalNetworkPolicy, *ierror.CoreError) + UpsertGNS(ctx context.Context, gns *entity.GlobalNetworkSet) (*entity.GlobalNetworkSet, *ierror.CoreError) GetGNSByName(ctx context.Context, name string) (*entity.GlobalNetworkSet, *ierror.CoreError) DeleteGNSByName(ctx context.Context, name string) *ierror.CoreError - ListGNS(ctx context.Context) ([]*entity.GlobalNetworkSet, *ierror.CoreError) + ListGNSs(ctx context.Context) ([]*entity.GlobalNetworkSet, *ierror.CoreError) }