@@ -85,11 +85,21 @@ static Try<set<Gpu>> enumerateGpus(
85
85
if (flags.nvidia_gpu_devices .isSome ()) {
86
86
indices = flags.nvidia_gpu_devices .get ();
87
87
} else {
88
- for (size_t i = 0 ; i < resources.gpus ().getOrElse (0 ); ++i) {
88
+ Try<unsigned int > available = nvml::deviceGetCount ();
89
+ if (available.isError ()) {
90
+ return Error (" Failed to nvml::deviceGetCount: " + available.error ());
91
+ }
92
+
93
+ for (unsigned int i = 0 ; i < available.get (); ++i) {
89
94
indices.push_back (i);
90
95
}
91
96
}
92
97
98
+ Try<unsigned int > caps_major = nvml::systemGetCapsMajor ();
99
+ if (caps_major.isError ()) {
100
+ return Error (" Failed to get nvidia caps major: " + caps_major.error ());
101
+ }
102
+
93
103
set<Gpu> gpus;
94
104
95
105
foreach (unsigned int index, indices) {
@@ -103,17 +113,91 @@ static Try<set<Gpu>> enumerateGpus(
103
113
return Error (" Failed to nvml::deviceGetMinorNumber: " + minor.error ());
104
114
}
105
115
106
- Gpu gpu;
107
- gpu.major = NVIDIA_MAJOR_DEVICE;
108
- gpu.minor = minor.get ();
116
+ Try<bool > ismig = nvml::deviceGetMigMode (handle.get ());
117
+ if (ismig.isError ()) {
118
+ return Error (" Failed to nvml::deviceGetMigMode: " + ismig.error ());
119
+ }
120
+
121
+ if (!ismig.get ()) {
122
+ Gpu gpu;
123
+ gpu.major = NVIDIA_MAJOR_DEVICE;
124
+ gpu.minor = minor.get ();
125
+
126
+ gpus.insert (gpu);
109
127
110
- gpus.insert (gpu);
128
+ continue ;
129
+ }
130
+
131
+ Try<unsigned int > migcount = nvml::deviceGetMigDeviceCount (handle.get ());
132
+ if (migcount.isError ()) {
133
+ return Error (" Failed to nvml::deviceGetMigDeviceCount: " + migcount.error ());
134
+ }
135
+
136
+ for (unsigned int migindex = 0 ; migindex < migcount.get (); migindex++) {
137
+ Try<nvmlDevice_t> mighandle = nvml::deviceGetMigDeviceHandleByIndex (handle.get (), migindex);
138
+ if (mighandle.isError ()) {
139
+ return Error (" Failed to nvml::deviceGetMigDeviceHandleByIndex: " + mighandle.error ());
140
+ }
141
+
142
+ Try<unsigned int > gi_minor = nvml::deviceGetGpuInstanceMinor (mighandle.get ());
143
+ if (gi_minor.isError ()) {
144
+ return Error (" Failed to nvml::deviceGetGpuInstanceMinor: " + gi_minor.error ());
145
+ }
146
+
147
+ Try<unsigned int > ci_minor = nvml::deviceGetComputeInstanceMinor (mighandle.get ());
148
+ if (ci_minor.isError ()) {
149
+ return Error (" Failed to nvml::deviceGetComputeInstanceMinor: " + ci_minor.error ());
150
+ }
151
+
152
+ Gpu gpu;
153
+ gpu.major = NVIDIA_MAJOR_DEVICE;
154
+ gpu.minor = minor.get ();
155
+ gpu.ismig = true ;
156
+ gpu.caps_major = caps_major.get ();
157
+ gpu.gi_minor = gi_minor.get ();
158
+ gpu.ci_minor = ci_minor.get ();
159
+
160
+ gpus.insert (gpu);
161
+ }
111
162
}
112
163
113
164
return gpus;
114
165
}
115
166
116
167
168
+ static Try<unsigned int > countGpuInstancesForDevices (
169
+ const vector<unsigned int >& devices)
170
+ {
171
+ unsigned int count = 0 ;
172
+
173
+ foreach (unsigned int device, devices) {
174
+ Try<nvmlDevice_t> handle = nvml::deviceGetHandleByIndex (device);
175
+ if (handle.isError ()) {
176
+ return Error (" Failed to nvml::deviceGetHandleByIndex: " + handle.error ());
177
+ }
178
+
179
+ Try<bool > ismig = nvml::deviceGetMigMode (handle.get ());
180
+ if (ismig.isError ()) {
181
+ return Error (" Failed to nvml::deviceGetMigMode: " + ismig.error ());
182
+ }
183
+
184
+ if (!ismig.get ()) {
185
+ count++;
186
+ continue ;
187
+ }
188
+
189
+ Try<unsigned int > migcount = nvml::deviceGetMigDeviceCount (handle.get ());
190
+ if (migcount.isError ()) {
191
+ return Error (" Failed to nvml::deviceGetMigDeviceCount: " + migcount.error ());
192
+ }
193
+
194
+ count += migcount.get ();
195
+ }
196
+
197
+ return count;
198
+ }
199
+
200
+
117
201
// To determine the proper number of GPU resources to return, we
118
202
// need to check both --resources and --nvidia_gpu_devices.
119
203
// There are two cases to consider:
@@ -174,11 +258,6 @@ static Try<Resources> enumerateGpuResources(const Flags& flags)
174
258
return Error (" Failed to nvml::initialize: " + initialized.error ());
175
259
}
176
260
177
- Try<unsigned int > available = nvml::deviceGetCount ();
178
- if (available.isError ()) {
179
- return Error (" Failed to nvml::deviceGetCount: " + available.error ());
180
- }
181
-
182
261
// The `Resources` wrapper does not allow us to distinguish between
183
262
// a user specifying "gpus:0" in the --resources flag and not
184
263
// specifying "gpus" at all. To help with this we short circuit
@@ -225,9 +304,11 @@ static Try<Resources> enumerateGpuResources(const Flags& flags)
225
304
return Error (" '--nvidia_gpu_devices' contains duplicates" );
226
305
}
227
306
228
- if (flags.nvidia_gpu_devices ->size () != resources.gpus ().get ()) {
229
- return Error (" '--resources' and '--nvidia_gpu_devices' specify"
230
- " different numbers of GPU devices" );
307
+ Try<unsigned int > available = countGpuInstancesForDevices (unique);
308
+ if (available.isError ()) {
309
+ return Error (" Failed to count all GPU instances for devices"
310
+ " specified by --nvidia_gpu_devices: "
311
+ + available.error ());
231
312
}
232
313
233
314
if (resources.gpus ().get () > available.get ()) {
@@ -238,6 +319,22 @@ static Try<Resources> enumerateGpuResources(const Flags& flags)
238
319
return resources;
239
320
}
240
321
322
+ Try<unsigned int > available = nvml::deviceGetCount ();
323
+ if (available.isError ()) {
324
+ return Error (" Failed to nvml::deviceGetCount: " + available.error ());
325
+ }
326
+
327
+ vector<unsigned int > indices;
328
+ for (unsigned int i = 0 ; i < available.get (); ++i) {
329
+ indices.push_back (i);
330
+ }
331
+
332
+ available = countGpuInstancesForDevices (indices);
333
+ if (available.isError ()) {
334
+ return Error (" Failed to count all GPU instances: "
335
+ + available.error ());
336
+ }
337
+
241
338
return Resources::parse (
242
339
" gpus" ,
243
340
stringify (available.get ()),
@@ -378,7 +475,15 @@ Future<Nothing> NvidiaGpuAllocator::deallocate(const set<Gpu>& gpus)
378
475
bool operator <(const Gpu& left, const Gpu& right)
379
476
{
380
477
if (left.major == right.major ) {
381
- return left.minor < right.minor ;
478
+ // Either or both aren't MIG, comparing major/minor is enough
479
+ if (!left.ismig || !right.ismig || (left.minor != right.minor )) {
480
+ return left.minor < right.minor ;
481
+ }
482
+
483
+ if (left.gi_minor == right.gi_minor ) {
484
+ return left.ci_minor < right.ci_minor ;
485
+ }
486
+ return left.gi_minor < right.gi_minor ;
382
487
}
383
488
return left.major < right.major ;
384
489
}
@@ -404,7 +509,14 @@ bool operator>=(const Gpu& left, const Gpu& right)
404
509
405
510
bool operator ==(const Gpu& left, const Gpu& right)
406
511
{
407
- return left.major == right.major && left.minor == right.minor ;
512
+ if (left.ismig != right.ismig )
513
+ return false ;
514
+
515
+ if (!left.ismig )
516
+ return left.major == right.major && left.minor == right.minor ;
517
+
518
+ return left.major == right.major && left.minor == right.minor
519
+ && left.gi_minor == right.gi_minor && left.ci_minor == right.ci_minor ;
408
520
}
409
521
410
522
@@ -416,7 +528,10 @@ bool operator!=(const Gpu& left, const Gpu& right)
416
528
417
529
ostream& operator <<(ostream& stream, const Gpu& gpu)
418
530
{
419
- return stream << gpu.major << ' .' << gpu.minor ;
531
+ if (gpu.ismig )
532
+ return stream << gpu.major << ' .' << gpu.minor << ' :' << gpu.gi_minor << ' .' << gpu.ci_minor ;
533
+ else
534
+ return stream << gpu.major << ' .' << gpu.minor ;
420
535
}
421
536
422
537
} // namespace slave {
0 commit comments