-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathcutset.py
702 lines (635 loc) · 30.5 KB
/
cutset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import warnings
from functools import partial
from itertools import repeat
from pathlib import Path
from typing import KeysView, Mapping, Sequence, Tuple, Union
import omegaconf
from lhotse import CutSet, Features, Recording
from lhotse.array import Array, TemporalArray
from lhotse.cut import Cut, MixedCut, PaddingCut
from omegaconf import DictConfig, ListConfig, OmegaConf
from nemo.collections.common.data.lhotse.nemo_adapters import (
LazyNeMoIterator,
LazyNeMoTarredIterator,
expand_sharded_filepaths,
)
from nemo.collections.common.data.lhotse.sampling import PlaceholderFilter
from nemo.collections.common.data.lhotse.text_adapters import (
LhotseTextAdapter,
LhotseTextPairAdapter,
NeMoMultimodalConversationJsonlAdapter,
NeMoSFTJsonlAdapter,
)
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bool]:
"""
Reads NeMo configuration and creates a CutSet either from Lhotse or NeMo manifests.
Returns a tuple of ``CutSet`` and a boolean indicating whether the data is tarred (True) or not (False).
"""
# First, check if the dataset is specified in the new configuration format and use it if possible.
if not isinstance(config, DictConfig):
config = DictConfig(config)
if config.get("input_cfg") is not None:
cuts, is_tarred = read_dataset_config(config)
else:
# Now, we'll figure out if we should read Lhotse manifest or NeMo manifest.
use_nemo_manifest = all(config.get(opt) is None for opt in ("cuts_path", "shar_path"))
if use_nemo_manifest:
if config.get("manifest_filepath") is None:
raise IncompleteConfigError("You must specify either: manifest_filepath, cuts_path, or shar_path")
cuts, is_tarred = read_nemo_manifest(config)
else:
cuts, is_tarred = read_lhotse_manifest(config)
# After reading cuts we filter cutsets to exclude cuts with valid "_skipme" values.
# This filtration is done before mixing cutsets as well. Here it is being done for non-mixed cutsets.
cuts = cuts.filter(PlaceholderFilter())
return cuts, is_tarred
class IncompleteConfigError(RuntimeError):
"""Placeholder for an error raised."""
pass
KNOWN_DATA_CONFIG_TYPES = {}
def get_known_config_data_types() -> KeysView[str]:
"""
Return the names of all registered data type parsers.
Example:
>>> get_known_config_data_types()
["nemo", "nemo_tarred", "lhotse", ...]
"""
return KNOWN_DATA_CONFIG_TYPES.keys()
def get_parser_fn(data_type_name: str):
"""
Return the parsing function for a given data type name.
Parsing function reads a dataloading config and returns a tuple
of lhotse ``CutSet`` and boolean indicating whether we should use
iterable dataset (True) or map dataset (False) mechanism ("is tarred").
"""
return KNOWN_DATA_CONFIG_TYPES[data_type_name]
def data_type_parser(name: Union[str, list[str]]):
"""
Decorator used to register data type parser functions.
Parsing function reads a dataloading config and returns a tuple
of lhotse ``CutSet`` and boolean indicating whether we should use
iterable dataset (True) or map dataset (False) mechanism ("is tarred").
Example:
>>> @data_type_parser("my_new_format")
... def my_new_format(config):
... return CutSet(read_my_format(**config)), True
...
... fn = get_parser_fn("my_new_format")
... cuts, is_tarred = fn({"my_arg_0": ..., "my_arg_1": ..., ...})
"""
def _decorator(fn):
global KNOWN_DATA_CONFIG_TYPES
if isinstance(name, str):
KNOWN_DATA_CONFIG_TYPES[name] = fn
else:
for n in name:
KNOWN_DATA_CONFIG_TYPES[n] = fn
return fn
return _decorator
def read_dataset_config(config) -> tuple[CutSet, bool]:
"""
Input configuration format examples.
Example 1. Combine two datasets with equal weights and attach custom metadata in ``tags`` to each cut::
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/manifest__OP_0..512_CL_.json
tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar
weight: 0.5
tags:
lang: en
some_metadata: some_value
- type: nemo_tarred
manifest_filepath: /path/to/manifest__OP_0..512_CL_.json
tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar
weight: 0.5
tags:
lang: pl
some_metadata: some_value
Example 2. Combine multiple (4) datasets, with 2 corresponding to different tasks (ASR, AST).
There are two levels of weights: per task (outer) and per dataset (inner).
The final weight is the product of outer and inner weight::
input_cfg:
- type: group
weight: 0.7
tags:
task: asr
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/asr1/manifest__OP_0..512_CL_.json
tarred_audio_filepath: /path/to/tarred_audio/asr1/audio__OP_0..512_CL_.tar
weight: 0.6
tags:
lang: en
some_metadata: some_value
- type: nemo_tarred
manifest_filepath: /path/to/asr2/manifest__OP_0..512_CL_.json
tarred_audio_filepath: /path/to/asr2/tarred_audio/audio__OP_0..512_CL_.tar
weight: 0.4
tags:
lang: pl
some_metadata: some_value
- type: group
weight: 0.3
tags:
task: ast
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/ast1/manifest__OP_0..512_CL_.json
tarred_audio_filepath: /path/to/ast1/tarred_audio/audio__OP_0..512_CL_.tar
weight: 0.2
tags:
src_lang: en
tgt_lang: pl
- type: nemo_tarred
manifest_filepath: /path/to/ast2/manifest__OP_0..512_CL_.json
tarred_audio_filepath: /path/to/ast2/tarred_audio/audio__OP_0..512_CL_.tar
weight: 0.8
tags:
src_lang: pl
tgt_lang: en
"""
propagate_attrs = {
"shuffle": config.get("shuffle", False),
"shard_seed": config.get("shard_seed", "trng"),
"text_field": config.get("text_field", "text"),
"lang_field": config.get("lang_field", "lang"),
"metadata_only": config.get("metadata_only", False),
"force_finite": config.get("force_finite", False),
"max_open_streams": config.get("max_open_streams", None),
"token_equivalent_duration": config.get("token_equivalent_duration", None),
"skip_missing_manifest_entries": config.get("skip_missing_manifest_entries", False),
"force_map_dataset": config.get("force_map_dataset", False),
"force_iterable_dataset": config.get("force_iterable_dataset", False),
}
input_cfg = config.input_cfg
if isinstance(input_cfg, (str, Path)):
# Resolve /path/to/input_cfg.yaml into config contents if needed.
input_cfg = OmegaConf.load(input_cfg)
cuts, is_tarred = parse_and_combine_datasets(input_cfg, propagate_attrs=propagate_attrs)
return cuts, is_tarred
def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]:
"""Parse a group configuration, potentially combining multiple datasets."""
assert grp_cfg.type in get_known_config_data_types(), f"Unknown item type in dataset config list: {grp_cfg.type=}"
# Note: Text data types will return is_tarred=True.
# We choose to treat text as-if it was tarred, which tends to be more
# efficient as it moves the text file iteration into dataloading subprocess.
if grp_cfg.type != "group":
parser_fn = get_parser_fn(grp_cfg.type)
cuts, is_tarred = parser_fn(grp_cfg)
else:
cuts, is_tarred = parse_and_combine_datasets(
grp_cfg.input_cfg,
propagate_attrs=propagate_attrs,
)
# Attach extra tags to every utterance dynamically, if provided.
if (extra_tags := grp_cfg.get("tags")) is not None:
cuts = cuts.map(partial(attach_tags, tags=extra_tags), apply_fn=None)
return cuts, is_tarred
@data_type_parser("txt")
def read_txt_paths(config: DictConfig) -> tuple[CutSet, bool]:
"""Read paths to text files and create a CutSet."""
cuts = CutSet(
LhotseTextAdapter(
paths=config.paths,
language=config.language,
shuffle_shards=config.shuffle,
shard_seed=config.shard_seed,
)
)
if not config.get("force_finite", False):
cuts = cuts.repeat()
return cuts, True
@data_type_parser("txt_pair")
def read_txt_pair_paths(config: DictConfig) -> tuple[CutSet, bool]:
"""Read paths to source and target text files and create a CutSet."""
cuts = CutSet(
LhotseTextPairAdapter(
source_paths=config.source_paths,
target_paths=config.target_paths,
source_language=config.get("source_language"),
target_language=config.get("target_language"),
questions_path=config.get("questions_path"),
questions_language=config.get("questions_language"),
shuffle_shards=config.shuffle,
shard_seed=config.shard_seed,
)
)
if not config.get("force_finite", False):
cuts = cuts.repeat()
return cuts, True
@data_type_parser("nemo_sft_jsonl")
def read_nemo_sft_jsonl(config: DictConfig) -> tuple[CutSet, bool]:
"""Read paths to Nemo SFT JSONL files and create a CutSet."""
cuts = CutSet(
NeMoSFTJsonlAdapter(
paths=config.paths,
language=config.get("language"),
shuffle_shards=config.shuffle,
shard_seed=config.shard_seed,
)
)
if not config.get("force_finite", False):
cuts = cuts.repeat()
return cuts, True
@data_type_parser("multimodal_conversation")
def read_multimodal_conversation_jsonl(config: DictConfig) -> tuple[CutSet, bool]:
"""Read paths to multimodal conversation JSONL files and create a CutSet."""
cuts = CutSet(
NeMoMultimodalConversationJsonlAdapter(
manifest_filepath=config.manifest_filepath,
tarred_audio_filepaths=config.get("tarred_audio_filepaths"),
audio_locator_tag=config.audio_locator_tag,
token_equivalent_duration=config.get("token_equivalent_duration"),
shuffle_shards=config.shuffle,
shard_seed=config.shard_seed,
)
)
if not config.get("force_finite", False):
cuts = cuts.repeat()
return cuts, True
def attach_tags(cut, tags: dict):
"""Attach extra tags to a cut dynamically."""
for key, val in tags.items():
setattr(cut, key, val)
return cut
@data_type_parser("group")
def parse_and_combine_datasets(
config_list: Union[list[DictConfig], ListConfig], propagate_attrs: dict
) -> tuple[CutSet, bool]:
"""Parse a list of dataset configurations, potentially combining multiple datasets."""
cuts = []
weights = []
tarred_status = []
assert len(config_list) > 0, "Empty group in dataset config list."
for item in config_list:
# Check if we have any attributes that are propagated downwards to each item in the group.
# If a key already exists in the item, it takes precedence (we will not overwrite);
# otherwise we will assign it.
# We also update propagate_atts for the next sub-groups based on what's present in this group
next_propagate_attrs = propagate_attrs.copy()
for k, v in propagate_attrs.items():
if k not in item:
item[k] = v
else:
next_propagate_attrs[k] = item[k]
# Load the item (which may also be another group) as a CutSet.
item_cuts, item_is_tarred = parse_group(item, next_propagate_attrs)
cuts.append(item_cuts)
tarred_status.append(item_is_tarred)
if (w := item.get("weight")) is not None:
weights.append(w)
all_same_tarred_status = all(t == tarred_status[0] for t in tarred_status)
if not all_same_tarred_status:
if propagate_attrs["force_map_dataset"] or propagate_attrs["force_iterable_dataset"]:
logging.warning(
f"Not all datasets in the group have the same tarred status, using provided force_map_dataset ({propagate_attrs['force_map_dataset']}) and force_iterable_dataset ({propagate_attrs['force_iterable_dataset']}) to determine the final tarred status."
)
else:
raise ValueError(
"Mixing tarred and non-tarred datasets is not supported when neither force_map_dataset nor force_iterable_dataset is True."
)
assert len(weights) == 0 or len(cuts) == len(
weights
), "Missing dataset weight. When weighting datasets, every dataset must have a specified weight."
if len(cuts) > 1:
# Before we mix the datasets in the config, we filter cutsets to exclude cuts
# with valid "_skipme" values to mix the data correctly.
cuts = [cut.filter(PlaceholderFilter()) for cut in cuts]
cuts = mux(
*cuts,
weights=weights if weights else None,
max_open_streams=propagate_attrs["max_open_streams"],
seed=propagate_attrs["shard_seed"],
force_finite=propagate_attrs["force_finite"] or propagate_attrs["metadata_only"],
)
else:
(cuts,) = cuts
return cuts, tarred_status[0]
@data_type_parser(["lhotse", "lhotse_shar"])
def read_lhotse_manifest(config) -> tuple[CutSet, bool]:
"""Read paths to Lhotse manifest files and create a CutSet."""
is_tarred = config.get("shar_path") is not None
if is_tarred:
# Lhotse Shar is the equivalent of NeMo's native "tarred" dataset.
# The combination of shuffle_shards, and repeat causes this to
# be an infinite manifest that is internally reshuffled on each epoch.
# The parameter ``config.shard_seed`` is used to determine shard shuffling order. Options:
# - "trng" means we'll defer setting the seed until the iteration
# is triggered, and we'll use system TRNG to get a completely random seed for each worker.
# This results in every dataloading worker using full data but in a completely different order.
# - "randomized" means we'll defer setting the seed until the iteration
# is triggered, and we'll use config.seed to get a pseudo-random seed for each worker.
# This results in every dataloading worker using full data but in a completely different order.
# Unlike "trng", this is deterministic, and if you resume training, you should change the seed
# to observe different data examples than in the previous run.
# - integer means we'll set a specific seed in every worker, and data would be duplicated across them.
# This is mostly useful for unit testing or debugging.
shard_seed = config.shard_seed
metadata_only = config.metadata_only
force_finite = config.force_finite
if config.get("cuts_path") is not None:
warnings.warn("Note: lhotse.cuts_path will be ignored because lhotse.shar_path was provided.")
if isinstance(config.shar_path, (str, Path)):
logging.info(f"Initializing Lhotse Shar CutSet (tarred) from a single data source: '{config.shar_path}'")
cuts = CutSet.from_shar(
**_resolve_shar_inputs(config.shar_path, metadata_only), shuffle_shards=True, seed=shard_seed
)
if not metadata_only and not force_finite:
cuts = cuts.repeat()
elif isinstance(config.shar_path, Sequence):
# Multiple datasets in Lhotse Shar format: we will dynamically multiplex them
# with probability approximately proportional to their size
logging.info(
"Initializing Lhotse Shar CutSet (tarred) from multiple data sources with a weighted multiplexer. "
"We found the following sources and weights: "
)
cutsets = []
weights = []
for item in config.shar_path:
if isinstance(item, (str, Path)):
path = item
cs = CutSet.from_shar(
**_resolve_shar_inputs(path, metadata_only), shuffle_shards=True, seed=shard_seed
)
weight = len(cs)
else:
assert isinstance(item, Sequence) and len(item) == 2 and isinstance(item[1], (int, float)), (
"Supported inputs types for config.shar_path are: "
"str | list[str] | list[tuple[str, number]] "
"where str is a path and number is a mixing weight (it may exceed 1.0). "
f"We got: '{item}'"
)
path, weight = item
cs = CutSet.from_shar(
**_resolve_shar_inputs(path, metadata_only), shuffle_shards=True, seed=shard_seed
)
logging.info(f"- {path=} {weight=}")
cutsets.append(cs)
weights.append(weight)
cutsets = [cutset.filter(PlaceholderFilter()) for cutset in cutsets]
cuts = mux(
*cutsets,
weights=weights,
max_open_streams=config.max_open_streams,
seed=config.shard_seed,
force_finite=force_finite,
)
elif isinstance(config.shar_path, Mapping):
fields = {k: expand_sharded_filepaths(v) for k, v in config.shar_path.items()}
assert "cuts" in config.shar_path.keys(), (
f"Invalid value for key 'shar_path': a dict was provided, but didn't specify key 'cuts' pointing "
f"to the manifests. We got the following: {config.shar_path=}"
)
if metadata_only:
fields = {"cuts": fields["cuts"]}
cuts = CutSet.from_shar(fields=fields, shuffle_shards=True, seed=shard_seed)
if not metadata_only and not force_finite:
cuts = cuts.repeat()
else:
raise RuntimeError(
f"Unexpected value for key 'shar_path'. We support string, list of strings, "
f"list of tuples[string,float], and dict[string,list[string]], "
f"but got: {type(config.shar_path)=} {config.shar_path=}"
)
else:
# Regular Lhotse manifest points to individual audio files (like native NeMo manifest).
path = config.cuts_path
cuts = CutSet.from_file(path).map(partial(resolve_relative_paths, manifest_path=path))
return cuts, is_tarred
def _resolve_shar_inputs(path: Union[str, Path], only_metadata: bool) -> dict:
if only_metadata:
return dict(fields={"cuts": sorted(Path(path).glob("cuts.*"))})
else:
return dict(in_dir=path)
def resolve_relative_paths(cut: Cut, manifest_path: str) -> Cut:
"""Resolve relative paths in a Cut object to their full paths."""
if isinstance(cut, PaddingCut):
return cut
if isinstance(cut, MixedCut):
for track in cut.tracks:
track.cut = resolve_relative_paths(track.cut, manifest_path)
return cut
def resolve_recording(value):
for audio_source in value.sources:
if audio_source.type == "file":
audio_source.source = get_full_path(audio_source.source, manifest_file=manifest_path)
def resolve_array(value):
if isinstance(value, TemporalArray):
value.array = resolve_array(value.array)
else:
if value.storage_type in ("numpy_files", "lilcom_files"):
abspath = Path(
get_full_path(str(Path(value.storage_path) / value.storage_key), manifest_file=manifest_path)
)
value.storage_path = str(abspath.parent)
value.storage_key = str(abspath.name)
elif value.storage_type in (
"kaldiio",
"chunked_lilcom_hdf5",
"lilcom_chunky",
"lilcom_hdf5",
"numpy_hdf5",
):
value.storage_path = get_full_path(value.storage_path, manifest_file=manifest_path)
# ignore others i.e. url, in-memory data, etc.
if cut.has_recording:
resolve_recording(cut.recording)
if cut.has_features:
resolve_array(cut.features)
if cut.custom is not None:
for key, value in cut.custom.items():
if isinstance(value, Recording):
resolve_recording(value)
elif isinstance(value, (Array, TemporalArray, Features)):
resolve_array(value)
return cut
@data_type_parser(["nemo", "nemo_tarred"])
def read_nemo_manifest(config) -> tuple[CutSet, bool]:
"""Read NeMo manifest and return a Lhotse CutSet."""
common_kwargs = {
"text_field": config.text_field,
"lang_field": config.lang_field,
"shuffle_shards": config.shuffle,
"shard_seed": config.shard_seed,
"extra_fields": config.get("extra_fields", None),
}
# The option below is to allow a special case of NeMo manifest iteration as Lhotse CutSet
# without performing any I/O. NeMo manifests typically don't have sampling_rate information required by Lhotse,
# so lhotse has to look up the headers of audio files to fill it on-the-fly.
# (this only has an impact on non-tarred data; tarred data is read into memory anyway).
# This is useful for utility scripts that iterate metadata and estimate optimal batching settings
# and other data statistics.
notar_kwargs = {"metadata_only": config.metadata_only}
metadata_only = config.metadata_only
force_finite = config.force_finite
is_tarred = config.get("tarred_audio_filepaths") is not None
if isinstance(config.manifest_filepath, (str, Path)):
logging.info(
f"""Initializing Lhotse CutSet from a single NeMo manifest
(is_tarred={is_tarred}): '{config.manifest_filepath}'"""
)
if is_tarred and not metadata_only:
cuts = CutSet(
LazyNeMoTarredIterator(
config.manifest_filepath,
tar_paths=config.tarred_audio_filepaths,
skip_missing_manifest_entries=config.skip_missing_manifest_entries,
**common_kwargs,
)
)
if not force_finite:
cuts = cuts.repeat()
else:
cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **notar_kwargs, **common_kwargs))
else:
# Format option 1:
# Assume it's [[path1], [path2], ...] (same for tarred_audio_filepaths).
# This is the format for multiple NeMo buckets.
# Note: we set "weights" here to be proportional to the number of utterances in each data source.
# this ensures that we distribute the data from each source uniformly throughout each epoch.
# Setting equal weights would exhaust the shorter data sources closer the towards the beginning
# of an epoch (or over-sample it in the case of infinite CutSet iteration with .repeat()).
# Format option 2:
# Assume it's [[path1, weight1], [path2, weight2], ...] (while tarred_audio_filepaths remain unchanged).
# Note: this option allows to manually set the weights for multiple datasets.
# Format option 3:
# i.e., NeMo concatenated dataset
# Assume it's [path1, path2, ...] (while tarred_audio_filepaths in the same format).
logging.info(
f"""Initializing Lhotse CutSet from multiple NeMo manifest
(is_tarred={is_tarred}) sources with a weighted multiplexer.
We found the following sources and weights: """
)
cutsets = []
weights = []
tar_paths = config.tarred_audio_filepaths if is_tarred else repeat((None,))
# Create a stream for each dataset.
for manifest_info, tar_path in zip(config.manifest_filepath, tar_paths):
if is_tarred and isinstance(tar_path, (list, tuple, ListConfig)):
# if it's in option 1 or 2
(tar_path,) = tar_path
manifest_path = manifest_info[0]
else:
# if it's in option 3
manifest_path = manifest_info
# First, convert manifest_path[+tar_path] to an iterator.
if is_tarred and not metadata_only:
nemo_iter = LazyNeMoTarredIterator(
manifest_path=manifest_path,
tar_paths=tar_path,
skip_missing_manifest_entries=config.skip_missing_manifest_entries,
**common_kwargs,
)
else:
nemo_iter = LazyNeMoIterator(manifest_path, **notar_kwargs, **common_kwargs)
# Then, determine the weight or use one provided
if isinstance(manifest_info, str) or len(manifest_info) == 1:
weight = len(nemo_iter)
else:
assert (
isinstance(manifest_info, Sequence)
and len(manifest_info) == 2
and isinstance(manifest_info[1], (int, float))
), (
"Supported inputs types for config.manifest_filepath are: "
"str | list[list[str]] | list[tuple[str, number]] "
"where str is a path and number is a mixing weight (it may exceed 1.0). "
f"We got: '{manifest_info}'"
)
weight = manifest_info[1]
logging.info(f"- {manifest_path=} {weight=}")
# [optional] When we have a limit on the number of open streams,
# split the manifest to individual shards if applicable.
# This helps the multiplexing achieve closer data distribution
# to the one desired in spite of the limit.
if config.max_open_streams is not None:
for subiter in nemo_iter.to_shards():
cutsets.append(CutSet(subiter))
weights.append(weight)
else:
cutsets.append(CutSet(nemo_iter))
weights.append(weight)
# Finally, we multiplex the dataset streams to mix the data.
# Before that we filter cutsets to exclude cuts with valid "_skipme" values to mix the data correctly.
cutsets = [cutset.filter(PlaceholderFilter()) for cutset in cutsets]
cuts = mux(
*cutsets,
weights=weights,
max_open_streams=config.max_open_streams,
seed=config.shard_seed,
force_finite=force_finite or metadata_only,
)
return cuts, is_tarred
def mux(
*cutsets: CutSet,
weights: list[Union[int, float]],
max_open_streams: Union[int, None] = None,
seed: Union[str, int] = "trng",
force_finite: bool = False,
) -> CutSet:
"""
Helper function to call the right multiplexing method flavour in lhotse.
The result is always an infinitely iterable ``CutSet``, but depending on whether ``max_open_streams`` is set,
it will select a more appropriate multiplexing strategy.
"""
if max_open_streams is not None:
assert not force_finite, "max_open_streams and metadata_only/force_finite options are not compatible"
cuts = CutSet.infinite_mux(*cutsets, weights=weights, seed=seed, max_open_streams=max_open_streams)
else:
if not force_finite:
cutsets = [cs.repeat() for cs in cutsets]
if len(cutsets) == 1:
# CutSet.mux must take more than one CutSet.
cuts = cutsets[0]
else:
cuts = CutSet.mux(*cutsets, weights=weights, seed=seed)
return cuts
def guess_parse_cutset(inp: Union[str, dict, omegaconf.DictConfig]) -> CutSet:
"""
Utility function that supports opening a CutSet from:
* a string path to YAML input spec (see :func:`read_dataset_config` for details)
* a string path to Lhotse non-tarred JSONL manifest
* a string path to NeMo non-tarred JSON manifest
* a dictionary specifying inputs with keys available in
:class:`nemo.collections.common.data.lhotse.dataloader.LhotseDataLoadingConfig`
It's intended to be used in a generic context where we are not sure which way the user will specify the inputs.
"""
from nemo.collections.common.data.lhotse.dataloader import make_structured_with_schema_warnings
if isinstance(inp, (dict, omegaconf.DictConfig)):
try:
config = make_structured_with_schema_warnings(OmegaConf.from_dotlist([f"{k}={v}" for k, v in inp.items()]))
cuts, _ = read_cutset_from_config(config)
return cuts
except Exception as e:
raise RuntimeError(
f"Couldn't open CutSet based on dict input {inp} (is it compatible with LhotseDataLoadingConfig?)"
) from e
elif isinstance(inp, str):
if inp.endswith(".yaml"):
# Path to YAML file with the input configuration
config = make_structured_with_schema_warnings(OmegaConf.from_dotlist([f"input_cfg={inp}"]))
elif inp.endswith(".jsonl") or inp.endswith(".jsonl.gz"):
# Path to a Lhotse non-tarred manifest
config = make_structured_with_schema_warnings(OmegaConf.from_dotlist([f"cuts_path={inp}"]))
else:
# Assume anything else is a NeMo non-tarred manifest
config = make_structured_with_schema_warnings(OmegaConf.from_dotlist([f"manifest_filepath={inp}"]))
cuts, _ = read_cutset_from_config(config)
return cuts
else:
raise RuntimeError(f'Unsupported input type: {type(inp)} (expected a dict or a string)')