diff --git a/internal/config/image/builder.go b/internal/config/image/builder.go index 7acd3ba0b..a2e29e776 100644 --- a/internal/config/image/builder.go +++ b/internal/config/image/builder.go @@ -50,7 +50,6 @@ func New(opt ...Option) (CUDA, error) { if b.logger == nil { b.logger = logger.New() } - if b.env == nil { b.env = make(map[string]string) } @@ -81,6 +80,20 @@ func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option { } } +func WithAnnotations(annotations map[string]string) Option { + return func(b *builder) error { + b.annotations = annotations + return nil + } +} + +func WithAnnotationsPrefixes(annotationsPrefixes []string) Option { + return func(b *builder) error { + b.annotationsPrefixes = annotationsPrefixes + return nil + } +} + // WithDisableRequire sets the disable require option. func WithDisableRequire(disableRequire bool) Option { return func(b *builder) error { diff --git a/internal/config/image/cuda_image.go b/internal/config/image/cuda_image.go index b16e3fedf..c295d105d 100644 --- a/internal/config/image/cuda_image.go +++ b/internal/config/image/cuda_image.go @@ -42,10 +42,12 @@ const ( type CUDA struct { logger logger.Interface + annotations map[string]string env map[string]string isPrivileged bool mounts []specs.Mount + annotationsPrefixes []string acceptDeviceListAsVolumeMounts bool acceptEnvvarUnprivileged bool preferredVisibleDeviceEnvVars []string @@ -54,12 +56,17 @@ type CUDA struct { // NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec. // The process environment is read (if present) to construc the CUDA Image. func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) { + if spec == nil { + return New(opts...) + } + var env []string - if spec != nil && spec.Process != nil { + if spec.Process != nil { env = spec.Process.Env } specOpts := []Option{ + WithAnnotations(spec.Annotations), WithEnv(env), WithMounts(spec.Mounts), WithPrivileged(IsPrivileged((*OCISpec)(spec))), @@ -95,6 +102,10 @@ func (i CUDA) IsLegacy() bool { return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 } +func (i CUDA) IsPrivileged() bool { + return i.isPrivileged +} + // GetRequirements returns the requirements from all NVIDIA_REQUIRE_ environment // variables. func (i CUDA) GetRequirements() ([]string, error) { @@ -212,19 +223,12 @@ func parseMajorMinorVersion(version string) (string, error) { // OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool { var hasCDIdevice bool - for _, device := range i.VisibleDevicesFromEnvVar() { + for _, device := range i.VisibleDevices() { if !parser.IsQualifiedName(device) { return false } hasCDIdevice = true } - - for _, device := range i.DevicesFromMounts() { - if !strings.HasPrefix(device, "cdi/") { - return false - } - hasCDIdevice = true - } return hasCDIdevice } @@ -234,6 +238,12 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool { // In cases where environment variable requests required privileged containers, // such devices requests are ignored. func (i CUDA) VisibleDevices() []string { + // If annotation device requests are present, these are preferred. + annotationDeviceRequests := i.cdiDeviceRequestsFromAnnotations() + if len(annotationDeviceRequests) > 0 { + return annotationDeviceRequests + } + // If enabled, try and get the device list from volume mounts first if i.acceptDeviceListAsVolumeMounts { volumeMountDeviceRequests := i.visibleDevicesFromMounts() @@ -260,6 +270,31 @@ func (i CUDA) VisibleDevices() []string { return nil } +// cdiDeviceRequestsFromAnnotations returns a list of devices specified in the +// annotations. +// Keys starting with the specified prefixes are considered and expected to +// contain a comma-separated list of fully-qualified CDI devices names. +// The format of the requested devices is not checked and the list is not +// deduplicated. +func (i CUDA) cdiDeviceRequestsFromAnnotations() []string { + if len(i.annotationsPrefixes) == 0 || len(i.annotations) == 0 { + return nil + } + + var devices []string + for key, value := range i.annotations { + for _, prefix := range i.annotationsPrefixes { + if strings.HasPrefix(key, prefix) { + devices = append(devices, strings.Split(value, ",")...) + // There is no need to check additional prefixes since we + // typically deduplicate devices in any case. + break + } + } + } + return devices +} + // VisibleDevicesFromEnvVar returns the set of visible devices requested through environment variables. // If any of the preferredVisibleDeviceEnvVars are present in the image, they // are used to determine the visible devices. If this is not the case, the @@ -276,20 +311,27 @@ func (i CUDA) VisibleDevicesFromEnvVar() []string { // visibleDevicesFromMounts returns the set of visible devices requested as mounts. func (i CUDA) visibleDevicesFromMounts() []string { var devices []string - for _, device := range i.DevicesFromMounts() { + for _, device := range i.requestsFromMounts() { switch { - case strings.HasPrefix(device, volumeMountDevicePrefixCDI): - continue case strings.HasPrefix(device, volumeMountDevicePrefixImex): continue + case strings.HasPrefix(device, volumeMountDevicePrefixCDI): + name, err := cdiDeviceMountRequest(device).qualifiedName() + if err != nil { + i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %v", device, err) + continue + } + devices = append(devices, name) + default: + devices = append(devices, device) } - devices = append(devices, device) + } return devices } -// DevicesFromMounts returns a list of device specified as mounts. -func (i CUDA) DevicesFromMounts() []string { +// requestsFromMounts returns a list of device specified as mounts. +func (i CUDA) requestsFromMounts() []string { root := filepath.Clean(DeviceListAsVolumeMountsRoot) seen := make(map[string]bool) var devices []string @@ -321,23 +363,30 @@ func (i CUDA) DevicesFromMounts() []string { return devices } -// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image. -func (i CUDA) CDIDevicesFromMounts() []string { - var devices []string - for _, mountDevice := range i.DevicesFromMounts() { - if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) { - continue - } - parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3) - if len(parts) != 3 { - continue - } - vendor := parts[0] - class := parts[1] - device := parts[2] - devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device)) +// a cdiDeviceMountRequest represents a CDI device requests as a mount. +// Here the host path /dev/null is mounted to a particular path in the container. +// The container path has the form: +// /var/run/nvidia-container-devices/cdi/// +// or +// /var/run/nvidia-container-devices/cdi//= +type cdiDeviceMountRequest string + +// qualifiedName returns the fully-qualified name of the CDI device. +func (m cdiDeviceMountRequest) qualifiedName() (string, error) { + if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) { + return "", fmt.Errorf("invalid mount CDI device request: %s", m) } - return devices + + requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI) + if parser.IsQualifiedName(requestedDevice) { + return requestedDevice, nil + } + + parts := strings.SplitN(requestedDevice, "/", 3) + if len(parts) != 3 { + return "", fmt.Errorf("invalid mount CDI device request: %s", m) + } + return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil } // ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image. @@ -352,7 +401,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string { // ImexChannelsFromMounts returns the list of IMEX channels requested for the image. func (i CUDA) ImexChannelsFromMounts() []string { var channels []string - for _, mountDevice := range i.DevicesFromMounts() { + for _, mountDevice := range i.requestsFromMounts() { if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) { continue } diff --git a/internal/config/image/cuda_image_test.go b/internal/config/image/cuda_image_test.go index 044466d38..7302a6959 100644 --- a/internal/config/image/cuda_image_test.go +++ b/internal/config/image/cuda_image_test.go @@ -487,9 +487,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) { expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"}, }, { - description: "cdi devices are ignored", - mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"), - expectedDevices: []string{"GPU0", "GPU1"}, + description: "cdi devices are included", + mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"), + expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"}, }, { description: "imex devices are ignored", @@ -649,6 +649,73 @@ func TestImexChannelsFromEnvVar(t *testing.T) { } } +func TestCDIDeviceRequestsFromAnnotations(t *testing.T) { + testCases := []struct { + description string + prefixes []string + annotations map[string]string + expectedDevices []string + }{ + { + description: "no annotations", + }, + { + description: "no matching annotations", + prefixes: []string{"not-prefix/"}, + annotations: map[string]string{ + "prefix/foo": "example.com/device=bar", + }, + }, + { + description: "single matching annotation", + prefixes: []string{"prefix/"}, + annotations: map[string]string{ + "prefix/foo": "example.com/device=bar", + }, + expectedDevices: []string{"example.com/device=bar"}, + }, + { + description: "multiple matching annotations", + prefixes: []string{"prefix/", "another-prefix/"}, + annotations: map[string]string{ + "prefix/foo": "example.com/device=bar", + "another-prefix/bar": "example.com/device=baz", + }, + expectedDevices: []string{"example.com/device=bar", "example.com/device=baz"}, + }, + { + description: "multiple matching annotations with duplicate devices", + prefixes: []string{"prefix/", "another-prefix/"}, + annotations: map[string]string{ + "prefix/foo": "example.com/device=bar", + "another-prefix/bar": "example.com/device=bar", + }, + expectedDevices: []string{"example.com/device=bar", "example.com/device=bar"}, + }, + { + description: "invalid devices are returned as is", + prefixes: []string{"prefix/"}, + annotations: map[string]string{ + "prefix/foo": "example.com/device", + }, + expectedDevices: []string{"example.com/device"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + image, err := New( + WithAnnotationsPrefixes(tc.prefixes), + WithAnnotations(tc.annotations), + ) + require.NoError(t, err) + + devices := image.cdiDeviceRequestsFromAnnotations() + require.ElementsMatch(t, tc.expectedDevices, devices) + }) + } +} + func makeTestMounts(paths ...string) []specs.Mount { var mounts []specs.Mount for _, path := range paths { diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index ad1475265..c2ab93d72 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "legacy", }, { - description: "cdi mount and non-CDI envvar resolves to legacy", + description: "cdi mount and non-CDI envvar resolves to cdi", mode: "auto", envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "0", @@ -197,6 +197,22 @@ func TestResolveAutoMode(t *testing.T) { "tegra": false, "nvgpu": false, }, + expectedMode: "cdi", + }, + { + description: "non-cdi mount and CDI envvar resolves to legacy", + mode: "auto", + envmap: map[string]string{ + "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0", + }, + mounts: []string{ + "/var/run/nvidia-container-devices/0", + }, + info: map[string]bool{ + "nvml": true, + "tegra": false, + "nvgpu": false, + }, expectedMode: "legacy", }, } @@ -232,6 +248,8 @@ func TestResolveAutoMode(t *testing.T) { image, _ := image.New( image.WithEnvMap(tc.envmap), image.WithMounts(mounts), + image.WithAcceptDeviceListAsVolumeMounts(true), + image.WithAcceptEnvvarUnprivileged(true), ) mode := resolveMode(logger, tc.mode, image, properties) require.EqualValues(t, tc.expectedMode, mode) diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 6c7286af6..f8559d3f0 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -18,7 +18,6 @@ package modifier import ( "fmt" - "strings" "tags.cncf.io/container-device-interface/pkg/parser" @@ -34,11 +33,13 @@ import ( // NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the // CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is // used to select the devices to include. -func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - devices, err := getDevicesFromSpec(logger, ociSpec, cfg) - if err != nil { - return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err) - } +func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { + deviceRequestor := newCDIDeviceRequestor( + logger, + image, + cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, + ) + devices := deviceRequestor.DeviceRequests() if len(devices) == 0 { logger.Debugf("No devices requested; no modification required.") return nil, nil @@ -65,90 +66,38 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe ) } -func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } +type deviceRequestor interface { + DeviceRequests() []string +} - annotationDevices, err := getAnnotationDevices(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes, rawSpec.Annotations) - if err != nil { - return nil, fmt.Errorf("failed to parse container annotations: %v", err) - } - if len(annotationDevices) > 0 { - return annotationDevices, nil - } +type cdiDeviceRequestor struct { + image image.CUDA + logger logger.Interface + defaultKind string +} - container, err := image.NewCUDAImageFromSpec( - rawSpec, - image.WithLogger(logger), - ) - if err != nil { - return nil, err - } - if cfg.AcceptDeviceListAsVolumeMounts { - mountDevices := container.CDIDevicesFromMounts() - if len(mountDevices) > 0 { - return mountDevices, nil - } +func newCDIDeviceRequestor(logger logger.Interface, image image.CUDA, defaultKind string) deviceRequestor { + c := &cdiDeviceRequestor{ + logger: logger, + image: image, + defaultKind: defaultKind, } + return withUniqueDevices(c) +} +func (c *cdiDeviceRequestor) DeviceRequests() []string { + if c == nil { + return nil + } var devices []string - seen := make(map[string]bool) - for _, name := range container.VisibleDevicesFromEnvVar() { + for _, name := range c.image.VisibleDevices() { if !parser.IsQualifiedName(name) { - name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name) - } - if seen[name] { - logger.Debugf("Ignoring duplicate device %q", name) - continue + name = fmt.Sprintf("%s=%s", c.defaultKind, name) } devices = append(devices, name) } - if len(devices) == 0 { - return nil, nil - } - - if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) { - return devices, nil - } - - logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES: %v", devices) - - return nil, nil -} - -// getAnnotationDevices returns a list of devices specified in the annotations. -// Keys starting with the specified prefixes are considered and expected to contain a comma-separated list of -// fully-qualified CDI devices names. If any device name is not fully-quality an error is returned. -// The list of returned devices is deduplicated. -func getAnnotationDevices(prefixes []string, annotations map[string]string) ([]string, error) { - devicesByKey := make(map[string][]string) - for key, value := range annotations { - for _, prefix := range prefixes { - if strings.HasPrefix(key, prefix) { - devicesByKey[key] = strings.Split(value, ",") - } - } - } - - seen := make(map[string]bool) - var annotationDevices []string - for key, devices := range devicesByKey { - for _, device := range devices { - if !parser.IsQualifiedName(device) { - return nil, fmt.Errorf("invalid device name %q in annotation %q", device, key) - } - if seen[device] { - continue - } - annotationDevices = append(annotationDevices, device) - seen[device] = true - } - } - - return annotationDevices, nil + return devices } // filterAutomaticDevices searches for "automatic" device names in the input slice. @@ -172,7 +121,7 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de if err != nil { return nil, fmt.Errorf("failed to generate CDI spec: %w", err) } - cdiModifier, err := cdi.New( + cdiDeviceRequestor, err := cdi.New( cdi.WithLogger(logger), cdi.WithSpec(spec.Raw()), ) @@ -180,7 +129,7 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de return nil, fmt.Errorf("failed to construct CDI modifier: %w", err) } - return cdiModifier, nil + return cdiDeviceRequestor, nil } func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devices []string) (spec.Interface, error) { @@ -218,3 +167,27 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devic spec.WithClass("gpu"), ) } + +type deduplicatedDeviceRequestor struct { + deviceRequestor +} + +func withUniqueDevices(deviceRequestor deviceRequestor) deviceRequestor { + return &deduplicatedDeviceRequestor{deviceRequestor: deviceRequestor} +} + +func (d *deduplicatedDeviceRequestor) DeviceRequests() []string { + if d == nil { + return nil + } + seen := make(map[string]bool) + var devices []string + for _, device := range d.deviceRequestor.DeviceRequests() { + if seen[device] { + continue + } + seen[device] = true + devices = append(devices, device) + } + return devices +} diff --git a/internal/modifier/cdi_test.go b/internal/modifier/cdi_test.go index 88ff697ab..881e8b2c2 100644 --- a/internal/modifier/cdi_test.go +++ b/internal/modifier/cdi_test.go @@ -17,76 +17,144 @@ package modifier import ( - "fmt" "testing" + "github.com/opencontainers/runtime-spec/specs-go" + testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" ) -func TestGetAnnotationDevices(t *testing.T) { +func TestDeviceRequests(t *testing.T) { + logger, _ := testlog.NewNullLogger() + testCases := []struct { description string + input cdiDeviceRequestor + spec *specs.Spec prefixes []string - annotations map[string]string expectedDevices []string - expectedError error }{ { - description: "no annotations", + description: "empty spec yields no devices", + }, + { + description: "cdi devices from mounts", + input: cdiDeviceRequestor{ + defaultKind: "nvidia.com/gpu", + }, + spec: &specs.Spec{ + Mounts: []specs.Mount{ + { + Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0", + Source: "/dev/null", + }, + { + Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/1", + Source: "/dev/null", + }, + }, + }, + expectedDevices: []string{"nvidia.com/gpu=0", "nvidia.com/gpu=1"}, + }, + { + description: "cdi devices from envvar", + input: cdiDeviceRequestor{ + defaultKind: "nvidia.com/gpu", + }, + spec: &specs.Spec{ + Process: &specs.Process{ + Env: []string{"NVIDIA_VISIBLE_DEVICES=0,example.com/class=device"}, + }, + }, + expectedDevices: []string{"nvidia.com/gpu=0", "example.com/class=device"}, }, { description: "no matching annotations", prefixes: []string{"not-prefix/"}, - annotations: map[string]string{ - "prefix/foo": "example.com/device=bar", + spec: &specs.Spec{ + Annotations: map[string]string{ + "prefix/foo": "example.com/device=bar", + }, }, }, { description: "single matching annotation", prefixes: []string{"prefix/"}, - annotations: map[string]string{ - "prefix/foo": "example.com/device=bar", + spec: &specs.Spec{ + Annotations: map[string]string{ + "prefix/foo": "example.com/device=bar", + }, }, expectedDevices: []string{"example.com/device=bar"}, }, { description: "multiple matching annotations", prefixes: []string{"prefix/", "another-prefix/"}, - annotations: map[string]string{ - "prefix/foo": "example.com/device=bar", - "another-prefix/bar": "example.com/device=baz", + spec: &specs.Spec{ + Annotations: map[string]string{ + "prefix/foo": "example.com/device=bar", + "another-prefix/bar": "example.com/device=baz", + }, }, expectedDevices: []string{"example.com/device=bar", "example.com/device=baz"}, }, { description: "multiple matching annotations with duplicate devices", prefixes: []string{"prefix/", "another-prefix/"}, - annotations: map[string]string{ - "prefix/foo": "example.com/device=bar", - "another-prefix/bar": "example.com/device=bar", + spec: &specs.Spec{ + Annotations: map[string]string{ + "prefix/foo": "example.com/device=bar", + "another-prefix/bar": "example.com/device=bar", + }, }, - expectedDevices: []string{"example.com/device=bar"}, + expectedDevices: []string{"example.com/device=bar", "example.com/device=bar"}, }, { - description: "invalid devices", - prefixes: []string{"prefix/"}, - annotations: map[string]string{ - "prefix/foo": "example.com/device", + description: "devices in annotations are expanded", + input: cdiDeviceRequestor{ + defaultKind: "nvidia.com/gpu", + }, + prefixes: []string{"prefix/"}, + spec: &specs.Spec{ + Annotations: map[string]string{ + "prefix/foo": "device", + }, + }, + expectedDevices: []string{"nvidia.com/gpu=device"}, + }, + { + description: "invalid devices in annotations are treated as strings", + input: cdiDeviceRequestor{ + defaultKind: "nvidia.com/gpu", + }, + prefixes: []string{"prefix/"}, + spec: &specs.Spec{ + Annotations: map[string]string{ + "prefix/foo": "example.com/device", + }, }, - expectedError: fmt.Errorf("invalid device %q", "example.com/device"), + expectedDevices: []string{"nvidia.com/gpu=example.com/device"}, }, } for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - devices, err := getAnnotationDevices(tc.prefixes, tc.annotations) - if tc.expectedError != nil { - require.Error(t, err) - return - } + tc.input.logger = logger + image, err := image.NewCUDAImageFromSpec( + tc.spec, + image.WithAcceptDeviceListAsVolumeMounts(true), + image.WithAcceptEnvvarUnprivileged(true), + image.WithAnnotationsPrefixes(tc.prefixes), + ) + require.NoError(t, err) + tc.input.image = image + + t.Run(tc.description, func(t *testing.T) { + devices := tc.input.DeviceRequests() require.NoError(t, err) - require.ElementsMatch(t, tc.expectedDevices, devices) + require.EqualValues(t, tc.expectedDevices, devices) }) } } diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index a7b454a94..a4e992c19 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -65,29 +65,17 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) { - rawSpec, err := ociSpec.Load() - if err != nil { - return nil, fmt.Errorf("failed to load OCI spec: %v", err) - } - - image, err := image.NewCUDAImageFromSpec( - rawSpec, - image.WithLogger(logger), - ) + mode, image, err := initRuntimeModeAndImage(logger, cfg, ociSpec) if err != nil { return nil, err } - hookCreator := discover.NewHookCreator(discover.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path)) - - mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) - // We update the mode here so that we can continue passing just the config to other functions. - cfg.NVIDIAContainerRuntimeConfig.Mode = mode - modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, image) + modeModifier, err := newModeModifier(logger, mode, cfg, *image) if err != nil { return nil, err } + hookCreator := discover.NewHookCreator(discover.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path)) var modifiers modifier.List for _, modifierType := range supportedModifierTypes(mode) { switch modifierType { @@ -96,13 +84,13 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp case "nvidia-hook-remover": modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger)) case "graphics": - graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver, hookCreator) + graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, *image, driver, hookCreator) if err != nil { return nil, err } modifiers = append(modifiers, graphicsModifier) case "feature-gated": - featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image, driver, hookCreator) + featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, *image, driver, hookCreator) if err != nil { return nil, err } @@ -113,19 +101,58 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp return modifiers, nil } -func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec, image image.CUDA) (oci.SpecModifier, error) { +func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { switch mode { case "legacy": return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil case "csv": return modifier.NewCSVModifier(logger, cfg, image) case "cdi": - return modifier.NewCDIModifier(logger, cfg, ociSpec) + return modifier.NewCDIModifier(logger, cfg, image) } return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode) } +// initRuntimeModeAndImage constructs an image from the specified OCI runtime +// specification and runtime config. +// The image is also used to determine the runtime mode to apply. +// If a non-CDI mode is detected we ensure that the image does not process +// annotation devices. +func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (string, *image.CUDA, error) { + rawSpec, err := ociSpec.Load() + if err != nil { + return "", nil, fmt.Errorf("failed to load OCI spec: %v", err) + } + + image, err := image.NewCUDAImageFromSpec( + rawSpec, + image.WithLogger(logger), + image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts), + image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged), + image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes), + ) + if err != nil { + return "", nil, err + } + + mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) + // We update the mode here so that we can continue passing just the config to other functions. + cfg.NVIDIAContainerRuntimeConfig.Mode = mode + + if mode == "cdi" || len(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes) == 0 { + return mode, &image, nil + } + + // For non-cdi modes we explicitly set the annotation prefixes to nil and + // call this function again to force a reconstruction of the image. + // Note that since the mode is now explicitly set, we will effectively skip + // the mode resolution. + cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes = nil + + return initRuntimeModeAndImage(logger, cfg, ociSpec) +} + // supportedModifierTypes returns the modifiers supported for a specific runtime mode. func supportedModifierTypes(mode string) []string { switch mode {