Skip to content

Make CDI device requests consistent with other methods #1132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion internal/config/image/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ func New(opt ...Option) (CUDA, error) {
if b.logger == nil {
b.logger = logger.New()
}

if b.env == nil {
b.env = make(map[string]string)
}
Expand Down Expand Up @@ -81,6 +80,20 @@ func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option {
}
}

func WithAnnotations(annotations map[string]string) Option {
return func(b *builder) error {
b.annotations = annotations
return nil
}
}

func WithAnnotationsPrefixes(annotationsPrefixes []string) Option {
return func(b *builder) error {
b.annotationsPrefixes = annotationsPrefixes
return nil
}
}

// WithDisableRequire sets the disable require option.
func WithDisableRequire(disableRequire bool) Option {
return func(b *builder) error {
Expand Down
113 changes: 81 additions & 32 deletions internal/config/image/cuda_image.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ const (
type CUDA struct {
logger logger.Interface

annotations map[string]string
env map[string]string
isPrivileged bool
mounts []specs.Mount

annotationsPrefixes []string
acceptDeviceListAsVolumeMounts bool
acceptEnvvarUnprivileged bool
preferredVisibleDeviceEnvVars []string
Expand All @@ -54,12 +56,17 @@ type CUDA struct {
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
// The process environment is read (if present) to construc the CUDA Image.
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) {
if spec == nil {
return New(opts...)
}

var env []string
if spec != nil && spec.Process != nil {
if spec.Process != nil {
env = spec.Process.Env
}

specOpts := []Option{
WithAnnotations(spec.Annotations),
WithEnv(env),
WithMounts(spec.Mounts),
WithPrivileged(IsPrivileged((*OCISpec)(spec))),
Expand Down Expand Up @@ -95,6 +102,10 @@ func (i CUDA) IsLegacy() bool {
return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0
}

func (i CUDA) IsPrivileged() bool {
return i.isPrivileged
}

// GetRequirements returns the requirements from all NVIDIA_REQUIRE_ environment
// variables.
func (i CUDA) GetRequirements() ([]string, error) {
Expand Down Expand Up @@ -212,19 +223,12 @@ func parseMajorMinorVersion(version string) (string, error) {
// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/
func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
var hasCDIdevice bool
for _, device := range i.VisibleDevicesFromEnvVar() {
for _, device := range i.VisibleDevices() {
if !parser.IsQualifiedName(device) {
return false
}
hasCDIdevice = true
}

for _, device := range i.DevicesFromMounts() {
if !strings.HasPrefix(device, "cdi/") {
return false
}
hasCDIdevice = true
}
return hasCDIdevice
}

Expand All @@ -234,6 +238,12 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
// In cases where environment variable requests required privileged containers,
// such devices requests are ignored.
func (i CUDA) VisibleDevices() []string {
// If annotation device requests are present, these are preferred.
annotationDeviceRequests := i.cdiDeviceRequestsFromAnnotations()
if len(annotationDeviceRequests) > 0 {
return annotationDeviceRequests
}

// If enabled, try and get the device list from volume mounts first
if i.acceptDeviceListAsVolumeMounts {
volumeMountDeviceRequests := i.visibleDevicesFromMounts()
Expand All @@ -260,6 +270,31 @@ func (i CUDA) VisibleDevices() []string {
return nil
}

// cdiDeviceRequestsFromAnnotations returns a list of devices specified in the
// annotations.
// Keys starting with the specified prefixes are considered and expected to
// contain a comma-separated list of fully-qualified CDI devices names.
// The format of the requested devices is not checked and the list is not
// deduplicated.
func (i CUDA) cdiDeviceRequestsFromAnnotations() []string {
if len(i.annotationsPrefixes) == 0 || len(i.annotations) == 0 {
return nil
}

var devices []string
for key, value := range i.annotations {
for _, prefix := range i.annotationsPrefixes {
if strings.HasPrefix(key, prefix) {
devices = append(devices, strings.Split(value, ",")...)
// There is no need to check additional prefixes since we
// typically deduplicate devices in any case.
break
}
}
}
return devices
}

// VisibleDevicesFromEnvVar returns the set of visible devices requested through environment variables.
// If any of the preferredVisibleDeviceEnvVars are present in the image, they
// are used to determine the visible devices. If this is not the case, the
Expand All @@ -276,20 +311,27 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string {
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
func (i CUDA) visibleDevicesFromMounts() []string {
var devices []string
for _, device := range i.DevicesFromMounts() {
for _, device := range i.requestsFromMounts() {
switch {
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
continue
case strings.HasPrefix(device, volumeMountDevicePrefixImex):
continue
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
name, err := cdiDeviceMountRequest(device).qualifiedName()
if err != nil {
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %v", device, err)
continue
}
devices = append(devices, name)
default:
devices = append(devices, device)
}
devices = append(devices, device)

}
return devices
}

// DevicesFromMounts returns a list of device specified as mounts.
func (i CUDA) DevicesFromMounts() []string {
// requestsFromMounts returns a list of device specified as mounts.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this function since we're returning all requests including:

  • device requests (legacy and CDI)
  • imex channel requests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: DevicesFromMounts has been renamed and turned private to the config pkg

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It was not used outside this package.

func (i CUDA) requestsFromMounts() []string {
root := filepath.Clean(DeviceListAsVolumeMountsRoot)
seen := make(map[string]bool)
var devices []string
Expand Down Expand Up @@ -321,23 +363,30 @@ func (i CUDA) DevicesFromMounts() []string {
return devices
}

// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
func (i CUDA) CDIDevicesFromMounts() []string {
var devices []string
for _, mountDevice := range i.DevicesFromMounts() {
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) {
continue
}
parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3)
if len(parts) != 3 {
continue
}
vendor := parts[0]
class := parts[1]
device := parts[2]
devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device))
// a cdiDeviceMountRequest represents a CDI device requests as a mount.
// Here the host path /dev/null is mounted to a particular path in the container.
// The container path has the form:
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>/<device>
// or
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>=<device>
type cdiDeviceMountRequest string

// qualifiedName returns the fully-qualified name of the CDI device.
func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) {
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
}
return devices

requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI)
if parser.IsQualifiedName(requestedDevice) {
return requestedDevice, nil
}

parts := strings.SplitN(requestedDevice, "/", 3)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this trying to parse vendor, class, device from

/cdi/<vendor>/<class>/<device>

?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, volumeMountDevicePrefixCDI = "cdi/".

and the argument 3 probably means "at most two splits", i.e. max three tokens.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a mounted CDI device will have the path /var/run/nvidia-container-devices/cdi/vendor.com/class/device or /var/run/nvidia-container-devices/cdi/vendor.com/class=device. This code handles the former case where we split vendor.com/class/device into AT MOST three parts.

From the SplitN docs:

// SplitN slices s into substrings separated by sep and returns a slice of
// the substrings between those separators.
//
// The count determines the number of substrings to return:
// - n > 0: at most n substrings; the last substring will be the unsplit remainder;
// - n == 0: the result is nil (zero substrings);
// - n < 0: all substrings.

if len(parts) != 3 {
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
}
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
}

// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
Expand All @@ -352,7 +401,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string {
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
func (i CUDA) ImexChannelsFromMounts() []string {
var channels []string
for _, mountDevice := range i.DevicesFromMounts() {
for _, mountDevice := range i.requestsFromMounts() {
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
continue
}
Expand Down
73 changes: 70 additions & 3 deletions internal/config/image/cuda_image_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
},
{
description: "cdi devices are ignored",
mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"),
expectedDevices: []string{"GPU0", "GPU1"},
description: "cdi devices are included",
mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"),
expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"},
},
{
description: "imex devices are ignored",
Expand Down Expand Up @@ -649,6 +649,73 @@ func TestImexChannelsFromEnvVar(t *testing.T) {
}
}

func TestCDIDeviceRequestsFromAnnotations(t *testing.T) {
testCases := []struct {
description string
prefixes []string
annotations map[string]string
expectedDevices []string
}{
{
description: "no annotations",
},
{
description: "no matching annotations",
prefixes: []string{"not-prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device=bar",
},
},
{
description: "single matching annotation",
prefixes: []string{"prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device=bar",
},
expectedDevices: []string{"example.com/device=bar"},
},
{
description: "multiple matching annotations",
prefixes: []string{"prefix/", "another-prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device=bar",
"another-prefix/bar": "example.com/device=baz",
},
expectedDevices: []string{"example.com/device=bar", "example.com/device=baz"},
},
{
description: "multiple matching annotations with duplicate devices",
prefixes: []string{"prefix/", "another-prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device=bar",
"another-prefix/bar": "example.com/device=bar",
},
expectedDevices: []string{"example.com/device=bar", "example.com/device=bar"},
},
{
description: "invalid devices are returned as is",
prefixes: []string{"prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device",
},
expectedDevices: []string{"example.com/device"},
},
}

for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
image, err := New(
WithAnnotationsPrefixes(tc.prefixes),
WithAnnotations(tc.annotations),
)
require.NoError(t, err)

devices := image.cdiDeviceRequestsFromAnnotations()
require.ElementsMatch(t, tc.expectedDevices, devices)
})
}
}

func makeTestMounts(paths ...string) []specs.Mount {
var mounts []specs.Mount
for _, path := range paths {
Expand Down
20 changes: 19 additions & 1 deletion internal/info/auto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) {
expectedMode: "legacy",
},
{
description: "cdi mount and non-CDI envvar resolves to legacy",
description: "cdi mount and non-CDI envvar resolves to cdi",
mode: "auto",
envmap: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "0",
Expand All @@ -197,6 +197,22 @@ func TestResolveAutoMode(t *testing.T) {
"tegra": false,
"nvgpu": false,
},
expectedMode: "cdi",
},
{
description: "non-cdi mount and CDI envvar resolves to legacy",
mode: "auto",
envmap: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0",
},
mounts: []string{
"/var/run/nvidia-container-devices/0",
},
info: map[string]bool{
"nvml": true,
"tegra": false,
"nvgpu": false,
},
expectedMode: "legacy",
},
}
Expand Down Expand Up @@ -232,6 +248,8 @@ func TestResolveAutoMode(t *testing.T) {
image, _ := image.New(
image.WithEnvMap(tc.envmap),
image.WithMounts(mounts),
image.WithAcceptDeviceListAsVolumeMounts(true),
image.WithAcceptEnvvarUnprivileged(true),
)
mode := resolveMode(logger, tc.mode, image, properties)
require.EqualValues(t, tc.expectedMode, mode)
Expand Down
Loading