From cf9d9bd24855219ac24cf6222d73f934c47d40ad Mon Sep 17 00:00:00 2001 From: Darc Z <53445798+DarcJC@users.noreply.github.com> Date: Mon, 26 Feb 2024 11:49:28 +0800 Subject: [PATCH] Fix DDNS bugs and split up ddns module (#326) * fix: webhook retry logic fix: adjust record type based on ipv4 ipv6 refract: move ddns providers to a new module * refract: move ddns module to pkg/ --- pkg/ddns/cloudflare.go | 177 ++++++++++++++++++++++++ pkg/ddns/ddns.go | 14 ++ pkg/ddns/dummy.go | 7 + pkg/ddns/helper.go | 61 +++++++++ pkg/ddns/webhook.go | 59 ++++++++ service/rpc/nezha.go | 3 +- service/singleton/ddns.go | 274 +------------------------------------- 7 files changed, 327 insertions(+), 268 deletions(-) create mode 100644 pkg/ddns/cloudflare.go create mode 100644 pkg/ddns/ddns.go create mode 100644 pkg/ddns/dummy.go create mode 100644 pkg/ddns/helper.go create mode 100644 pkg/ddns/webhook.go diff --git a/pkg/ddns/cloudflare.go b/pkg/ddns/cloudflare.go new file mode 100644 index 0000000000..8d14c109c7 --- /dev/null +++ b/pkg/ddns/cloudflare.go @@ -0,0 +1,177 @@ +package ddns + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" +) + +type ProviderCloudflare struct { + Secret string +} + +func (provider ProviderCloudflare) UpdateDomain(domainConfig *DomainConfig) bool { + if domainConfig == nil { + return false + } + + zoneID, err := provider.getZoneID(domainConfig.FullDomain) + if err != nil { + log.Printf("无法获取 zone ID: %s\n", err) + return false + } + + // 当IPv4和IPv6同时成功才算作成功 + var resultV4 = true + var resultV6 = true + if domainConfig.EnableIPv4 { + if !provider.addDomainRecord(zoneID, domainConfig, true) { + resultV4 = false + } + } + + if domainConfig.EnableIpv6 { + if !provider.addDomainRecord(zoneID, domainConfig, false) { + resultV6 = false + } + } + + return resultV4 && resultV6 +} + +func (provider ProviderCloudflare) addDomainRecord(zoneID string, domainConfig *DomainConfig, isIpv4 bool) bool { + record, err := provider.findDNSRecord(zoneID, domainConfig.FullDomain, isIpv4) + if err != nil { + log.Printf("查找 DNS 记录时出错: %s\n", err) + return false + } + + if record == nil { + // 添加 DNS 记录 + return provider.createDNSRecord(zoneID, domainConfig, isIpv4) + } else { + // 更新 DNS 记录 + return provider.updateDNSRecord(zoneID, record["id"].(string), domainConfig, isIpv4) + } +} + +func (provider ProviderCloudflare) getZoneID(domain string) (string, error) { + _, realDomain := SplitDomain(domain) + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones?name=%s", realDomain) + body, err := provider.sendRequest("GET", url, nil) + if err != nil { + return "", err + } + + var res map[string]interface{} + err = json.Unmarshal(body, &res) + if err != nil { + return "", err + } + + result := res["result"].([]interface{}) + if len(result) > 0 { + zoneID := result[0].(map[string]interface{})["id"].(string) + return zoneID, nil + } + + return "", fmt.Errorf("找不到 Zone ID") +} + +func (provider ProviderCloudflare) findDNSRecord(zoneID string, domain string, isIPv4 bool) (map[string]interface{}, error) { + var ipType = "A" + if !isIPv4 { + ipType = "AAAA" + } + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records?type=%s&name=%s", zoneID, ipType, domain) + body, err := provider.sendRequest("GET", url, nil) + if err != nil { + return nil, err + } + + var res map[string]interface{} + err = json.Unmarshal(body, &res) + if err != nil { + return nil, err + } + + result := res["result"].([]interface{}) + if len(result) > 0 { + return result[0].(map[string]interface{}), nil + } + + return nil, nil // 没有找到 DNS 记录 +} + +func (provider ProviderCloudflare) createDNSRecord(zoneID string, domainConfig *DomainConfig, isIPv4 bool) bool { + var ipType = "A" + var ipAddr = domainConfig.Ipv4Addr + if !isIPv4 { + ipType = "AAAA" + ipAddr = domainConfig.Ipv6Addr + } + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneID) + data := map[string]interface{}{ + "type": ipType, + "name": domainConfig.FullDomain, + "content": ipAddr, + "ttl": 60, + "proxied": false, + } + jsonData, _ := json.Marshal(data) + _, err := provider.sendRequest("POST", url, jsonData) + return err == nil +} + +func (provider ProviderCloudflare) updateDNSRecord(zoneID string, recordID string, domainConfig *DomainConfig, isIPv4 bool) bool { + var ipType = "A" + var ipAddr = domainConfig.Ipv4Addr + if !isIPv4 { + ipType = "AAAA" + ipAddr = domainConfig.Ipv6Addr + } + url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneID, recordID) + data := map[string]interface{}{ + "type": ipType, + "name": domainConfig.FullDomain, + "content": ipAddr, + "ttl": 60, + "proxied": false, + } + jsonData, _ := json.Marshal(data) + _, err := provider.sendRequest("PATCH", url, jsonData) + return err == nil +} + +// 以下为辅助方法,如发送 HTTP 请求等 +func (provider ProviderCloudflare) sendRequest(method string, url string, data []byte) ([]byte, error) { + client := &http.Client{} + req, err := http.NewRequest(method, url, bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", provider.Secret)) + req.Header.Add("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + log.Printf("NEZHA>> 无法关闭HTTP响应体流: %s\n", err.Error()) + } + }(resp.Body) + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return body, nil +} diff --git a/pkg/ddns/ddns.go b/pkg/ddns/ddns.go new file mode 100644 index 0000000000..3e80d6ee61 --- /dev/null +++ b/pkg/ddns/ddns.go @@ -0,0 +1,14 @@ +package ddns + +type DomainConfig struct { + EnableIPv4 bool + EnableIpv6 bool + FullDomain string + Ipv4Addr string + Ipv6Addr string +} + +type Provider interface { + // UpdateDomain Return is updated + UpdateDomain(domainConfig *DomainConfig) bool +} diff --git a/pkg/ddns/dummy.go b/pkg/ddns/dummy.go new file mode 100644 index 0000000000..5843e720c8 --- /dev/null +++ b/pkg/ddns/dummy.go @@ -0,0 +1,7 @@ +package ddns + +type ProviderDummy struct{} + +func (provider ProviderDummy) UpdateDomain(domainConfig *DomainConfig) bool { + return false +} diff --git a/pkg/ddns/helper.go b/pkg/ddns/helper.go new file mode 100644 index 0000000000..e16314f624 --- /dev/null +++ b/pkg/ddns/helper.go @@ -0,0 +1,61 @@ +package ddns + +import ( + "fmt" + "net/http" + "strings" +) + +func (provider ProviderWebHook) FormatWebhookString(s string, config *DomainConfig, ipType string) string { + if config == nil { + return s + } + + result := strings.TrimSpace(s) + result = strings.Replace(s, "{ip}", config.Ipv4Addr, -1) + result = strings.Replace(result, "{domain}", config.FullDomain, -1) + result = strings.Replace(result, "{type}", ipType, -1) + // remove \r + result = strings.Replace(result, "\r", "", -1) + return result +} + +func SetStringHeadersToRequest(req *http.Request, headers []string) { + if req == nil { + return + } + for _, element := range headers { + kv := strings.SplitN(element, ":", 2) + if len(kv) == 2 { + req.Header.Add(kv[0], kv[1]) + } + } +} + +// SplitDomain 分割域名为前缀和一级域名 +func SplitDomain(domain string) (prefix string, topLevelDomain string) { + // 带有二级TLD的一些常见例子,需要特别处理 + secondLevelTLDs := map[string]bool{ + ".co.uk": true, ".com.cn": true, ".gov.cn": true, ".net.cn": true, ".org.cn": true, + } + + // 分割域名为"."的各部分 + parts := strings.Split(domain, ".") + + // 处理特殊情况,例如 ".co.uk" + for i := len(parts) - 2; i > 0; i-- { + potentialTLD := fmt.Sprintf(".%s.%s", parts[i], parts[i+1]) + if secondLevelTLDs[potentialTLD] { + if i > 1 { + return strings.Join(parts[:i-1], "."), strings.Join(parts[i-1:], ".") + } + return "", domain // 当域名仅为二级TLD时,无前缀 + } + } + + // 常规处理,查找最后一个"."前的所有内容作为前缀 + if len(parts) > 2 { + return strings.Join(parts[:len(parts)-2], "."), strings.Join(parts[len(parts)-2:], ".") + } + return "", domain // 当域名不包含子域名时,无前缀 +} diff --git a/pkg/ddns/webhook.go b/pkg/ddns/webhook.go new file mode 100644 index 0000000000..e290c78c34 --- /dev/null +++ b/pkg/ddns/webhook.go @@ -0,0 +1,59 @@ +package ddns + +import ( + "bytes" + "log" + "net/http" + "strings" +) + +type ProviderWebHook struct { + URL string + RequestMethod string + RequestBody string + RequestHeader string +} + +func (provider ProviderWebHook) UpdateDomain(domainConfig *DomainConfig) bool { + if domainConfig == nil { + return false + } + + if domainConfig.FullDomain == "" { + log.Println("NEZHA>> Failed to update an empty domain") + return false + } + updated := false + client := &http.Client{} + if domainConfig.EnableIPv4 && domainConfig.Ipv4Addr != "" { + url := provider.FormatWebhookString(provider.URL, domainConfig, "ipv4") + body := provider.FormatWebhookString(provider.RequestBody, domainConfig, "ipv4") + header := provider.FormatWebhookString(provider.RequestHeader, domainConfig, "ipv4") + headers := strings.Split(header, "\n") + req, err := http.NewRequest(provider.RequestMethod, url, bytes.NewBufferString(body)) + if err == nil && req != nil { + SetStringHeadersToRequest(req, headers) + if _, err := client.Do(req); err != nil { + log.Printf("NEZHA>> Failed to update a domain: %s. Cause by: %s\n", domainConfig.FullDomain, err.Error()) + } else { + updated = true + } + } + } + if domainConfig.EnableIpv6 && domainConfig.Ipv6Addr != "" { + url := provider.FormatWebhookString(provider.URL, domainConfig, "ipv6") + body := provider.FormatWebhookString(provider.RequestBody, domainConfig, "ipv6") + header := provider.FormatWebhookString(provider.RequestHeader, domainConfig, "ipv6") + headers := strings.Split(header, "\n") + req, err := http.NewRequest(provider.RequestMethod, url, bytes.NewBufferString(body)) + if err == nil && req != nil { + SetStringHeadersToRequest(req, headers) + if _, err := client.Do(req); err != nil { + log.Printf("NEZHA>> Failed to update a domain: %s. Cause by: %s\n", domainConfig.FullDomain, err.Error()) + } else { + updated = true + } + } + } + return updated +} diff --git a/service/rpc/nezha.go b/service/rpc/nezha.go index efb39a9c15..f183043214 100644 --- a/service/rpc/nezha.go +++ b/service/rpc/nezha.go @@ -3,6 +3,7 @@ package rpc import ( "context" "fmt" + "github.com/naiba/nezha/pkg/ddns" "github.com/naiba/nezha/pkg/utils" "log" "time" @@ -125,7 +126,7 @@ func (s *NezhaHandler) ReportSystemInfo(c context.Context, r *pb.Host) (*pb.Rece if err == nil && serverDomain != "" { ipv4, ipv6, _ := utils.SplitIPAddr(host.IP) maxRetries := int(singleton.Conf.DDNS.MaxRetries) - config := &singleton.DDNSDomainConfig{ + config := &ddns.DomainConfig{ EnableIPv4: true, EnableIpv6: true, FullDomain: serverDomain, diff --git a/service/singleton/ddns.go b/service/singleton/ddns.go index b40e187950..810138a795 100644 --- a/service/singleton/ddns.go +++ b/service/singleton/ddns.go @@ -1,273 +1,13 @@ package singleton import ( - "bytes" - "encoding/json" "errors" "fmt" - "io" + ddns2 "github.com/naiba/nezha/pkg/ddns" "log" - "net/http" - "strings" ) -type DDNSDomainConfig struct { - EnableIPv4 bool - EnableIpv6 bool - FullDomain string - Ipv4Addr string - Ipv6Addr string -} - -type DDNSProvider interface { - // UpdateDomain Return is updated - UpdateDomain(domainConfig *DDNSDomainConfig) bool -} - -type DDNSProviderWebHook struct { - URL string - RequestMethod string - RequestBody string - RequestHeader string -} - -func (provider DDNSProviderWebHook) UpdateDomain(domainConfig *DDNSDomainConfig) bool { - if domainConfig == nil { - return false - } - - if domainConfig.FullDomain == "" { - log.Println("NEZHA>> Failed to update an empty domain") - return false - } - updated := false - client := &http.Client{} - if domainConfig.EnableIPv4 && domainConfig.Ipv4Addr != "" { - url := provider.FormatWebhookString(provider.URL, domainConfig, "ipv4") - body := provider.FormatWebhookString(provider.RequestBody, domainConfig, "ipv4") - header := provider.FormatWebhookString(provider.RequestHeader, domainConfig, "ipv4") - headers := strings.Split(header, "\n") - req, err := http.NewRequest(provider.RequestMethod, url, bytes.NewBufferString(body)) - if err == nil && req != nil { - SetStringHeadersToRequest(req, headers) - if _, err := client.Do(req); err != nil { - log.Printf("NEZHA>> Failed to update a domain: %s. Cause by: %s\n", domainConfig.FullDomain, err.Error()) - } - updated = true - } - } - if domainConfig.EnableIpv6 && domainConfig.Ipv6Addr != "" { - url := provider.FormatWebhookString(provider.URL, domainConfig, "ipv6") - body := provider.FormatWebhookString(provider.RequestBody, domainConfig, "ipv6") - header := provider.FormatWebhookString(provider.RequestHeader, domainConfig, "ipv6") - headers := strings.Split(header, "\n") - req, err := http.NewRequest(provider.RequestMethod, url, bytes.NewBufferString(body)) - if err == nil && req != nil { - SetStringHeadersToRequest(req, headers) - if _, err := client.Do(req); err != nil { - log.Printf("NEZHA>> Failed to update a domain: %s. Cause by: %s\n", domainConfig.FullDomain, err.Error()) - } - updated = true - } - } - return updated -} - -type DDNSProviderDummy struct{} - -func (provider DDNSProviderDummy) UpdateDomain(domainConfig *DDNSDomainConfig) bool { - return false -} - -type DDNSProviderCloudflare struct { - Secret string -} - -func (provider DDNSProviderCloudflare) UpdateDomain(domainConfig *DDNSDomainConfig) bool { - if domainConfig == nil { - return false - } - - zoneID, err := provider.getZoneID(domainConfig.FullDomain) - if err != nil { - log.Printf("无法获取 zone ID: %s\n", err) - return false - } - - record, err := provider.findDNSRecord(zoneID, domainConfig.FullDomain) - if err != nil { - log.Printf("查找 DNS 记录时出错: %s\n", err) - return false - } - - if record == nil { - // 添加 DNS 记录 - return provider.createDNSRecord(zoneID, domainConfig) - } else { - // 更新 DNS 记录 - return provider.updateDNSRecord(zoneID, record["id"].(string), domainConfig) - } -} - -func (provider DDNSProviderCloudflare) getZoneID(domain string) (string, error) { - _, realDomain := SplitDomain(domain) - url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones?name=%s", realDomain) - body, err := provider.sendRequest("GET", url, nil) - if err != nil { - return "", err - } - - var res map[string]interface{} - err = json.Unmarshal(body, &res) - if err != nil { - return "", err - } - - result := res["result"].([]interface{}) - if len(result) > 0 { - zoneID := result[0].(map[string]interface{})["id"].(string) - return zoneID, nil - } - - return "", fmt.Errorf("找不到 Zone ID") -} - -func (provider DDNSProviderCloudflare) findDNSRecord(zoneID string, domain string) (map[string]interface{}, error) { - url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records?type=A&name=%s", zoneID, domain) - body, err := provider.sendRequest("GET", url, nil) - if err != nil { - return nil, err - } - - var res map[string]interface{} - err = json.Unmarshal(body, &res) - if err != nil { - return nil, err - } - - result := res["result"].([]interface{}) - if len(result) > 0 { - return result[0].(map[string]interface{}), nil - } - - return nil, nil // 没有找到 DNS 记录 -} - -func (provider DDNSProviderCloudflare) createDNSRecord(zoneID string, domainConfig *DDNSDomainConfig) bool { - url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records", zoneID) - data := map[string]interface{}{ - "type": "A", - "name": domainConfig.FullDomain, - "content": domainConfig.Ipv4Addr, - "ttl": 3600, - "proxied": false, - } - jsonData, _ := json.Marshal(data) - _, err := provider.sendRequest("POST", url, jsonData) - return err == nil -} - -func (provider DDNSProviderCloudflare) updateDNSRecord(zoneID string, recordID string, domainConfig *DDNSDomainConfig) bool { - url := fmt.Sprintf("https://api.cloudflare.com/client/v4/zones/%s/dns_records/%s", zoneID, recordID) - data := map[string]interface{}{ - "type": "A", - "name": domainConfig.FullDomain, - "content": domainConfig.Ipv4Addr, - "ttl": 3600, - "proxied": false, - } - jsonData, _ := json.Marshal(data) - _, err := provider.sendRequest("PATCH", url, jsonData) - return err == nil -} - -// 以下为辅助方法,如发送 HTTP 请求等 -func (provider DDNSProviderCloudflare) sendRequest(method string, url string, data []byte) ([]byte, error) { - client := &http.Client{} - req, err := http.NewRequest(method, url, bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", provider.Secret)) - req.Header.Add("Content-Type", "application/json") - - resp, err := client.Do(req) - if err != nil { - return nil, err - } - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - log.Printf("NEZHA>> 无法关闭HTTP响应体流: %s\n", err.Error()) - } - }(resp.Body) - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - return body, nil -} - -func (provider DDNSProviderWebHook) FormatWebhookString(s string, config *DDNSDomainConfig, ipType string) string { - if config == nil { - return s - } - - result := strings.TrimSpace(s) - result = strings.Replace(s, "{ip}", config.Ipv4Addr, -1) - result = strings.Replace(result, "{domain}", config.FullDomain, -1) - result = strings.Replace(result, "{type}", ipType, -1) - result = strings.Replace(result, "{access_id}", Conf.DDNS.AccessID, -1) - result = strings.Replace(result, "{access_secret}", Conf.DDNS.AccessSecret, -1) - // remove \r - result = strings.Replace(result, "\r", "", -1) - return result -} - -func SetStringHeadersToRequest(req *http.Request, headers []string) { - if req == nil { - return - } - for _, element := range headers { - kv := strings.SplitN(element, ":", 2) - if len(kv) == 2 { - req.Header.Add(kv[0], kv[1]) - } - } -} - -// SplitDomain 分割域名为前缀和一级域名 -func SplitDomain(domain string) (prefix string, topLevelDomain string) { - // 带有二级TLD的一些常见例子,需要特别处理 - secondLevelTLDs := map[string]bool{ - ".co.uk": true, ".com.cn": true, ".gov.cn": true, ".net.cn": true, ".org.cn": true, - } - - // 分割域名为"."的各部分 - parts := strings.Split(domain, ".") - - // 处理特殊情况,例如 ".co.uk" - for i := len(parts) - 2; i > 0; i-- { - potentialTLD := fmt.Sprintf(".%s.%s", parts[i], parts[i+1]) - if secondLevelTLDs[potentialTLD] { - if i > 1 { - return strings.Join(parts[:i-1], "."), strings.Join(parts[i-1:], ".") - } - return "", domain // 当域名仅为二级TLD时,无前缀 - } - } - - // 常规处理,查找最后一个"."前的所有内容作为前缀 - if len(parts) > 2 { - return strings.Join(parts[:len(parts)-2], "."), strings.Join(parts[len(parts)-2:], ".") - } - return "", domain // 当域名不包含子域名时,无前缀 -} - -func RetryableUpdateDomain(provider DDNSProvider, config *DDNSDomainConfig, maxRetries int) bool { +func RetryableUpdateDomain(provider ddns2.Provider, config *ddns2.DomainConfig, maxRetries int) bool { if nil == config { return false } @@ -282,21 +22,21 @@ func RetryableUpdateDomain(provider DDNSProvider, config *DDNSDomainConfig, maxR return false } -func GetDDNSProviderFromString(provider string) (DDNSProvider, error) { +func GetDDNSProviderFromString(provider string) (ddns2.Provider, error) { switch provider { case "webhook": - return DDNSProviderWebHook{ + return ddns2.ProviderWebHook{ URL: Conf.DDNS.WebhookURL, RequestMethod: Conf.DDNS.WebhookMethod, RequestBody: Conf.DDNS.WebhookRequestBody, RequestHeader: Conf.DDNS.WebhookHeaders, }, nil case "dummy": - return DDNSProviderDummy{}, nil + return ddns2.ProviderDummy{}, nil case "cloudflare": - return DDNSProviderCloudflare{ + return ddns2.ProviderCloudflare{ Secret: Conf.DDNS.AccessSecret, }, nil } - return DDNSProviderDummy{}, errors.New(fmt.Sprintf("无法找到配置的DDNS提供者%s", Conf.DDNS.Provider)) + return ddns2.ProviderDummy{}, errors.New(fmt.Sprintf("无法找到配置的DDNS提供者%s", Conf.DDNS.Provider)) }