Skip to content

Commit 1b0f07a

Browse files
committed
Make CDI device requests consistent with other methods
Following the refactoring of device request extraction, we can now make CDI device requests consistent with other methods. This change moves to using image.VisibleDevices instead of separate calls to CDIDevicesFromMounts and VisibleDevicesFromEnvVar. The handling of annotation-based requests will be addressed in a follow-up. Signed-off-by: Evan Lezar <[email protected]>
1 parent 27f5ec8 commit 1b0f07a

File tree

5 files changed

+171
-63
lines changed

5 files changed

+171
-63
lines changed

internal/config/image/cuda_image.go

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,12 @@ type CUDA struct {
5454
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
5555
// The process environment is read (if present) to construc the CUDA Image.
5656
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) {
57+
if spec == nil {
58+
return New(opts...)
59+
}
60+
5761
var env []string
58-
if spec != nil && spec.Process != nil {
62+
if spec.Process != nil {
5963
env = spec.Process.Env
6064
}
6165

@@ -212,19 +216,12 @@ func parseMajorMinorVersion(version string) (string, error) {
212216
// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/
213217
func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
214218
var hasCDIdevice bool
215-
for _, device := range i.VisibleDevicesFromEnvVar() {
219+
for _, device := range i.VisibleDevices() {
216220
if !parser.IsQualifiedName(device) {
217221
return false
218222
}
219223
hasCDIdevice = true
220224
}
221-
222-
for _, device := range i.DevicesFromMounts() {
223-
if !strings.HasPrefix(device, "cdi/") {
224-
return false
225-
}
226-
hasCDIdevice = true
227-
}
228225
return hasCDIdevice
229226
}
230227

@@ -276,20 +273,27 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string {
276273
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
277274
func (i CUDA) visibleDevicesFromMounts() []string {
278275
var devices []string
279-
for _, device := range i.DevicesFromMounts() {
276+
for _, device := range i.requestsFromMounts() {
280277
switch {
281-
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
282-
continue
283278
case strings.HasPrefix(device, volumeMountDevicePrefixImex):
284279
continue
280+
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
281+
name, err := cdiDeviceMountRequest(device).qualifiedName()
282+
if err != nil {
283+
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %w", device, err)
284+
continue
285+
}
286+
devices = append(devices, name)
287+
default:
288+
devices = append(devices, device)
285289
}
286-
devices = append(devices, device)
290+
287291
}
288292
return devices
289293
}
290294

291-
// DevicesFromMounts returns a list of device specified as mounts.
292-
func (i CUDA) DevicesFromMounts() []string {
295+
// requestsFromMounts returns a list of device specified as mounts.
296+
func (i CUDA) requestsFromMounts() []string {
293297
root := filepath.Clean(DeviceListAsVolumeMountsRoot)
294298
seen := make(map[string]bool)
295299
var devices []string
@@ -321,23 +325,30 @@ func (i CUDA) DevicesFromMounts() []string {
321325
return devices
322326
}
323327

324-
// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
325-
func (i CUDA) CDIDevicesFromMounts() []string {
326-
var devices []string
327-
for _, mountDevice := range i.DevicesFromMounts() {
328-
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) {
329-
continue
330-
}
331-
parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3)
332-
if len(parts) != 3 {
333-
continue
334-
}
335-
vendor := parts[0]
336-
class := parts[1]
337-
device := parts[2]
338-
devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device))
328+
// a cdiDeviceMountRequest represents a CDI device requests as a mount.
329+
// Here the host path /dev/null is mounted to a particular path in the container.
330+
// The container path has the form:
331+
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>/<device>
332+
// or
333+
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>=<device>
334+
type cdiDeviceMountRequest string
335+
336+
// qualifiedName returns the fully-qualified name of the CDI device.
337+
func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
338+
if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) {
339+
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
339340
}
340-
return devices
341+
342+
requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI)
343+
if parser.IsQualifiedName(requestedDevice) {
344+
return requestedDevice, nil
345+
}
346+
347+
parts := strings.SplitN(requestedDevice, "/", 3)
348+
if len(parts) != 3 {
349+
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
350+
}
351+
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
341352
}
342353

343354
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
@@ -352,7 +363,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string {
352363
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
353364
func (i CUDA) ImexChannelsFromMounts() []string {
354365
var channels []string
355-
for _, mountDevice := range i.DevicesFromMounts() {
366+
for _, mountDevice := range i.requestsFromMounts() {
356367
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
357368
continue
358369
}

internal/config/image/cuda_image_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
487487
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
488488
},
489489
{
490-
description: "cdi devices are ignored",
491-
mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"),
492-
expectedDevices: []string{"GPU0", "GPU1"},
490+
description: "cdi devices are included",
491+
mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"),
492+
expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"},
493493
},
494494
{
495495
description: "imex devices are ignored",

internal/info/auto_test.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) {
184184
expectedMode: "legacy",
185185
},
186186
{
187-
description: "cdi mount and non-CDI envvar resolves to legacy",
187+
description: "cdi mount and non-CDI envvar resolves to cdi",
188188
mode: "auto",
189189
envmap: map[string]string{
190190
"NVIDIA_VISIBLE_DEVICES": "0",
@@ -197,6 +197,22 @@ func TestResolveAutoMode(t *testing.T) {
197197
"tegra": false,
198198
"nvgpu": false,
199199
},
200+
expectedMode: "cdi",
201+
},
202+
{
203+
description: "non-cdi mount and CDI envvar resolves to legacy",
204+
mode: "auto",
205+
envmap: map[string]string{
206+
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0",
207+
},
208+
mounts: []string{
209+
"/var/run/nvidia-container-devices/0",
210+
},
211+
info: map[string]bool{
212+
"nvml": true,
213+
"tegra": false,
214+
"nvgpu": false,
215+
},
200216
expectedMode: "legacy",
201217
},
202218
}
@@ -232,6 +248,8 @@ func TestResolveAutoMode(t *testing.T) {
232248
image, _ := image.New(
233249
image.WithEnvMap(tc.envmap),
234250
image.WithMounts(mounts),
251+
image.WithAcceptDeviceListAsVolumeMounts(true),
252+
image.WithAcceptEnvvarUnprivileged(true),
235253
)
236254
mode := resolveMode(logger, tc.mode, image, properties)
237255
require.EqualValues(t, tc.expectedMode, mode)

internal/modifier/cdi.go

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -66,57 +66,66 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
6666
}
6767

6868
func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) {
69+
cdiModifier := &cdiModifier{
70+
logger: logger,
71+
acceptDeviceListAsVolumeMounts: cfg.AcceptDeviceListAsVolumeMounts,
72+
acceptEnvvarUnprivileged: cfg.AcceptEnvvarUnprivileged,
73+
annotationPrefixes: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes,
74+
defaultKind: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind,
75+
}
76+
return cdiModifier.getDevicesFromSpec(ociSpec)
77+
}
78+
79+
// TODO: We should rename this type.
80+
type cdiModifier struct {
81+
logger logger.Interface
82+
acceptDeviceListAsVolumeMounts bool
83+
acceptEnvvarUnprivileged bool
84+
annotationPrefixes []string
85+
defaultKind string
86+
}
87+
88+
func (c *cdiModifier) getDevicesFromSpec(ociSpec oci.Spec) ([]string, error) {
6989
rawSpec, err := ociSpec.Load()
7090
if err != nil {
7191
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
7292
}
7393

74-
annotationDevices, err := getAnnotationDevices(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes, rawSpec.Annotations)
75-
if err != nil {
76-
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
77-
}
78-
if len(annotationDevices) > 0 {
79-
return annotationDevices, nil
94+
if rawSpec != nil {
95+
annotationDevices, err := getAnnotationDevices(c.annotationPrefixes, rawSpec.Annotations)
96+
if err != nil {
97+
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
98+
}
99+
if len(annotationDevices) > 0 {
100+
return annotationDevices, nil
101+
}
80102
}
81103

82104
container, err := image.NewCUDAImageFromSpec(
83105
rawSpec,
84-
image.WithLogger(logger),
106+
image.WithLogger(c.logger),
107+
image.WithAcceptDeviceListAsVolumeMounts(c.acceptDeviceListAsVolumeMounts),
108+
image.WithAcceptEnvvarUnprivileged(c.acceptEnvvarUnprivileged),
85109
)
86110
if err != nil {
87111
return nil, err
88112
}
89-
if cfg.AcceptDeviceListAsVolumeMounts {
90-
mountDevices := container.CDIDevicesFromMounts()
91-
if len(mountDevices) > 0 {
92-
return mountDevices, nil
93-
}
94-
}
95113

96114
var devices []string
97115
seen := make(map[string]bool)
98-
for _, name := range container.VisibleDevicesFromEnvVar() {
116+
for _, name := range container.VisibleDevices() {
99117
if !parser.IsQualifiedName(name) {
100-
name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name)
118+
name = fmt.Sprintf("%s=%s", c.defaultKind, name)
101119
}
102120
if seen[name] {
103-
logger.Debugf("Ignoring duplicate device %q", name)
121+
c.logger.Debugf("Ignoring duplicate device %q", name)
104122
continue
105123
}
124+
seen[name] = true
106125
devices = append(devices, name)
107126
}
108127

109-
if len(devices) == 0 {
110-
return nil, nil
111-
}
112-
113-
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) {
114-
return devices, nil
115-
}
116-
117-
logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES: %v", devices)
118-
119-
return nil, nil
128+
return devices, nil
120129
}
121130

122131
// getAnnotationDevices returns a list of devices specified in the annotations.

internal/modifier/cdi_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ import (
2020
"fmt"
2121
"testing"
2222

23+
"github.com/opencontainers/runtime-spec/specs-go"
24+
testlog "github.com/sirupsen/logrus/hooks/test"
2325
"github.com/stretchr/testify/require"
26+
27+
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
2428
)
2529

2630
func TestGetAnnotationDevices(t *testing.T) {
@@ -90,3 +94,69 @@ func TestGetAnnotationDevices(t *testing.T) {
9094
})
9195
}
9296
}
97+
98+
func TestGetDevicesFromSpec(t *testing.T) {
99+
logger, _ := testlog.NewNullLogger()
100+
101+
testCases := []struct {
102+
description string
103+
input cdiModifier
104+
spec *specs.Spec
105+
expectedDevices []string
106+
}{
107+
{
108+
description: "empty spec yields no devices",
109+
},
110+
{
111+
description: "cdi devices from mounts",
112+
input: cdiModifier{
113+
defaultKind: "nvidia.com/gpu",
114+
acceptEnvvarUnprivileged: true,
115+
acceptDeviceListAsVolumeMounts: true,
116+
},
117+
spec: &specs.Spec{
118+
Mounts: []specs.Mount{
119+
{
120+
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0",
121+
Source: "/dev/null",
122+
},
123+
{
124+
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/1",
125+
Source: "/dev/null",
126+
},
127+
},
128+
},
129+
expectedDevices: []string{"nvidia.com/gpu=0", "nvidia.com/gpu=1"},
130+
},
131+
{
132+
description: "cdi devices from envvar",
133+
input: cdiModifier{
134+
defaultKind: "nvidia.com/gpu",
135+
acceptEnvvarUnprivileged: true,
136+
acceptDeviceListAsVolumeMounts: true,
137+
},
138+
spec: &specs.Spec{
139+
Process: &specs.Process{
140+
Env: []string{"NVIDIA_VISIBLE_DEVICES=0,example.com/class=device"},
141+
},
142+
},
143+
expectedDevices: []string{"nvidia.com/gpu=0", "example.com/class=device"},
144+
},
145+
}
146+
147+
for _, tc := range testCases {
148+
tc.input.logger = logger
149+
150+
spec := &oci.SpecMock{
151+
LoadFunc: func() (*specs.Spec, error) {
152+
return tc.spec, nil
153+
},
154+
}
155+
156+
t.Run(tc.description, func(t *testing.T) {
157+
devices, err := tc.input.getDevicesFromSpec(spec)
158+
require.NoError(t, err)
159+
require.EqualValues(t, tc.expectedDevices, devices)
160+
})
161+
}
162+
}

0 commit comments

Comments
 (0)