diff --git a/go.mod b/go.mod index 0e514a7..fbf052c 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/free5gc/udr go 1.17 require ( + github.com/antihax/optional v1.0.0 github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/evanphx/json-patch v0.5.2 github.com/free5gc/openapi v1.0.7-0.20230802173229-2b3ded4db293 @@ -14,11 +15,11 @@ require ( github.com/stretchr/testify v1.8.3 github.com/urfave/cli v1.22.5 go.mongodb.org/mongo-driver v1.8.4 + golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 gopkg.in/yaml.v2 v2.4.0 ) require ( - github.com/antihax/optional v1.0.0 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect @@ -56,8 +57,7 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.14.0 // indirect golang.org/x/net v0.17.0 // indirect - golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect - golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect + golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index 29d5750..4d7eb7e 100644 --- a/go.sum +++ b/go.sum @@ -339,7 +339,6 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/internal/context/context.go b/internal/context/context.go index 5436cff..22e9f64 100644 --- a/internal/context/context.go +++ b/internal/context/context.go @@ -56,6 +56,8 @@ type UDRContext struct { InfluenceDataSubscriptions sync.Map appDataInfluDataSubscriptionIdGenerator uint64 mtx sync.RWMutex + ClientMap sync.Map + TokenMap sync.Map } type UESubsData struct { diff --git a/internal/sbi/consumer/nf_accesstoken.go b/internal/sbi/consumer/nf_accesstoken.go new file mode 100644 index 0000000..e3f153a --- /dev/null +++ b/internal/sbi/consumer/nf_accesstoken.go @@ -0,0 +1,87 @@ +package consumer + +import ( + "context" + "time" + + udr_context "github.com/free5gc/udr/internal/context" + "github.com/free5gc/udr/internal/logger" + "github.com/free5gc/udr/pkg/factory" + + "github.com/free5gc/openapi" + "github.com/free5gc/openapi/Nnrf_AccessToken" + "github.com/free5gc/openapi/models" + + "github.com/antihax/optional" + "golang.org/x/oauth2" +) + +func GetTokenCtx(scope, targetNF string) (context.Context, *models.ProblemDetails, error) { + if factory.UdrConfig.GetOAuth() { + tok, pd, err := sendAccTokenReq(scope, targetNF) + if err != nil { + return nil, pd, err + } + return context.WithValue(context.Background(), + openapi.ContextOAuth2, tok), pd, nil + } + return context.TODO(), nil, nil +} + +func sendAccTokenReq(scope, targetNF string) (oauth2.TokenSource, *models.ProblemDetails, error) { + logger.ConsumerLog.Infof("Send Access Token Request") + var client *Nnrf_AccessToken.APIClient + udrSelf := udr_context.GetSelf() + // Set client and set url + configuration := Nnrf_AccessToken.NewConfiguration() + configuration.SetBasePath(udrSelf.NrfUri) + if val, ok := udrSelf.ClientMap.Load(configuration); ok { + client = val.(*Nnrf_AccessToken.APIClient) + } else { + client = Nnrf_AccessToken.NewAPIClient(configuration) + udrSelf.ClientMap.Store(configuration, client) + } + + var tok models.AccessTokenRsp + + if val, ok := udrSelf.TokenMap.Load(scope); ok { + tok = val.(models.AccessTokenRsp) + if int32(time.Now().Unix()) < tok.ExpiresIn { + logger.ConsumerLog.Infof("Token is not expired") + token := &oauth2.Token{ + AccessToken: tok.AccessToken, + TokenType: tok.TokenType, + Expiry: time.Unix(int64(tok.ExpiresIn), 0), + } + return oauth2.StaticTokenSource(token), nil, nil + } + } + + tok, res, err := client.AccessTokenRequestApi.AccessTokenRequest(context.Background(), "client_credentials", + udrSelf.NfId, scope, &Nnrf_AccessToken.AccessTokenRequestParamOpts{ + NfType: optional.NewInterface(models.NfType_UDR), + TargetNfType: optional.NewInterface(targetNF), + }) + if err == nil { + udrSelf.TokenMap.Store(scope, tok) + token := &oauth2.Token{ + AccessToken: tok.AccessToken, + TokenType: tok.TokenType, + Expiry: time.Unix(int64(tok.ExpiresIn), 0), + } + return oauth2.StaticTokenSource(token), nil, err + } else if res != nil { + defer func() { + if resCloseErr := res.Body.Close(); resCloseErr != nil { + logger.ConsumerLog.Errorf("AccessTokenRequestApi response body cannot close: %+v", resCloseErr) + } + }() + if res.Status != err.Error() { + return nil, nil, err + } + problem := err.(openapi.GenericOpenAPIError).Model().(models.ProblemDetails) + return nil, &problem, err + } else { + return nil, nil, openapi.ReportError("server no response") + } +} diff --git a/internal/sbi/consumer/nf_discovery.go b/internal/sbi/consumer/nf_discovery.go index 19455fd..62bc39b 100644 --- a/internal/sbi/consumer/nf_discovery.go +++ b/internal/sbi/consumer/nf_discovery.go @@ -1,7 +1,6 @@ package consumer import ( - "context" "fmt" "net/http" @@ -12,14 +11,19 @@ import ( func SendSearchNFInstances(nrfUri string, targetNfType, requestNfType models.NfType, param Nnrf_NFDiscovery.SearchNFInstancesParamOpts, -) (models.SearchResult, error) { +) (*models.SearchResult, error) { // Set client and set url configuration := Nnrf_NFDiscovery.NewConfiguration() configuration.SetBasePath(nrfUri) client := Nnrf_NFDiscovery.NewAPIClient(configuration) + ctx, _, err := GetTokenCtx("nnrf-disc", "NRF") + if err != nil { + return nil, err + } + var res *http.Response - result, res, err := client.NFInstancesStoreApi.SearchNFInstances(context.TODO(), targetNfType, requestNfType, ¶m) + result, res, err := client.NFInstancesStoreApi.SearchNFInstances(ctx, targetNfType, requestNfType, ¶m) if res != nil && res.StatusCode == http.StatusTemporaryRedirect { err = fmt.Errorf("Temporary Redirect For Non NRF Consumer") } @@ -29,5 +33,5 @@ func SendSearchNFInstances(nrfUri string, targetNfType, requestNfType models.NfT } }() - return result, err + return &result, err } diff --git a/internal/sbi/consumer/nf_managemant.go b/internal/sbi/consumer/nf_managemant.go index 953695e..ae23b4e 100644 --- a/internal/sbi/consumer/nf_managemant.go +++ b/internal/sbi/consumer/nf_managemant.go @@ -106,6 +106,11 @@ func SendRegisterNFInstance(nrfUri, nfInstanceId string, profile models.NfProfil func SendDeregisterNFInstance() (problemDetails *models.ProblemDetails, err error) { logger.ConsumerLog.Infof("Send Deregister NFInstance") + ctx, pd, err := GetTokenCtx("nnrf-nfm", "NRF") + if err != nil { + return pd, err + } + udrSelf := udr_context.GetSelf() // Set client and set url configuration := Nnrf_NFManagement.NewConfiguration() @@ -114,7 +119,7 @@ func SendDeregisterNFInstance() (problemDetails *models.ProblemDetails, err erro var res *http.Response - res, err = client.NFInstanceIDDocumentApi.DeregisterNFInstance(context.Background(), udrSelf.NfId) + res, err = client.NFInstanceIDDocumentApi.DeregisterNFInstance(ctx, udrSelf.NfId) if err == nil { return } else if res != nil { diff --git a/pkg/factory/config.go b/pkg/factory/config.go index 64e7138..d466671 100644 --- a/pkg/factory/config.go +++ b/pkg/factory/config.go @@ -81,6 +81,13 @@ type Sbi struct { BindingIPv4 string `yaml:"bindingIPv4,omitempty" valid:"host,optional"` // IP used to run the server in the node. Port int `yaml:"port" valid:"port,required"` Tls *Tls `yaml:"tls,omitempty" valid:"optional"` + OAuth bool `yaml:"oauth,omitempty" valid:"optional"` +} + +func (c *Config) GetOAuth() bool { + c.RLock() + defer c.RUnlock() + return c.Configuration.Sbi.OAuth } type Tls struct {