|
| 1 | +import torch |
| 2 | + |
| 3 | +def ctc_forced_align( |
| 4 | + log_probs: torch.Tensor, |
| 5 | + targets: torch.Tensor, |
| 6 | + input_lengths: torch.Tensor, |
| 7 | + target_lengths: torch.Tensor, |
| 8 | + blank: int = 0, |
| 9 | + ignore_id: int = -1, |
| 10 | +) -> torch.Tensor: |
| 11 | + """Align a CTC label sequence to an emission. |
| 12 | +
|
| 13 | + Args: |
| 14 | + log_probs (Tensor): log probability of CTC emission output. |
| 15 | + Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length, |
| 16 | + `C` is the number of characters in alphabet including blank. |
| 17 | + targets (Tensor): Target sequence. Tensor of shape `(B, L)`, |
| 18 | + where `L` is the target length. |
| 19 | + input_lengths (Tensor): |
| 20 | + Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`. |
| 21 | + target_lengths (Tensor): |
| 22 | + Lengths of the targets. 1-D Tensor of shape `(B,)`. |
| 23 | + blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0) |
| 24 | + ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1) |
| 25 | + """ |
| 26 | + targets[targets == ignore_id] = blank |
| 27 | + |
| 28 | + batch_size, input_time_size, _ = log_probs.size() |
| 29 | + bsz_indices = torch.arange(batch_size, device=input_lengths.device) |
| 30 | + |
| 31 | + _t_a_r_g_e_t_s_ = torch.cat( |
| 32 | + ( |
| 33 | + torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1), |
| 34 | + torch.full_like(targets[:, :1], blank), |
| 35 | + ), |
| 36 | + dim=-1, |
| 37 | + ) |
| 38 | + diff_labels = torch.cat( |
| 39 | + ( |
| 40 | + torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1), |
| 41 | + _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2], |
| 42 | + ), |
| 43 | + dim=1, |
| 44 | + ) |
| 45 | + |
| 46 | + neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype) |
| 47 | + padding_num = 2 |
| 48 | + padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1) |
| 49 | + best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype) |
| 50 | + best_score[:, padding_num + 0] = log_probs[:, 0, blank] |
| 51 | + best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]] |
| 52 | + |
| 53 | + backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype) |
| 54 | + |
| 55 | + for t in range(1, input_time_size): |
| 56 | + prev = torch.stack( |
| 57 | + (best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf)) |
| 58 | + ) |
| 59 | + prev_max_value, prev_max_idx = prev.max(dim=0) |
| 60 | + best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value |
| 61 | + backpointers[:, t, padding_num:] = prev_max_idx |
| 62 | + |
| 63 | + l1l2 = best_score.gather( |
| 64 | + -1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1) |
| 65 | + ) |
| 66 | + |
| 67 | + path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long) |
| 68 | + path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1) |
| 69 | + |
| 70 | + for t in range(input_time_size - 1, 0, -1): |
| 71 | + target_indices = path[:, t] |
| 72 | + prev_max_idx = backpointers[bsz_indices, t, target_indices] |
| 73 | + path[:, t - 1] += target_indices - prev_max_idx |
| 74 | + |
| 75 | + alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0)) |
| 76 | + return alignments |
0 commit comments