Skip to content

Make CDI device requests consistent with other methods #1132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 43 additions & 32 deletions internal/config/image/cuda_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@ type CUDA struct {
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
// The process environment is read (if present) to construc the CUDA Image.
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) {
if spec == nil {
return New(opts...)
}

var env []string
if spec != nil && spec.Process != nil {
if spec.Process != nil {
env = spec.Process.Env
}

Expand Down Expand Up @@ -212,19 +216,12 @@ func parseMajorMinorVersion(version string) (string, error) {
// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/
func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
var hasCDIdevice bool
for _, device := range i.VisibleDevicesFromEnvVar() {
for _, device := range i.VisibleDevices() {
if !parser.IsQualifiedName(device) {
return false
}
hasCDIdevice = true
}

for _, device := range i.DevicesFromMounts() {
if !strings.HasPrefix(device, "cdi/") {
return false
}
hasCDIdevice = true
}
return hasCDIdevice
}

Expand Down Expand Up @@ -276,20 +273,27 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string {
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
func (i CUDA) visibleDevicesFromMounts() []string {
var devices []string
for _, device := range i.DevicesFromMounts() {
for _, device := range i.requestsFromMounts() {
switch {
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
continue
case strings.HasPrefix(device, volumeMountDevicePrefixImex):
continue
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
name, err := cdiDeviceMountRequest(device).qualifiedName()
if err != nil {
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %w", device, err)
Copy link
Preview

Copilot AI Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The %w verb is only supported by fmt.Errorf. In logging calls, use %v or %s to ensure the error is formatted correctly.

Suggested change
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %w", device, err)
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %v", device, err)

Copilot uses AI. Check for mistakes.

continue
}
devices = append(devices, name)
default:
devices = append(devices, device)
}
devices = append(devices, device)

}
return devices
}

// DevicesFromMounts returns a list of device specified as mounts.
func (i CUDA) DevicesFromMounts() []string {
// requestsFromMounts returns a list of device specified as mounts.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this function since we're returning all requests including:

  • device requests (legacy and CDI)
  • imex channel requests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: DevicesFromMounts has been renamed and turned private to the config pkg

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It was not used outside this package.

func (i CUDA) requestsFromMounts() []string {
root := filepath.Clean(DeviceListAsVolumeMountsRoot)
seen := make(map[string]bool)
var devices []string
Expand Down Expand Up @@ -321,23 +325,30 @@ func (i CUDA) DevicesFromMounts() []string {
return devices
}

// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
func (i CUDA) CDIDevicesFromMounts() []string {
var devices []string
for _, mountDevice := range i.DevicesFromMounts() {
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) {
continue
}
parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3)
if len(parts) != 3 {
continue
}
vendor := parts[0]
class := parts[1]
device := parts[2]
devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device))
// a cdiDeviceMountRequest represents a CDI device requests as a mount.
// Here the host path /dev/null is mounted to a particular path in the container.
// The container path has the form:
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>/<device>
// or
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>=<device>
type cdiDeviceMountRequest string

// qualifiedName returns the fully-qualified name of the CDI device.
func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) {
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
}
return devices

requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI)
if parser.IsQualifiedName(requestedDevice) {
return requestedDevice, nil
}

parts := strings.SplitN(requestedDevice, "/", 3)
if len(parts) != 3 {
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
}
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
}

// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
Expand All @@ -352,7 +363,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string {
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
func (i CUDA) ImexChannelsFromMounts() []string {
var channels []string
for _, mountDevice := range i.DevicesFromMounts() {
for _, mountDevice := range i.requestsFromMounts() {
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
continue
}
Expand Down
6 changes: 3 additions & 3 deletions internal/config/image/cuda_image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
},
{
description: "cdi devices are ignored",
mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"),
expectedDevices: []string{"GPU0", "GPU1"},
description: "cdi devices are included",
mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"),
expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"},
},
{
description: "imex devices are ignored",
Expand Down
20 changes: 19 additions & 1 deletion internal/info/auto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) {
expectedMode: "legacy",
},
{
description: "cdi mount and non-CDI envvar resolves to legacy",
description: "cdi mount and non-CDI envvar resolves to cdi",
mode: "auto",
envmap: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "0",
Expand All @@ -197,6 +197,22 @@ func TestResolveAutoMode(t *testing.T) {
"tegra": false,
"nvgpu": false,
},
expectedMode: "cdi",
},
{
description: "non-cdi mount and CDI envvar resolves to legacy",
mode: "auto",
envmap: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0",
},
mounts: []string{
"/var/run/nvidia-container-devices/0",
},
info: map[string]bool{
"nvml": true,
"tegra": false,
"nvgpu": false,
},
expectedMode: "legacy",
},
}
Expand Down Expand Up @@ -232,6 +248,8 @@ func TestResolveAutoMode(t *testing.T) {
image, _ := image.New(
image.WithEnvMap(tc.envmap),
image.WithMounts(mounts),
image.WithAcceptDeviceListAsVolumeMounts(true),
image.WithAcceptEnvvarUnprivileged(true),
)
mode := resolveMode(logger, tc.mode, image, properties)
require.EqualValues(t, tc.expectedMode, mode)
Expand Down
63 changes: 36 additions & 27 deletions internal/modifier/cdi.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,57 +66,66 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe
}

func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) {
cdiModifier := &cdiModifier{
logger: logger,
acceptDeviceListAsVolumeMounts: cfg.AcceptDeviceListAsVolumeMounts,
acceptEnvvarUnprivileged: cfg.AcceptEnvvarUnprivileged,
annotationPrefixes: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes,
defaultKind: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind,
}
return cdiModifier.getDevicesFromSpec(ociSpec)
}

// TODO: We should rename this type.
type cdiModifier struct {
Comment on lines +69 to +80
Copy link
Preview

Copilot AI Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The cdiModifier type has a TODO to rename it. Consider selecting a more descriptive name or removing the TODO if this is the intended name.

Suggested change
cdiModifier := &cdiModifier{
logger: logger,
acceptDeviceListAsVolumeMounts: cfg.AcceptDeviceListAsVolumeMounts,
acceptEnvvarUnprivileged: cfg.AcceptEnvvarUnprivileged,
annotationPrefixes: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes,
defaultKind: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind,
}
return cdiModifier.getDevicesFromSpec(ociSpec)
}
// TODO: We should rename this type.
type cdiModifier struct {
deviceHandler := &CDIDeviceHandler{
logger: logger,
acceptDeviceListAsVolumeMounts: cfg.AcceptDeviceListAsVolumeMounts,
acceptEnvvarUnprivileged: cfg.AcceptEnvvarUnprivileged,
annotationPrefixes: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes,
defaultKind: cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind,
}
return deviceHandler.getDevicesFromSpec(ociSpec)
}
// This type handles CDI-related device configurations and annotations.
type CDIDeviceHandler struct {

Copilot uses AI. Check for mistakes.

logger logger.Interface
acceptDeviceListAsVolumeMounts bool
acceptEnvvarUnprivileged bool
annotationPrefixes []string
defaultKind string
}

func (c *cdiModifier) getDevicesFromSpec(ociSpec oci.Spec) ([]string, error) {
rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}

annotationDevices, err := getAnnotationDevices(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes, rawSpec.Annotations)
if err != nil {
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
}
if len(annotationDevices) > 0 {
return annotationDevices, nil
if rawSpec != nil {
annotationDevices, err := getAnnotationDevices(c.annotationPrefixes, rawSpec.Annotations)
if err != nil {
return nil, fmt.Errorf("failed to parse container annotations: %v", err)
}
if len(annotationDevices) > 0 {
return annotationDevices, nil
}
}

container, err := image.NewCUDAImageFromSpec(
rawSpec,
image.WithLogger(logger),
image.WithLogger(c.logger),
image.WithAcceptDeviceListAsVolumeMounts(c.acceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(c.acceptEnvvarUnprivileged),
)
if err != nil {
return nil, err
}
if cfg.AcceptDeviceListAsVolumeMounts {
mountDevices := container.CDIDevicesFromMounts()
if len(mountDevices) > 0 {
return mountDevices, nil
}
}

var devices []string
seen := make(map[string]bool)
for _, name := range container.VisibleDevicesFromEnvVar() {
for _, name := range container.VisibleDevices() {
if !parser.IsQualifiedName(name) {
name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name)
name = fmt.Sprintf("%s=%s", c.defaultKind, name)
}
if seen[name] {
logger.Debugf("Ignoring duplicate device %q", name)
c.logger.Debugf("Ignoring duplicate device %q", name)
continue
}
seen[name] = true
devices = append(devices, name)
}

if len(devices) == 0 {
return nil, nil
}

if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) {
return devices, nil
}

logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES: %v", devices)

return nil, nil
return devices, nil
}

// getAnnotationDevices returns a list of devices specified in the annotations.
Expand Down
70 changes: 70 additions & 0 deletions internal/modifier/cdi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ import (
"fmt"
"testing"

"github.com/opencontainers/runtime-spec/specs-go"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"

"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)

func TestGetAnnotationDevices(t *testing.T) {
Expand Down Expand Up @@ -90,3 +94,69 @@ func TestGetAnnotationDevices(t *testing.T) {
})
}
}

func TestGetDevicesFromSpec(t *testing.T) {
logger, _ := testlog.NewNullLogger()

testCases := []struct {
description string
input cdiModifier
spec *specs.Spec
expectedDevices []string
}{
{
description: "empty spec yields no devices",
},
{
description: "cdi devices from mounts",
input: cdiModifier{
defaultKind: "nvidia.com/gpu",
acceptEnvvarUnprivileged: true,
acceptDeviceListAsVolumeMounts: true,
},
spec: &specs.Spec{
Mounts: []specs.Mount{
{
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0",
Source: "/dev/null",
},
{
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/1",
Source: "/dev/null",
},
},
},
expectedDevices: []string{"nvidia.com/gpu=0", "nvidia.com/gpu=1"},
},
{
description: "cdi devices from envvar",
input: cdiModifier{
defaultKind: "nvidia.com/gpu",
acceptEnvvarUnprivileged: true,
acceptDeviceListAsVolumeMounts: true,
},
spec: &specs.Spec{
Process: &specs.Process{
Env: []string{"NVIDIA_VISIBLE_DEVICES=0,example.com/class=device"},
},
},
expectedDevices: []string{"nvidia.com/gpu=0", "example.com/class=device"},
},
}

for _, tc := range testCases {
tc.input.logger = logger

spec := &oci.SpecMock{
LoadFunc: func() (*specs.Spec, error) {
return tc.spec, nil
},
}

t.Run(tc.description, func(t *testing.T) {
devices, err := tc.input.getDevicesFromSpec(spec)
require.NoError(t, err)
require.EqualValues(t, tc.expectedDevices, devices)
})
}
}