Skip to content

Commit 0e11147

Browse files
committed
Make CDI device requests consistent with other methods
Following the refactoring of device request extraction, we can now make CDI device requests consistent with other methods. This change moves to using image.VisibleDevices instead of separate calls to CDIDevicesFromMounts and VisibleDevicesFromEnvVar. The handling of annotation-based requests will be addressed in a follow-up. Signed-off-by: Evan Lezar <[email protected]>
1 parent 27f5ec8 commit 0e11147

File tree

5 files changed

+170
-54
lines changed

5 files changed

+170
-54
lines changed

internal/config/image/cuda_image.go

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,12 @@ type CUDA struct {
5454
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
5555
// The process environment is read (if present) to construc the CUDA Image.
5656
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) {
57+
if spec == nil {
58+
return New(opts...)
59+
}
60+
5761
var env []string
58-
if spec != nil && spec.Process != nil {
62+
if spec.Process != nil {
5963
env = spec.Process.Env
6064
}
6165

@@ -212,19 +216,12 @@ func parseMajorMinorVersion(version string) (string, error) {
212216
// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/
213217
func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
214218
var hasCDIdevice bool
215-
for _, device := range i.VisibleDevicesFromEnvVar() {
219+
for _, device := range i.VisibleDevices() {
216220
if !parser.IsQualifiedName(device) {
217221
return false
218222
}
219223
hasCDIdevice = true
220224
}
221-
222-
for _, device := range i.DevicesFromMounts() {
223-
if !strings.HasPrefix(device, "cdi/") {
224-
return false
225-
}
226-
hasCDIdevice = true
227-
}
228225
return hasCDIdevice
229226
}
230227

@@ -276,20 +273,27 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string {
276273
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
277274
func (i CUDA) visibleDevicesFromMounts() []string {
278275
var devices []string
279-
for _, device := range i.DevicesFromMounts() {
276+
for _, device := range i.requestsFromMounts() {
280277
switch {
281-
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
282-
continue
283278
case strings.HasPrefix(device, volumeMountDevicePrefixImex):
284279
continue
280+
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
281+
name, err := cdiDeviceMountRequest(device).qualifiedName()
282+
if err != nil {
283+
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %w", device, err)
284+
continue
285+
}
286+
devices = append(devices, name)
287+
default:
288+
devices = append(devices, device)
285289
}
286-
devices = append(devices, device)
290+
287291
}
288292
return devices
289293
}
290294

291-
// DevicesFromMounts returns a list of device specified as mounts.
292-
func (i CUDA) DevicesFromMounts() []string {
295+
// requestsFromMounts returns a list of device specified as mounts.
296+
func (i CUDA) requestsFromMounts() []string {
293297
root := filepath.Clean(DeviceListAsVolumeMountsRoot)
294298
seen := make(map[string]bool)
295299
var devices []string
@@ -321,23 +325,30 @@ func (i CUDA) DevicesFromMounts() []string {
321325
return devices
322326
}
323327

324-
// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
325-
func (i CUDA) CDIDevicesFromMounts() []string {
326-
var devices []string
327-
for _, mountDevice := range i.DevicesFromMounts() {
328-
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) {
329-
continue
330-
}
331-
parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3)
332-
if len(parts) != 3 {
333-
continue
334-
}
335-
vendor := parts[0]
336-
class := parts[1]
337-
device := parts[2]
338-
devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device))
328+
// a cdiDeviceMountRequest represents a CDI device requests as a mount.
329+
// Here the host path /dev/null is mounted to a particular path in the container.
330+
// The container path has the form:
331+
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>/<device>
332+
// or
333+
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>=<device>
334+
type cdiDeviceMountRequest string
335+
336+
// qualifiedName returns the fully-qualified name of the CDI device.
337+
func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
338+
if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) {
339+
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
339340
}
340-
return devices
341+
342+
requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI)
343+
if parser.IsQualifiedName(requestedDevice) {
344+
return requestedDevice, nil
345+
}
346+
347+
parts := strings.SplitN(requestedDevice, "/", 3)
348+
if len(parts) != 3 {
349+
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
350+
}
351+
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
341352
}
342353

343354
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
@@ -352,7 +363,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string {
352363
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
353364
func (i CUDA) ImexChannelsFromMounts() []string {
354365
var channels []string
355-
for _, mountDevice := range i.DevicesFromMounts() {
366+
for _, mountDevice := range i.requestsFromMounts() {
356367
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
357368
continue
358369
}

internal/config/image/cuda_image_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
487487
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
488488
},
489489
{
490-
description: "cdi devices are ignored",
491-
mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"),
492-
expectedDevices: []string{"GPU0", "GPU1"},
490+
description: "cdi devices are included",
491+
mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"),
492+
expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"},
493493
},
494494
{
495495
description: "imex devices are ignored",

internal/info/auto_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) {
184184
expectedMode: "legacy",
185185
},
186186
{
187-
description: "cdi mount and non-CDI envvar resolves to legacy",
187+
description: "cdi mount and non-CDI envvar resolves to cdi",
188188
mode: "auto",
189189
envmap: map[string]string{
190190
"NVIDIA_VISIBLE_DEVICES": "0",
@@ -197,6 +197,22 @@ func TestResolveAutoMode(t *testing.T) {
197197
"tegra": false,
198198
"nvgpu": false,
199199
},
200+
expectedMode: "cdi",
201+
},
202+
{
203+
description: "non-cdi mount and CDI envvar resolves to legacy",
204+
mode: "auto",
205+
envmap: map[string]string{
206+
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0",
207+
},
208+
mounts: []string{
209+
"/var/run/nvidia-container-devices/0",
210+
},
211+
info: map[string]bool{
212+
"nvml": true,
213+
"tegra": false,
214+
"nvgpu": false,
215+
},
200216
expectedMode: "legacy",
201217
},
202218
}
@@ -232,6 +248,8 @@ func TestResolveAutoMode(t *testing.T) {
232248
image, _ := image.New(
233249
image.WithEnvMap(tc.envmap),
234250
image.WithMounts(mounts),
251+
image.WithAcceptDeviceListAsVolumeMounts(true),
252+
image.WithAcceptEnvvarUnprivileged(true),
235253
)
236254
mode := resolveMode(logger, tc.mode, image, properties)
237255
require.EqualValues(t, tc.expectedMode, mode)

internal/modifier/cdi.go

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,55 +66,72 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
6666
}
6767

6868
func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) {
69+
cdiModifier := &cdiModifier{
70+
logger: logger,
71+
annotationPrefixes: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes,
72+
defaultKind: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind,
73+
}
74+
return cdiModifier.getDevicesFromSpec(ociSpec)
75+
}
76+
77+
// TODO: We should rename this type.
78+
type cdiModifier struct {
79+
logger logger.Interface
80+
acceptDeviceListAsVolumeMounts bool
81+
acceptEnvvarUnprivileged bool
82+
annotationPrefixes []string
83+
defaultKind string
84+
}
85+
86+
func (c *cdiModifier) getDevicesFromSpec(ociSpec oci.Spec) ([]string, error) {
6987
rawSpec, err := ociSpec.Load()
7088
if err != nil {
7189
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
7290
}
7391

74-
annotationDevices, err := getAnnotationDevices(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes, rawSpec.Annotations)
75-
if err != nil {
76-
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
77-
}
78-
if len(annotationDevices) > 0 {
79-
return annotationDevices, nil
92+
if rawSpec != nil {
93+
annotationDevices, err := getAnnotationDevices(c.annotationPrefixes, rawSpec.Annotations)
94+
if err != nil {
95+
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
96+
}
97+
if len(annotationDevices) > 0 {
98+
return annotationDevices, nil
99+
}
80100
}
81101

82102
container, err := image.NewCUDAImageFromSpec(
83103
rawSpec,
84-
image.WithLogger(logger),
104+
image.WithLogger(c.logger),
105+
image.WithAcceptDeviceListAsVolumeMounts(c.acceptDeviceListAsVolumeMounts),
106+
image.WithAcceptEnvvarUnprivileged(c.acceptEnvvarUnprivileged),
85107
)
86108
if err != nil {
87109
return nil, err
88110
}
89-
if cfg.AcceptDeviceListAsVolumeMounts {
90-
mountDevices := container.CDIDevicesFromMounts()
91-
if len(mountDevices) > 0 {
92-
return mountDevices, nil
93-
}
94-
}
95111

96112
var devices []string
97113
seen := make(map[string]bool)
98-
for _, name := range container.VisibleDevicesFromEnvVar() {
114+
for _, name := range container.VisibleDevices() {
99115
if !parser.IsQualifiedName(name) {
100-
name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name)
116+
name = fmt.Sprintf("%s=%s", c.defaultKind, name)
101117
}
102118
if seen[name] {
103-
logger.Debugf("Ignoring duplicate device %q", name)
119+
c.logger.Debugf("Ignoring duplicate device %q", name)
104120
continue
105121
}
122+
seen[name] = true
106123
devices = append(devices, name)
107124
}
108125

109126
if len(devices) == 0 {
110127
return nil, nil
111128
}
112129

113-
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) {
130+
if c.acceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) {
114131
return devices, nil
115132
}
116133

117-
logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES: %v", devices)
134+
c.logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES: %v", devices)
118135

119136
return nil, nil
120137
}

internal/modifier/cdi_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ import (
2020
"fmt"
2121
"testing"
2222

23+
"github.com/opencontainers/runtime-spec/specs-go"
24+
testlog "github.com/sirupsen/logrus/hooks/test"
2325
"github.com/stretchr/testify/require"
26+
27+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
2428
)
2529

2630
func TestGetAnnotationDevices(t *testing.T) {
@@ -90,3 +94,69 @@ func TestGetAnnotationDevices(t *testing.T) {
9094
})
9195
}
9296
}
97+
98+
func TestGetDevicesFromSpec(t *testing.T) {
99+
logger, _ := testlog.NewNullLogger()
100+
101+
testCases := []struct {
102+
description string
103+
input cdiModifier
104+
spec *specs.Spec
105+
expectedDevices []string
106+
}{
107+
{
108+
description: "empty spec yields no devices",
109+
},
110+
{
111+
description: "cdi devices from mounts",
112+
input: cdiModifier{
113+
defaultKind: "nvidia.com/gpu",
114+
acceptEnvvarUnprivileged: true,
115+
acceptDeviceListAsVolumeMounts: true,
116+
},
117+
spec: &specs.Spec{
118+
Mounts: []specs.Mount{
119+
{
120+
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0",
121+
Source: "/dev/null",
122+
},
123+
{
124+
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/1",
125+
Source: "/dev/null",
126+
},
127+
},
128+
},
129+
expectedDevices: []string{"nvidia.com/gpu=0", "nvidia.com/gpu=1"},
130+
},
131+
{
132+
description: "cdi devices from envvar",
133+
input: cdiModifier{
134+
defaultKind: "nvidia.com/gpu",
135+
acceptEnvvarUnprivileged: true,
136+
acceptDeviceListAsVolumeMounts: true,
137+
},
138+
spec: &specs.Spec{
139+
Process: &specs.Process{
140+
Env: []string{"NVIDIA_VISIBLE_DEVICES=0,example.com/class=device"},
141+
},
142+
},
143+
expectedDevices: []string{"nvidia.com/gpu=0", "example.com/class=device"},
144+
},
145+
}
146+
147+
for _, tc := range testCases {
148+
tc.input.logger = logger
149+
150+
spec := &oci.SpecMock{
151+
LoadFunc: func() (*specs.Spec, error) {
152+
return tc.spec, nil
153+
},
154+
}
155+
156+
t.Run(tc.description, func(t *testing.T) {
157+
devices, err := tc.input.getDevicesFromSpec(spec)
158+
require.NoError(t, err)
159+
require.EqualValues(t, tc.expectedDevices, devices)
160+
})
161+
}
162+
}

0 commit comments

Comments
 (0)