diff --git a/plugins/pxeboot/plugin.go b/plugins/pxeboot/plugin.go index 28f669e..296ae30 100644 --- a/plugins/pxeboot/plugin.go +++ b/plugins/pxeboot/plugin.go @@ -49,14 +49,24 @@ func parseArgs(args ...string) (*url.URL, *url.URL, error) { if len(args) != 2 { return nil, nil, fmt.Errorf("exactly two arguments must be passed to PXEBOOT plugin, got %d", len(args)) } + tftp, err := url.Parse(args[0]) if err != nil { return nil, nil, err } + ipxe, err := url.Parse(args[1]) if err != nil { return nil, nil, err } + + if tftp.Scheme != "tftp" || tftp.Host == "" || tftp.Path == "" || tftp.Path[0] != '/' || tftp.Path[1:] == "" { + return nil, nil, fmt.Errorf("Malformed FTTP parameter, should be a valid URL") + } + + if (ipxe.Scheme != "http" && ipxe.Scheme != "https") || ipxe.Host == "" || ipxe.Path == "" { + return nil, nil, fmt.Errorf("Malformed iPXE parameter, should be a valid URL") + } return tftp, ipxe, nil } @@ -127,6 +137,7 @@ func setup6(args ...string) (handler.Handler6, error) { if err != nil { return nil, err } + tftpOption = dhcpv6.OptBootFileURL(tftp.String()) ipxeOption = dhcpv6.OptBootFileURL(ipxe.String()) diff --git a/plugins/pxeboot/plugin_test.go b/plugins/pxeboot/plugin_test.go index 2cdbcb2..5cc9782 100644 --- a/plugins/pxeboot/plugin_test.go +++ b/plugins/pxeboot/plugin_test.go @@ -39,6 +39,79 @@ func Init6(numOptBoot int) { } } +/* parametrization */ + +func TestWrongNumberArgs(t *testing.T) { + _, _, err := parseArgs(tftpPath, ipxePath, "not-needed-arg") + if err == nil { + t.Fatal("no error occurred when providing wrong number of args (3), but it should have") + } + + _, _, err = parseArgs("only-one-arg") + if err == nil { + t.Fatal("no error occurred when providing wrong number of args (1), but it should have") + } +} + +func TestWrongArgs(t *testing.T) { + malformedTFTPPath := []string{"tftp://1.2.3.4/", "foo://1.2.3.4/boot.efi"} + malformedIPXEPath := []string{"httpfoo://www.example.com", "https:/1.2.3"} + + for _, wrongTFTP := range malformedTFTPPath { + _, err := setup4(wrongTFTP, ipxePath) + if err == nil { + t.Fatalf("no error occurred when providing wrong TFTP path %s, but it should have", wrongTFTP) + } + if tftpBootFileOption != nil { + t.Fatalf("TFTP boot file was set when providing wrong TFTP path %s, but it should be empty", wrongTFTP) + } + if tftpServerNameOption != nil { + t.Fatalf("TFTP server name was set when providing wrong TFTP path %s, but it should be empty", wrongTFTP) + } + if ipxeBootFileOption != nil { + t.Fatalf("IPXE boot file was set when providing wrong TFTP path %s, but it should be empty", wrongTFTP) + } + + _, err = setup6(wrongTFTP, ipxePath) + if err == nil { + t.Fatalf("no error occurred when providing wrong TFTP path %s, but it should have", wrongTFTP) + } + if tftpOption != nil { + t.Fatalf("TFTP boot file was set when providing wrong TFTP path %s, but it should be empty", wrongTFTP) + } + if ipxeOption != nil { + t.Fatalf("IPXE boot file was set when providing wrong TFTP path %s, but it should be empty", wrongTFTP) + } + } + + for _, wrongIPXE := range malformedIPXEPath { + _, err := setup4(tftpPath, wrongIPXE) + if err == nil { + t.Fatalf("no error occurred when providing wrong IPXE path %s, but it should have", wrongIPXE) + } + if tftpBootFileOption != nil { + t.Fatalf("TFTP boot file was set when providing wrong IPXE path %s, but it should be empty", wrongIPXE) + } + if tftpServerNameOption != nil { + t.Fatalf("TFTP server name set when providing wrong IPXE path %s, but it should be empty", wrongIPXE) + } + if ipxeBootFileOption != nil { + t.Fatalf("IPXE boot file was set when providing wrong IPXE path %s, but it should be empty", wrongIPXE) + } + + _, err = setup6(tftpPath, wrongIPXE) + if err == nil { + t.Fatalf("no error occurred when providing wrong IPXE path %s, but it should have", wrongIPXE) + } + if tftpOption != nil { + t.Fatalf("TFTP boot file was set when providing wrong IPXE path %s, but it should be empty", wrongIPXE) + } + if ipxeOption != nil { + t.Fatalf("IPXE boot file was set when providing wrong IPXE path %s, but it should be empty", wrongIPXE) + } + } +} + /* IPv6 */ func TestPXERequested6(t *testing.T) {