From cd9f950dff124485df20598200ae3d805674746b Mon Sep 17 00:00:00 2001 From: Johnny Yip Date: Fri, 17 Jan 2025 18:05:29 +0000 Subject: [PATCH] Added accelerator flag to cli and cvdr.toml Signed-off-by: Johnny Yip --- pkg/cli/cli.go | 11 ++++++++++- pkg/cli/config.go | 7 ++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pkg/cli/cli.go b/pkg/cli/cli.go index 2d5bbde0..3bd4c59f 100644 --- a/pkg/cli/cli.go +++ b/pkg/cli/cli.go @@ -84,6 +84,7 @@ const ( gcpMachineTypeFlag = "gcp_machine_type" gcpMinCPUPlatformFlag = "gcp_min_cpu_platform" gcpBootDiskSizeGB = "gcp_boot_disk_size_gb" + gcpAcceleratorFlag = "gcp_accelerator" ) const ( @@ -542,6 +543,7 @@ func hostCommand(opts *subCommandOpts) *cobra.Command { } func cvdCommands(opts *subCommandOpts) []*cobra.Command { + gcpAcceleratorFlagValues := []string{} // Create command createFlags := &CreateCVDFlags{ ServiceFlags: opts.ServiceFlags, @@ -553,6 +555,11 @@ func cvdCommands(opts *subCommandOpts) []*cobra.Command { Short: "Creates a CVD", PreRunE: preRunE(createFlags, &opts.ServiceFlags.Service, &opts.InitialConfig), RunE: func(c *cobra.Command, args []string) error { + configs, err := parseAcceleratorFlag(gcpAcceleratorFlagValues) + if err != nil { + return err + } + createFlags.GCP.AcceleratorConfigs = configs return runCreateCVDCommand(c, args, createFlags, opts) }, } @@ -645,6 +652,8 @@ func cvdCommands(opts *subCommandOpts) []*cobra.Command { create.Flags().StringVar(f.ValueRef, name, f.Default, f.Desc) create.MarkFlagsMutuallyExclusive(hostFlag, name) } + create.Flags().StringSliceVar(&gcpAcceleratorFlagValues, "host_" + gcpAcceleratorFlag, + opts.InitialConfig.DefaultService().Host.GCP.AcceleratorConfigs, acceleratorFlagDesc) // List command listFlags := &ListCVDsFlags{ServiceFlags: opts.ServiceFlags} list := &cobra.Command{ @@ -1728,7 +1737,7 @@ type acceleratorConfig struct { Type string } -// Values should follow the pattern: `type=[TYPE],count=[COUNT]`, i.e: `type=nvidia-tesla-p100,count=1`. +// Values should follow the pattern: '"type=[TYPE],count=[COUNT]"', i.e: '"type=nvidia-tesla-p100,count=1"'. func parseAcceleratorFlag(values []string) ([]acceleratorConfig, error) { singleValueParser := func(value string) (*acceleratorConfig, error) { sanitized := strings.Join(strings.Fields(value), "") diff --git a/pkg/cli/config.go b/pkg/cli/config.go index 864bb771..47a008da 100644 --- a/pkg/cli/config.go +++ b/pkg/cli/config.go @@ -27,9 +27,10 @@ import ( ) type GCPHostConfig struct { - MachineType string `json:"machine_type,omitempty"` - MinCPUPlatform string `json:"min_cpu_platform,omitempty"` - BootDiskSizeGB int64 `json:"boot_disk_size_gb,omitempty"` + MachineType string `json:"machine_type,omitempty"` + MinCPUPlatform string `json:"min_cpu_platform,omitempty"` + BootDiskSizeGB int64 `json:"boot_disk_size_gb,omitempty"` + AcceleratorConfigs []string `json:"accelerator_configs,omitempty"` } type HostConfig struct {