1
1
package sdl
2
2
3
3
import (
4
+ "errors"
4
5
"fmt"
5
6
"sort"
6
7
@@ -9,25 +10,27 @@ import (
9
10
types "github.com/akash-network/akash-api/go/node/types/v1beta3"
10
11
)
11
12
12
- type v2GPUNvidia struct {
13
+ var (
14
+ ErrResourceGPUEmptyVendors = errors .New ("sdl: invalid GPU attributes. at least one vendor must be set" )
15
+ )
16
+
17
+ type v2GPU struct {
13
18
Model string `yaml:"model"`
14
19
RAM * memoryQuantity `yaml:"ram,omitempty"`
15
20
}
16
21
17
- func (sdl * v2GPUNvidia ) String () string {
22
+ func (sdl * v2GPU ) String () string {
18
23
key := sdl .Model
19
24
if sdl .RAM != nil {
20
- key += "/" + sdl .RAM .StringWithSuffix ("Gi" )
25
+ key += "/ram/ " + sdl .RAM .StringWithSuffix ("Gi" )
21
26
}
22
27
23
28
return key
24
29
}
25
30
26
- type v2GPUsNvidia []v2GPUNvidia
31
+ type v2GPUs []v2GPU
27
32
28
- type gpuVendor struct {
29
- Nvidia v2GPUsNvidia `yaml:"nvidia,omitempty"`
30
- }
33
+ type gpuVendors map [string ]v2GPUs
31
34
32
35
type v2GPUAttributes types.Attributes
33
36
@@ -66,37 +69,54 @@ func (sdl *v2ResourceGPU) UnmarshalYAML(node *yaml.Node) error {
66
69
func (sdl * v2GPUAttributes ) UnmarshalYAML (node * yaml.Node ) error {
67
70
var res types.Attributes
68
71
69
- var vendor * gpuVendor
72
+ vendors := make ( gpuVendors )
70
73
71
74
for i := 0 ; i < len (node .Content ); i += 2 {
72
75
switch node .Content [i ].Value {
73
76
case "vendor" :
74
- if err := node .Content [i + 1 ].Decode (& vendor ); err != nil {
77
+ if err := node .Content [i + 1 ].Decode (& vendors ); err != nil {
75
78
return err
76
79
}
77
80
default :
78
81
return fmt .Errorf ("sdl: unsupported attribute (%s) for GPU resource" , node .Content [i ].Value )
79
82
}
80
83
}
81
84
82
- if vendor == nil {
83
- return fmt . Errorf ( "sdl: invalid GPU attributes. at least one vendor must be set" )
85
+ if len ( vendors ) == 0 {
86
+ return ErrResourceGPUEmptyVendors
84
87
}
85
88
86
- res = make (types. Attributes , 0 , len ( vendor . Nvidia ))
89
+ resPrealloc := 0
87
90
88
- for _ , model := range vendor .Nvidia {
89
- res = append (res , types.Attribute {
90
- Key : fmt .Sprintf ("vendor/nvidia/model/%s" , model .String ()),
91
- Value : "true" ,
92
- })
91
+ for _ , models := range vendors {
92
+ if len (models ) == 0 {
93
+ resPrealloc += 1
94
+ } else {
95
+ resPrealloc += len (models )
96
+ }
93
97
}
94
98
95
- if len (res ) == 0 {
96
- res = append (res , types.Attribute {
97
- Key : "vendor/nvidia/model/*" ,
98
- Value : "true" ,
99
- })
99
+ for vendor , models := range vendors {
100
+ switch vendor {
101
+ case "nvidia" :
102
+ case "amd" :
103
+ default :
104
+ return fmt .Errorf ("sdl: unsupported GPU vendor (%s)" , vendor )
105
+ }
106
+
107
+ for _ , model := range models {
108
+ res = append (res , types.Attribute {
109
+ Key : fmt .Sprintf ("vendor/%s/model/%s" , vendor , model .String ()),
110
+ Value : "true" ,
111
+ })
112
+ }
113
+
114
+ if len (models ) == 0 {
115
+ res = append (res , types.Attribute {
116
+ Key : fmt .Sprintf ("vendor/%s/model/*" , vendor ),
117
+ Value : "true" ,
118
+ })
119
+ }
100
120
}
101
121
102
122
sort .Sort (res )
0 commit comments