Skip to content

Fix skipme handling #13244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bo
if not isinstance(config, DictConfig):
config = DictConfig(config)
if config.get("input_cfg") is not None:
return read_dataset_config(config)
# Now, we'll figure out if we should read Lhotse manifest or NeMo manifest.
use_nemo_manifest = all(config.get(opt) is None for opt in ("cuts_path", "shar_path"))
if use_nemo_manifest:
if config.get("manifest_filepath") is None:
raise IncompleteConfigError("You must specify either: manifest_filepath, cuts_path, or shar_path")
cuts, is_tarred = read_nemo_manifest(config)
cuts, is_tarred = read_dataset_config(config)
else:
cuts, is_tarred = read_lhotse_manifest(config)
# Now, we'll figure out if we should read Lhotse manifest or NeMo manifest.
use_nemo_manifest = all(config.get(opt) is None for opt in ("cuts_path", "shar_path"))
if use_nemo_manifest:
if config.get("manifest_filepath") is None:
raise IncompleteConfigError("You must specify either: manifest_filepath, cuts_path, or shar_path")
cuts, is_tarred = read_nemo_manifest(config)
else:
cuts, is_tarred = read_lhotse_manifest(config)

# After reading cuts we filter cutsets to exclude cuts with valid "_skipme" values.
# This filtration is done before mixing cutsets as well. Here it is being done for non-mixed cutsets.
Expand Down Expand Up @@ -351,7 +352,11 @@ def parse_and_combine_datasets(
assert len(weights) == 0 or len(cuts) == len(
weights
), "Missing dataset weight. When weighting datasets, every dataset must have a specified weight."

if len(cuts) > 1:
# Before we mix the datasets in the config, we filter cutsets to exclude cuts
# with valid "_skipme" values to mix the data correctly.
cuts = [cut.filter(PlaceholderFilter()) for cut in cuts]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary? Don;t samples with _skipme already get filtered in read_cutset_from_config function with PlaceholderFilter there ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to filter before cuts will get mixed with the specified weights, because if we mix n samples vs n - n_skipme_samples, resulted cutset will be different

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in read_cutset_from_config we do the filtering for non-merged cutsets, but I think I can change the place of the filtering to do one for both cases

cuts = mux(
*cuts,
weights=weights if weights else None,
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/common/data/lhotse/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,4 +394,4 @@ def __call__(self, example) -> bool:
return True

custom = getattr(example, "custom", None)
return custom is None or not custom.pop("_skipme", False)
return custom is None or not custom.get("_skipme", False)
43 changes: 43 additions & 0 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2549,3 +2549,46 @@ def test_dataloader_from_tarred_nemo_manifest_with_skipme(nemo_tarred_manifest_w

assert len(batches) == 8
assert not any(skipme_s)


def test_dataloader_from_data_input_cfg_yaml_path_with_skipme(cutset_shar_path, nemo_tarred_manifest_with_skipme_path):
config = OmegaConf.create(
{
"input_cfg": [
{
"type": "nemo_tarred",
"manifest_filepath": nemo_tarred_manifest_with_skipme_path[0],
"tarred_audio_filepaths": nemo_tarred_manifest_with_skipme_path[1],
"weight": 0.5,
"tags": {
"language": "en",
"modality": "audio",
"dataset_name": "D1",
},
},
{
"type": "lhotse_shar",
"shar_path": cutset_shar_path,
"weight": 0.5,
"tags": {
"language": "en",
"modality": "audio",
"dataset_name": "D2",
},
},
],
"sample_rate": 16000,
"shuffle": True,
"num_workers": 0,
"batch_size": 4,
"seed": 0,
"shard_seed": 0,
"force_finite": True,
}
)

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity())
batches = [batch for batch in dl]
skipme_s = [cut.custom.get('_skipme', 0) for batch in batches for cut in batch]

assert not any(skipme_s)
Loading