Skip to content

Use just-in-time CDI spec generation by default in the NVIDIA Container Runtime #910

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions cmd/nvidia-container-runtime-hook/container_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,14 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
}
}

func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
func (hookConfig *hookConfig) getContainerConfig() (config *containerConfig) {
hookConfig.Lock()
defer hookConfig.Unlock()

if hookConfig.containerConfig != nil {
return hookConfig.containerConfig
}

var h HookState
d := json.NewDecoder(os.Stdin)
if err := d.Decode(&h); err != nil {
Expand Down Expand Up @@ -271,10 +278,13 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
log.Panicln(err)
}

return containerConfig{
cc := containerConfig{
Pid: h.Pid,
Rootfs: s.Root.Path,
Image: i,
Nvidia: hookConfig.getNvidiaConfig(i, privileged),
}
hookConfig.containerConfig = &cc

return hookConfig.containerConfig
}
2 changes: 1 addition & 1 deletion cmd/nvidia-container-runtime-hook/container_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ func TestGetNvidiaConfig(t *testing.T) {
hookCfg := tc.hookConfig
if hookCfg == nil {
defaultConfig, _ := config.GetDefault()
hookCfg = &hookConfig{defaultConfig}
hookCfg = &hookConfig{Config: defaultConfig}
}
cfg = hookCfg.getNvidiaConfig(image, tc.privileged)
}
Expand Down
28 changes: 25 additions & 3 deletions cmd/nvidia-container-runtime-hook/hook_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"path"
"reflect"
"strings"
"sync"

"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
)

const (
Expand All @@ -20,7 +22,9 @@ const (
// hookConfig wraps the toolkit config.
// This allows for functions to be defined on the local type.
type hookConfig struct {
sync.Mutex
*config.Config
containerConfig *containerConfig
}

// loadConfig loads the required paths for the hook config.
Expand Down Expand Up @@ -55,7 +59,7 @@ func getHookConfig() (*hookConfig, error) {
if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err)
}
config := &hookConfig{cfg}
config := &hookConfig{Config: cfg}

allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" {
Expand All @@ -73,8 +77,8 @@ func getHookConfig() (*hookConfig, error) {

// getConfigOption returns the toml config option associated with the
// specified struct field.
func (c hookConfig) getConfigOption(fieldName string) string {
t := reflect.TypeOf(c)
func (c *hookConfig) getConfigOption(fieldName string) string {
t := reflect.TypeOf(&c)
f, ok := t.FieldByName(fieldName)
if !ok {
return fieldName
Expand Down Expand Up @@ -127,3 +131,21 @@ func (c *hookConfig) nvidiaContainerCliCUDACompatModeFlags() []string {
}
return []string{flag}
}

func (c *hookConfig) assertModeIsLegacy() error {
if c.NVIDIAContainerRuntimeHookConfig.SkipModeDetection {
return nil
}

mr := info.NewRuntimeModeResolver(
info.WithLogger(&logInterceptor{}),
info.WithImage(&c.containerConfig.Image),
info.WithDefaultMode(info.LegacyRuntimeMode),
)

mode := mr.ResolveRuntimeMode(c.NVIDIAContainerRuntimeConfig.Mode)
if mode == "legacy" {
return nil
}
return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead")
}
4 changes: 2 additions & 2 deletions cmd/nvidia-container-runtime-hook/hook_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
}
}

var cfg hookConfig
var cfg *hookConfig
getHookConfig := func() {
c, _ := getHookConfig()
cfg = *c
cfg = c
}

if tc.expectedPanic {
Expand Down
6 changes: 3 additions & 3 deletions cmd/nvidia-container-runtime-hook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func getCLIPath(config config.ContainerCLIConfig) string {
}

// getRootfsPath returns an absolute path. We don't need to resolve symlinks for now.
func getRootfsPath(config containerConfig) string {
func getRootfsPath(config *containerConfig) string {
rootfs, err := filepath.Abs(config.Rootfs)
if err != nil {
log.Panicln(err)
Expand All @@ -82,8 +82,8 @@ func doPrestart() {
return
}

if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" {
log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.")
if err := hook.assertModeIsLegacy(); err != nil {
log.Panicf("%v", err)
}

rootfs := getRootfsPath(container)
Expand Down
25 changes: 4 additions & 21 deletions cmd/nvidia-container-runtime/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,10 @@ func TestGoodInput(t *testing.T) {
err = cmdCreate.Run()
require.NoError(t, err, "runtime should not return an error")

// Check config.json for NVIDIA prestart hook
// Check config.json to ensure that the NVIDIA prestart was not inserted.
spec, err = cfg.getRuntimeSpec()
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json")
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
require.Empty(t, spec.Hooks, "there should be no hooks in config.json")
}

// NVIDIA prestart hook already present in config file
Expand Down Expand Up @@ -168,11 +167,10 @@ func TestDuplicateHook(t *testing.T) {
output, err := cmdCreate.CombinedOutput()
require.NoErrorf(t, err, "runtime should not return an error", "output=%v", string(output))

// Check config.json for NVIDIA prestart hook
// Check config.json to ensure that the NVIDIA prestart hook was removed.
spec, err = cfg.getRuntimeSpec()
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json")
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
require.Empty(t, spec.Hooks, "there should be no hooks in config.json")
}

// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
Expand Down Expand Up @@ -240,18 +238,3 @@ func (c testConfig) generateNewRuntimeSpec() error {
}
return nil
}

// Return number of valid NVIDIA prestart hooks in runtime spec
func nvidiaHookCount(hooks *specs.Hooks) int {
if hooks == nil {
return 0
}

count := 0
for _, hook := range hooks.Prestart {
if strings.Contains(hook.Path, nvidiaHook) {
count++
}
}
return count
}
106 changes: 93 additions & 13 deletions internal/info/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,114 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

// A RuntimeMode is used to select a specific mode of operation for the NVIDIA Container Runtime.
type RuntimeMode string

const (
// In LegacyRuntimeMode the nvidia-container-runtime injects the
// nvidia-container-runtime-hook as a prestart hook into the incoming
// container config. This hook invokes the nvidia-container-cli to perform
// the required modifications to the container.
LegacyRuntimeMode = RuntimeMode("legacy")
// In CSVRuntimeMode the nvidia-container-runtime processes a set of CSV
// files to determine which container modification are required. The
// contents of these CSV files are used to generate an in-memory CDI
// specification which is used to modify the container config.
CSVRuntimeMode = RuntimeMode("csv")
// In CDIRuntimeMode the nvidia-container-runtime applies the modifications
// to the container config required for the requested CDI devices in the
// same way that other CDI clients would.
CDIRuntimeMode = RuntimeMode("cdi")
// In JitCDIRuntimeMode the nvidia-container-runtime generates in-memory CDI
// specifications for requested NVIDIA devices.
JitCDIRuntimeMode = RuntimeMode("jit-cdi")
)

type RuntimeModeResolver interface {
ResolveRuntimeMode(string) RuntimeMode
}

type modeResolver struct {
logger logger.Interface
// TODO: This only needs to consider the requested devices.
image *image.CUDA
propertyExtractor info.PropertyExtractor
defaultMode RuntimeMode
}

type Option func(*modeResolver)

func WithDefaultMode(defaultMode RuntimeMode) Option {
return func(mr *modeResolver) {
mr.defaultMode = defaultMode
}
}

func WithLogger(logger logger.Interface) Option {
return func(mr *modeResolver) {
mr.logger = logger
}
}

func WithImage(image *image.CUDA) Option {
return func(mr *modeResolver) {
mr.image = image
}
}

func WithPropertyExtractor(propertyExtractor info.PropertyExtractor) Option {
return func(mr *modeResolver) {
mr.propertyExtractor = propertyExtractor
}
}

func NewRuntimeModeResolver(opts ...Option) RuntimeModeResolver {
r := &modeResolver{
defaultMode: JitCDIRuntimeMode,
}
for _, opt := range opts {
opt(r)
}
if r.logger == nil {
r.logger = &logger.NullLogger{}
}

return r
}

// ResolveAutoMode determines the correct mode for the platform if set to "auto"
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) {
return resolveMode(logger, mode, image, nil)
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode RuntimeMode) {
r := modeResolver{
logger: logger,
image: &image,
propertyExtractor: nil,
}
return r.ResolveRuntimeMode(mode)
}

func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) {
func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode RuntimeMode) {
if mode != "auto" {
logger.Infof("Using requested mode '%s'", mode)
return mode
m.logger.Infof("Using requested mode '%s'", mode)
return RuntimeMode(mode)
}
defer func() {
logger.Infof("Auto-detected mode as '%v'", rmode)
m.logger.Infof("Auto-detected mode as '%v'", rmode)
}()

if image.OnlyFullyQualifiedCDIDevices() {
return "cdi"
if m.image.OnlyFullyQualifiedCDIDevices() {
return CDIRuntimeMode
}

nvinfo := info.New(
info.WithLogger(logger),
info.WithPropertyExtractor(propertyExtractor),
info.WithLogger(m.logger),
info.WithPropertyExtractor(m.propertyExtractor),
)

switch nvinfo.ResolvePlatform() {
case info.PlatformNVML, info.PlatformWSL:
return "legacy"
return m.defaultMode
case info.PlatformTegra:
return "csv"
return CSVRuntimeMode
}
return "legacy"
return m.defaultMode
}
Loading