diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index 93deb6196e..246a47fceb 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -36,6 +36,8 @@ list( ${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h ${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h ${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h + ${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.h ${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp ${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h ${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 3ee2d3789e..309de56cd8 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -4,13 +4,13 @@ This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models: 1. LLAMA2 Stories 110M 2. LLAMA3.2 1B - 3. LLAMA3.2 3B (WIP) + 3. LLAMA3.2 3B We offer the following modes to execute the model: -KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt. +- KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt. -Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens. +- Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens. - AR-N model: The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use it to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor in hybrid mode. - Prompt processing with AR-N model:
@@ -19,6 +19,7 @@ Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache
+- Lookahead Mode: Lookahead Mode introduces [lookahead decoding](https://arxiv.org/abs/2402.02057) and uses AR-N model to process prompt to enhance token generation speed. While decoding multiple tokens in a single step is infeasible, an LLM can generate multiple guess tokens in parallel. These guess tokens may fit into future parts of the generated sequence. The lookahead decoder generates and verifies these guess tokens, integrating them into the sequence if suitable. In some cases, it can obtain more than one token in a single step. Result is lossless. ## Instructions ### Note @@ -127,3 +128,14 @@ You can select the KV Cache update mechanism at runtime by setting the `KV_UPDAT ```bash python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER} ``` + +You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters: +- `--ngram` (N-gram size): Represents the size of the n-grams used in the lookahead process. +- `--window` (window size): Determines how many future tokens the algorithm attempts to predict in each step. +- `--gcap` (Verification candidates): Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities. + +For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057) + +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2 +``` diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 2ff50cced2..33482090b2 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -11,6 +11,7 @@ import getpass import json import logging +import math import os import subprocess import sys @@ -90,6 +91,12 @@ logging.getLogger().setLevel(logging.INFO) +def next_power_of_two(n): + if n == 0: + return 1 + return 2 ** math.ceil(math.log2(n)) + + def smart_mask_updater( ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches ): @@ -531,6 +538,28 @@ def compile(args, pte_filename, tokenizer): use_i64_token=use_i64_token, ) ) + elif args.model_mode == "lookahead": + llama_instance_list.append( + LlamaModel( + kv_config, + # To get better performance, we round up to the nearest power of 2. + ar_len=next_power_of_two( + (args.window + args.gcap) * (args.ngram - 1) + ), + output_new_cache_only=True, + output_cache=True, + use_i64_token=use_i64_token, + ) + ) + llama_instance_list.append( + LlamaModel( + prefill_config, + ar_len=args.prefill_ar_len, + output_new_cache_only=True, + output_cache=True, + use_i64_token=use_i64_token, + ) + ) else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") @@ -630,8 +659,8 @@ def permute(w, heads): tokenizer=tokenizer, custom_annotations=custom_annotations, ) - # If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later - if i == 0 and args.model_mode == "hybrid": + # If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later + if i == 0 and args.model_mode in ["hybrid", "lookahead"]: output_indices = 0 for node in llama_instance.llama_graph_module.graph.nodes: if node.op == "output": @@ -673,7 +702,7 @@ def permute(w, heads): shared_buffer=args.shared_buffer, ) quant_attrs = llama_instance_list[0].get_quant_attrs() - elif args.model_mode == "hybrid": + elif args.model_mode in ["hybrid", "lookahead"]: sample_inputs_list = [ llama_instace.inputs for llama_instace in llama_instance_list ] @@ -759,6 +788,8 @@ def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""): eval_mode = 0 elif args.model_mode == "hybrid": eval_mode = 1 + elif args.model_mode == "lookahead": + eval_mode = 2 else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") @@ -832,6 +863,9 @@ def post_process(): "--output_path outputs/outputs.txt", f"--performance_output_path {performance_output_path}", f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}", + f"--window {args.window}", + f"--gcap {args.gcap}", + f"--ngram {args.ngram}", runner_args, ] ) @@ -971,9 +1005,9 @@ def _build_parser(): parser.add_argument( "--model_mode", - help="Export and inference kv mode or hybrid mode", + help="Export and inference kv mode, hybrid mode, or lookahead decoding mode", default="kv", - choices=["kv", "hybrid"], + choices=["kv", "hybrid", "lookahead"], type=str, ) @@ -986,7 +1020,7 @@ def _build_parser(): parser.add_argument( "--prefill_ar_len", - help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.", + help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid and lookahead mode.", default=32, type=int, ) @@ -1007,6 +1041,27 @@ def _build_parser(): help="Fallback to cpu embedding operator and type of embedding quantization, ',', e.g., '4,32'.", ) + parser.add_argument( + "--ngram", + help="Represents the size of the n-grams used in the lookahead process.", + default=5, + type=int, + ) + + parser.add_argument( + "--window", + help="Determines how many future tokens the algorithm attempts to predict in each step.", + default=8, + type=int, + ) + + parser.add_argument( + "--gcap", + help="Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.", + default=8, + type=int, + ) + parser.add_argument("-v", "--verbose", action="store_true") return parser @@ -1023,6 +1078,14 @@ def export_llama(args) -> None: args.max_seq_len >= args.prefill_ar_len ), "Please ensure max_seq_len is >= prefill_ar_len" pte_filename = "hybrid_llama_qnn" + elif args.model_mode == "lookahead": + assert ( + args.max_seq_len >= args.prefill_ar_len + ), "Please ensure max_seq_len is >= prefill_ar_len" + assert args.max_seq_len > next_power_of_two( + (args.window + args.gcap) * (args.ngram - 1) + ), "Please ensure max_seq_len is > next_power_of_two((args.window + args.gcap) * (args.ngram - 1))" + pte_filename = "lookahead_llama_qnn" else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 938d298d07..5c10d3eade 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -53,12 +53,24 @@ DEFINE_int32( DEFINE_int32( eval_mode, 0, - "0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)"); + "0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding"); DEFINE_string( kv_updater, - "How to update kv cache. Choose between SmartMask and ShiftPointer", - "SmartMask"); + "SmartMask", + "How to update kv cache. Choose between SmartMask and ShiftPointer"); DEFINE_int32(num_iters, 1, "total num of iterations to run."); +DEFINE_int32( + ngram, + 0, + "[Lookahead Decoding] Represents the size of the n-grams used in the lookahead process."); +DEFINE_int32( + window, + 0, + "[Lookahead Decoding] Determines how many future tokens the algorithm attempts to predict in each step."); +DEFINE_int32( + gcap, + 0, + "[Lookahead Decoding] Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities."); std::vector CollectPrompts(int argc, char** argv) { // Collect all prompts from command line, example usage: @@ -111,7 +123,10 @@ int main(int argc, char** argv) { FLAGS_performance_output_path.c_str(), FLAGS_temperature, FLAGS_eval_mode, - FLAGS_kv_updater); + FLAGS_kv_updater, + FLAGS_ngram, + FLAGS_window, + FLAGS_gcap); auto llama_version = runner.get_llama_version(); std::vector buf; buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp index ca155204de..b563049eb8 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp @@ -51,7 +51,7 @@ void KVManager::init_attention_mask( int32_t ar_len, int32_t n_past) { ET_CHECK_MSG( - attention_map.size() == ar_len, + attention_map.size() <= ar_len, "The size of attention_map (%zu) doesn't match with ar_len (%d)", attention_map.size(), ar_len); @@ -197,9 +197,11 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { ? 0 : metadata_.max_cache_len - (metadata_.context_len - cur_ar_len_); v_cache_[layer][head].buffer = single_layer_v_cache + - head * single_head_size_in + cache_gap * metadata_.head_dim; - v_cache_[layer][head].output_buffer = - single_layer_v_cache + (head + 1) * single_head_size_in; + head * metadata_.head_dim * metadata_.context_len + + cache_gap * metadata_.head_dim; + v_cache_[layer][head].output_buffer = single_layer_v_cache + + head * metadata_.head_dim * metadata_.context_len + + single_head_size_in; } } break; @@ -311,7 +313,11 @@ bool KVManager::update_cache_tensor( return updated; } -void KVManager::update_cache(int32_t ar_len, int32_t n_past, int32_t n_update) { +void KVManager::update_cache( + int32_t ar_len, + int32_t n_past, + int32_t n_update, + const std::vector& selected) { ET_CHECK_MSG( cur_ar_len_ == ar_len, "Current AR length (%d) is not matched with target AR length (%d). Please rearrange cache first.", @@ -319,13 +325,17 @@ void KVManager::update_cache(int32_t ar_len, int32_t n_past, int32_t n_update) { ar_len); for (int layer = 0; layer < metadata_.num_layers; ++layer) { for (int head = 0; head < metadata_.num_heads; ++head) { - update_key(k_cache_[layer][head], n_past, n_update); - update_value(v_cache_[layer][head], n_past, n_update); + update_key(k_cache_[layer][head], n_past, n_update, selected); + update_value(v_cache_[layer][head], n_past, n_update, selected); } } } -void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) { +void KVManager::update_key( + KVCache& k_cache, + int32_t n_past, + int32_t n_update, + const std::vector& selected) { uint8_t* write_ptr = k_cache.buffer; uint8_t* read_ptr = k_cache.output_buffer; const int32_t copy_size = n_update * sizeof(uint8_t); @@ -340,22 +350,35 @@ void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) { write_ptr += iter_size + past_size; if (kv_updater_ == KVManagerMode::SMART_MASK) write_ptr += past_size; - - for (int i = 0; i < n_iter; ++i) { - std::memcpy(write_ptr, read_ptr, copy_size); - write_ptr += iter_size; - read_ptr += out_size; + if (selected.empty()) { + for (int i = 0; i < n_iter; ++i) { + std::memcpy(write_ptr, read_ptr, copy_size); + write_ptr += iter_size; + read_ptr += out_size; + } + } else { + std::vector true_indices(n_update); + for (int i = 0, j = 0; i < selected.size() && j < n_update; ++i) { + if (selected[i]) { + true_indices[j++] = i; + } + } + for (int i = 0; i < n_iter; ++i) { + auto wp = write_ptr, rp = read_ptr; + for (auto ind : true_indices) { + *wp++ = rp[ind]; + } + write_ptr += iter_size; + read_ptr += out_size; + } } } void KVManager::update_value( KVCache& v_cache, int32_t n_past, - int32_t n_update) { - // Value cache doesn't need to copy for SHIFT_POINTER mode - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) - return; - + int32_t n_update, + const std::vector& selected) { uint8_t* write_ptr = v_cache.buffer; uint8_t* read_ptr = v_cache.output_buffer; const int32_t copy_size = n_update * metadata_.head_dim * sizeof(uint8_t); @@ -364,7 +387,31 @@ void KVManager::update_value( if (kv_updater_ == KVManagerMode::SMART_MASK) write_ptr += past_size; - std::memcpy(write_ptr, read_ptr, copy_size); + // Update the value cache for lookahead decoding in SHIFT_POINTER mode + if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { + read_ptr += past_size; + write_ptr = read_ptr; + } + + if (selected.empty()) { + // In general, value cache doesn't need to copy for SHIFT_POINTER mode + if (kv_updater_ == KVManagerMode::SHIFT_POINTER) + return; + std::memcpy(write_ptr, read_ptr, copy_size); + } else { + int32_t update_times = n_update; + auto wp = write_ptr, rp = read_ptr; + for (auto sel : selected) { + if (sel) { + std::memcpy(wp, rp, metadata_.head_dim * sizeof(uint8_t)); + wp += metadata_.head_dim; + update_times--; + if (update_times == 0) + break; + } + rp += metadata_.head_dim; + } + } } } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h index 1a3beb35f9..e1a756d121 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h @@ -120,8 +120,13 @@ class KVManager { * @param ar_len Length of input tokens. * @param n_past Number of past elements in the cache. * @param n_update Number of elements to be updated. + * @param selected Indicate which position to be updated */ - void update_cache(int32_t ar_len, int32_t n_past, int32_t n_update); + void update_cache( + int32_t ar_len, + int32_t n_past, + int32_t n_update, + const std::vector& selected); const std::vector>& get_k_cache_() const { return k_cache_; @@ -138,8 +143,16 @@ class KVManager { // Helper functions to rearrange and update key and value caches void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); - void update_key(KVCache& k_cache, int32_t n_past, int32_t n_update); - void update_value(KVCache& v_cache, int32_t n_past, int32_t n_update); + void update_key( + KVCache& k_cache, + int32_t n_past, + int32_t n_update, + const std::vector& selected); + void update_value( + KVCache& v_cache, + int32_t n_past, + int32_t n_update, + const std::vector& selected); KVManagerMode kv_updater_; // metadata diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp new file mode 100644 index 0000000000..a20994a7a3 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp @@ -0,0 +1,383 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +using executorch::runtime::Result; + +namespace example { + +void LhdTokenGenerator::prepare_io( + std::vector input_tokens, + std::vector input_pos) { + for (int i = 0; i < metadata_.ar_len; i++) { + if (i < input_tokens.size()) { + // Prepare pos data + input_pos_.data[i] = input_pos[i]; + + // Support CPU 4-bit embedding, which requires int64 input. + // However, for QNN embedding, only int32 input is needed. + // Therefore, we need to cast to the correct type to write the data. + if (metadata_.use_int64_token) { + input_toks_.data[i] = input_tokens[i]; + } else { + int32_t* input_toks_ptr = reinterpret_cast(input_toks_.data); + input_toks_ptr[i] = static_cast(input_tokens[i]); + } + } + } +} + +void LhdTokenGenerator::init_attention_mask(int32_t n_past) { + std::vector attention_map; + attention_map.reserve(metadata_.ar_len); + // Initialize attention mask with current position + for (int i = 0; i < metadata_.window; ++i) { + attention_map.push_back(i - 1); + } + for (int i = 1; i < metadata_.ngram - 1; ++i) { + for (int j = 0; j < metadata_.window; ++j) { + attention_map.push_back((i - 1) * metadata_.window + j); + } + } + for (int g = 0; g < metadata_.gcap; g++) { + for (int j = 0; j < metadata_.ngram - 1; j++) { + if (j == 0) + attention_map.push_back(0); + else + attention_map.push_back( + (metadata_.window + g) * (metadata_.ngram - 1) + j - 1); + } + } + + kv_manager_->init_attention_mask( + attention_mask_.data, attention_map, metadata_.ar_len, n_past); +} + +void LhdTokenGenerator::init_lookahead_branch( + const std::vector& tokens) { + for (int i = 0; i < metadata_.ngram - 1; ++i) { + for (int j = 0; j < metadata_.window; ++j) { + // there are different ways to init these tokens + if (0) { + // initialize with a sequence of increasing numbers + lhd_branch_[i][j] = 1000 + j; + } else { + // initialize with the random token from prompt + lhd_branch_[i][j] = tokens[1 + rand() % (tokens.size() - 1)]; + } + } + } + is_lhd_branch_initialized_ = true; +} + +void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { + const int g_cur = ngrams_pool_.cnt[cur_token]; + + v_branch_.resize(g_cur); + for (int g = 0; g < g_cur; g++) { + v_branch_[g].active = true; + v_branch_[g].tokens.resize(metadata_.ngram); + v_branch_[g].i_batch.resize(metadata_.ngram); + v_branch_[g].seq_id = metadata_.window + 1 + g; + v_branch_[g].i_batch[0] = 0; + v_branch_[g].tokens[0] = cur_token; + } + + for (int j = 0; j < metadata_.ngram - 1; j++) { + for (int g = 0; g < g_cur; g++) { + const int idx = cur_token * (metadata_.ngram - 1) * metadata_.gcap + + g * (metadata_.ngram - 1); + const int32_t t = ngrams_pool_.tokens[idx + j]; + v_branch_[g].tokens[j + 1] = t; + v_branch_[g].i_batch[j + 1] = j + 1; + } + } +} + +void LhdTokenGenerator::update_ngrams_pool() { + std::vector ngram(metadata_.ngram - 1); + // n-gram pool generation + for (int f = 0; f < metadata_.window; ++f) { + const int ft = lhd_branch_prev_[f]; // first token of the n-gram + + for (int j = 0; j < metadata_.ngram - 1; ++j) { + ngram[j] = lhd_branch_[j][f]; + } + + // filter-out repeating n-grams + { + bool is_unique = true; + for (int k = 0; k < ngrams_pool_.cnt[ft]; ++k) { + // calculate the related idx by the first n-gram token + const int idx = ft * (metadata_.ngram - 1) * metadata_.gcap + + k * (metadata_.ngram - 1); + + bool is_match = true; + for (int j = 0; j < metadata_.ngram - 1; ++j) { + if (ngrams_pool_.tokens[idx + j] != ngram[j]) { + is_match = false; + break; + } + } + + // if n-gram match all, discard one of them + if (is_match) { + is_unique = false; + break; + } + } + if (!is_unique) { + continue; + } + } + + const int head = ngrams_pool_.head[ft]; + const int idx = ft * (metadata_.ngram - 1) * metadata_.gcap + + head * (metadata_.ngram - 1); + + for (int i = 0; i < metadata_.ngram - 1; i++) { + // update the n-gram pool with new n-gram + ngrams_pool_.tokens[idx + i] = ngram[i]; + } + + ngrams_pool_.cnt[ft] = + std::min(metadata_.gcap, (int32_t)ngrams_pool_.cnt[ft] + 1); + ngrams_pool_.head[ft] = (head + 1) % metadata_.gcap; + ngrams_pool_.n_total++; + } +} + +void LhdTokenGenerator::update_lookahead_branch( + const executorch::aten::Tensor& logits_tensor) { + for (int i = 0; i < metadata_.window; i++) { + lhd_branch_prev_[i] = lhd_branch_[0][i]; + } + + for (int j = 0; j < metadata_.ngram - 2; j++) { + lhd_branch_[j] = lhd_branch_[j + 1]; + } + + // sample from the last level + for (int i = 0; i < metadata_.window; i++) { + size_t sample_idx = (metadata_.ngram - 2) * metadata_.window + i; + lhd_branch_[metadata_.ngram - 2][i] = + decoder_runner_->logits_to_token(logits_tensor, sample_idx); + } +} + +Result LhdTokenGenerator::generate( + std::vector tokens, + int64_t start_pos, + int32_t seq_len, + std::function token_callback) { + ET_CHECK_MSG( + !tokens.empty(), "Token generation loop shouldn't take empty tokens"); + // position in the sequence + int64_t pos = start_pos; + int64_t prev_pos; + // number of match tokens + int32_t n_accept{0}; + std::vector result_tokens; + uint64_t cur_token = tokens.back(); + uint64_t prev_token; + result_tokens.push_back(cur_token); + + // Manage the inputs of lookahead decoding + std::vector input_pos; + std::vector input_tokens; + input_tokens.reserve(metadata_.ar_len); + input_pos.reserve(metadata_.ar_len); + + // Rearrange KV cache first and initialize the input and output of KV cache + kv_manager_->rearrange_cache(metadata_.ar_len); + + // Initialize attention mask with pos + init_attention_mask(pos); + + // Initialize Lookahead branch at first generation + if (!is_lhd_branch_initialized_) { + ET_LOG(Info, "Initialize Lookahead branch"); + init_lookahead_branch(tokens); + } + + // Initialize the output of the module + ET_CHECK_MSG( + decoder_runner_->set_outputs(method_name_, output_tensors_) == + executorch::runtime::Error::Ok, + "Failed to set output tensor for module %s", + method_name_.c_str()); + + // Generate tokens + while (pos < seq_len - 1) { + std::vector selected(metadata_.ar_len, false); + + input_tokens.clear(); + input_pos.clear(); + + // fill the first token of the first level + input_tokens.push_back(cur_token); + input_pos.push_back(pos); + + // fill the remaining WINDOW - 1 tokens for the first level + for (int i = 1; i < metadata_.window; ++i) { + input_tokens.push_back(lhd_branch_[0][i]); + input_pos.push_back(pos + i); + } + + // fill the rest of the levels + for (int i = 1; i < metadata_.ngram - 1; ++i) { + for (int j = 0; j < metadata_.window; ++j) { + input_tokens.push_back(lhd_branch_[i][j]); + input_pos.push_back(pos + i + j); + } + } + // Verification Branch Init + init_verification_branch(cur_token); + + for (int g = 0; g < v_branch_.size(); g++) { + for (int j = 0; j < metadata_.ngram - 1; j++) { + input_tokens.push_back(v_branch_[g].tokens[j + 1]); + input_pos.push_back(pos + j + 1); + } + } + + prepare_io(input_tokens, input_pos); + // Only update data pointer of the cache to the tensor for SHIFT_POINTER + // mode + bool updated = kv_manager_->update_cache_tensor( + k_cache_in_, + k_cache_out_, + v_cache_in_, + v_cache_out_, + metadata_.ar_len, + pos); + // Only update the output of module for SHIFT_POINTER mode + if (updated) { + // Update the output of the module + ET_CHECK_MSG( + decoder_runner_->set_outputs(method_name_, output_tensors_) == + executorch::runtime::Error::Ok, + "Failed to set output tensor for module %s", + method_name_.c_str()); + } + + // Run inference + auto logits_res = decoder_runner_->step(method_name_, inputs_); + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); + executorch::aten::Tensor& logits_tensor = logits_res.get(); + prev_pos = pos; + + // verification branch seq-id + size_t seq_id_best = 0; + // max hit pos + size_t i_batch_best = 0; + + // Lookahead decoding and verification + for (int v = 0; v < metadata_.ngram; ++v) { + // Verification + int i_batch = 0; + if (v > 0) { + for (int g = 0; g < v_branch_.size(); g++) { + // record the best matched seq and pos + if (v_branch_[g].active) { + i_batch = v_branch_[g].i_batch[v]; + i_batch_best = i_batch; + seq_id_best = v_branch_[g].seq_id; + ++n_accept; + break; + } + } + if (i_batch == 0) { + break; + } + } + + size_t sample_idx; + if (seq_id_best == 0) + sample_idx = 0; + else + sample_idx = metadata_.window * (metadata_.ngram - 1) + + (seq_id_best - (metadata_.window + 1)) * (metadata_.ngram - 1) + + i_batch - 1; + + // vector selected set + selected[sample_idx] = true; + + prev_token = cur_token; + // sampler from logits all + stats_->on_sampling_begin(); + cur_token = decoder_runner_->logits_to_token(logits_tensor, sample_idx); + stats_->on_sampling_end(); + result_tokens.push_back(cur_token); + pos++; + + // print the token as string, decode it with the Tokenizer object + token_callback( + ET_UNWRAP_TOKENIZER(tokenizer_->decode(prev_token, cur_token))); + + // data-dependent terminating condition: we have n_eos_ number of EOS + if (eos_ids_->count(cur_token) > 0) { + printf("\n"); + ET_LOG(Info, "\nReached to the end of generation"); + break; + } + + // if verify pass, check the next sample token until verifying failed + for (int g = 0; g < v_branch_.size(); g++) { + // update the n-gram active status + if (v_branch_[g].active) { + if (v == metadata_.ngram - 1) { + v_branch_[g].active = false; + } else { + if (cur_token != v_branch_[g].tokens[v + 1]) { + v_branch_[g].active = false; + } + } + } + } + + // only update n-grams pools and lookahead branch when v=0 + if (v == 0) { + // update lookahead branch + update_lookahead_branch(logits_tensor); + // update n-grams pool + update_ngrams_pool(); + } + } // end of verify loop + + if (pos > metadata_.context_len - metadata_.ar_len) { + printf("\n"); + ET_LOG(Info, "\nReached to the maximum sequence length"); + break; + } + // Update KV Cache with the output results + int32_t n_update = pos - prev_pos; + kv_manager_->update_cache(metadata_.ar_len, prev_pos, n_update, selected); + + // Update attention mask with current position + kv_manager_->update_attention_mask( + attention_mask_.data, metadata_.ar_len, prev_pos, n_update); + + // data-dependent terminating condition: we have n_eos_ number of EOS + if (eos_ids_->count(cur_token) > 0) { + printf("\n"); + ET_LOG(Info, "\nReached to the end of generation"); + break; + } + } + ET_LOG( + Info, + "Lookahead Decoding: n_generated = %ld / n_accept = %d", + pos - start_pos, + n_accept); + + return pos - start_pos; +} +} // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h new file mode 100644 index 0000000000..cf500d7e43 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +namespace example { +/** + * @class LhdTokenGenerator + * @brief Class for generating the token using decoder and key-value manager + * with lookahead decoding. + */ +class LhdTokenGenerator : public TokenGenerator { + public: + struct Metadata { + int32_t context_len; + int64_t num_heads; + int64_t num_layers; + int32_t ar_len; + int32_t vocab_size; + bool use_int64_token; + int32_t ngram; + int32_t window; + int32_t gcap; + }; + LhdTokenGenerator( + tokenizers::Tokenizer* tokenizer, + DecoderRunner* decoder_runner, + KVManager* kv_manager, + const std::string& forward_name, + std::unique_ptr>&& eos_ids, + Metadata metadata, + executorch::llm::Stats* stats) + : TokenGenerator( + tokenizer, + decoder_runner, + kv_manager, + forward_name, + std::move(eos_ids), + TokenGenerator::Metadata{ + metadata.context_len, + metadata.num_heads, + metadata.num_layers, + metadata.ar_len, + metadata.vocab_size, + metadata.use_int64_token}, + stats), + metadata_(metadata), + ngrams_pool_(metadata.vocab_size, metadata.ngram, metadata.gcap), + lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), + lhd_branch_prev_(metadata.window) { + ET_LOG( + Info, + "Use Lookahead decoding: ngram=%d, window=%d, gcap=%d", + metadata.ngram, + metadata.window, + metadata.gcap); + } + + ~LhdTokenGenerator() = default; + + /** +    * @brief Generate tokens with lookahead decoding. +    * @param tokens Vector of input tokens. +    * @param start_pos Starting position for generation. +    * @param seq_len Length of the sequence to generate. +    * @param token_callback Callback function for generated tokens. +    * @return The number of tokens generated. +    */ + executorch::runtime::Result generate( + std::vector tokens, + int64_t start_pos, + int32_t seq_len, + std::function token_callback) override; + + private: + /** + * @brief Fill in I/O buffers with prompt token and position. + * @param cur_token Current token. + * @param start_pos Starting position. + */ + void prepare_io( + std::vector input_tokens, + std::vector input_pos); + void init_attention_mask(int32_t n_past); + void init_lookahead_branch(const std::vector& tokens); + void init_verification_branch(uint64_t cur_token); + void update_lookahead_branch(const executorch::aten::Tensor& logits_tensor); + void update_ngrams_pool(); + struct NgramData { + bool active = false; + int32_t seq_id = -1; + + // match pos + std::vector i_batch; + std::vector tokens; + }; + + // n-gram pool + struct NgramContainer { + NgramContainer(int n_vocab, int n, int g) { + cnt.resize(n_vocab); + head.resize(n_vocab); + tokens.resize(n_vocab * g * (n - 1)); + } + + int n_total = 0; + + std::vector cnt; + std::vector head; + + // [n_vocab][G][N - 1] + // for each token of the vocab, keep a ring-buffer of capacity G of n-grams + // of size N - 1 + std::vector tokens; + }; + + Metadata metadata_; + + // lookahead branch + bool is_lhd_branch_initialized_{false}; + // [N - 1][W] + std::vector> lhd_branch_; + // [W] + std::vector lhd_branch_prev_; + + // verification branch + std::vector v_branch_; + + // n-gram pools + NgramContainer ngrams_pool_; +}; +} // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp index 37dce8f06c..4a1a62c8e1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp @@ -256,7 +256,7 @@ Result PromptProcessor::prefill( n_update = 1 + ((num_prompt_tokens - 1) % metadata_.ar_len); } // Update KV Cache with the output results - kv_manager_->update_cache(metadata_.ar_len, pos, n_update); + kv_manager_->update_cache(metadata_.ar_len, pos, n_update, {}); // Update attention mask with current position kv_manager_->update_attention_mask( attention_mask_.data, metadata_.ar_len, pos, n_update); diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index bdc2019352..7a054d8e2a 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -57,13 +58,20 @@ Runner::Runner( const std::string& performance_output_path, const float temperature, const int eval_mode, - const std::string& kv_updater) + const std::string& kv_updater, + const int ngram, + const int window, + const int gcap) : tokenizer_path_(tokenizer_path), performance_output_path_(performance_output_path), temperature_(temperature), - eval_mode_(static_cast(eval_mode)) { + eval_mode_(static_cast(eval_mode)), + ngram_(ngram), + window_(window), + gcap_(gcap) { module_ = std::make_unique( model_path, Module::LoadMode::MmapUseMlockIgnoreErrors); + stats_.reset(); if (kv_updater == "SmartMask") { kv_updater_ = KVManagerMode::SMART_MASK; } else if (kv_updater == "ShiftPointer") { @@ -96,6 +104,7 @@ Error Runner::load() { method_names.emplace_back(token_generator_method_name); break; case EvalMode::kHybrid: + case EvalMode::kLookaheadDecoding: prompt_processor_method_name = "prefill_forward"; token_generator_method_name = "kv_forward"; method_names.emplace_back(prompt_processor_method_name); @@ -162,7 +171,9 @@ Error Runner::load() { context_len_ = atten_mask_meta_token->sizes()[2]; if (eval_mode_ == EvalMode::kKVCached) { prompt_processor_ar_len = token_generator_ar_len; - } else if (eval_mode_ == EvalMode::kHybrid) { + } else if ( + eval_mode_ == EvalMode::kHybrid || + eval_mode_ == EvalMode::kLookaheadDecoding) { auto atten_mask_meta_prompt = module_->method_meta(prompt_processor_method_name) ->input_tensor_meta(1); @@ -196,21 +207,40 @@ Error Runner::load() { prompt_processor_ar_len, vocab_size, use_int64_token}); - token_generator_ = std::make_unique( - tokenizer_.get(), - decoder_runner_.get(), - kv_manager_.get(), - token_generator_method_name, - std::move(eos_ids), - TokenGenerator::Metadata{ - context_len_, - num_heads, - num_layers, - token_generator_ar_len, - vocab_size, - use_int64_token, - }, - &stats_); + if (eval_mode_ == EvalMode::kLookaheadDecoding) { + token_generator_ = std::make_unique( + tokenizer_.get(), + decoder_runner_.get(), + kv_manager_.get(), + token_generator_method_name, + std::move(eos_ids), + LhdTokenGenerator::Metadata{ + context_len_, + num_heads, + num_layers, + token_generator_ar_len, + vocab_size, + use_int64_token, + ngram_, + window_, + gcap_}, + &stats_); + } else { + token_generator_ = std::make_unique( + tokenizer_.get(), + decoder_runner_.get(), + kv_manager_.get(), + token_generator_method_name, + std::move(eos_ids), + TokenGenerator::Metadata{ + context_len_, + num_heads, + num_layers, + token_generator_ar_len, + vocab_size, + use_int64_token}, + &stats_); + } buffer_manager_ = std::make_unique(); if (kv_updater_ == KVManagerMode::SMART_MASK) { diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 708f91157a..c318da5020 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -38,7 +38,10 @@ class Runner { const std::string& performance_output_path, const float temperature = 0.8f, const int eval_mode = EvalMode::kKVCached, - const std::string& kv_updater = "SmartMask"); + const std::string& kv_updater = "SmartMask", + const int ngram = 0, + const int window = 0, + const int gcap = 0); bool is_loaded() const; executorch::runtime::Error load(); @@ -57,12 +60,16 @@ class Runner { enum EvalMode { kKVCached = 0, kHybrid, + kLookaheadDecoding, kUnsupported, }; std::unique_ptr module_; int32_t context_len_{0}; + int ngram_{0}; + int window_{0}; + int gcap_{0}; int64_t cur_pos_{0}; std::string tokenizer_path_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index 8d890637b1..da20517925 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -231,7 +231,7 @@ Result TokenGenerator::generate( stats_->on_sampling_end(); // Update KV Cache with the output results - kv_manager_->update_cache(metadata_.ar_len, pos, metadata_.ar_len); + kv_manager_->update_cache(metadata_.ar_len, pos, metadata_.ar_len, {}); // Update attention mask with current position kv_manager_->update_attention_mask( attention_mask_.data, metadata_.ar_len, pos, metadata_.ar_len); diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h index a5d6965795..d2dd4afd19 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h @@ -38,13 +38,15 @@ class TokenGenerator { std::unique_ptr>&& eos_ids, Metadata metadata, executorch::llm::Stats* stats); + + virtual ~TokenGenerator() = default; /** * @brief Initialize I/O tensor and allocate I/O data buffer. * @param buffer_manager Pointer to IMemAlloc instance which depends on * kv_updater. * @param method_meta Method metadata. */ - void init_io( + virtual void init_io( IMemAlloc* buffer_manager, executorch::runtime::Result method_meta); @@ -56,7 +58,7 @@ class TokenGenerator {    * @param token_callback Callback function for generated tokens.    * @return The number of tokens generated.    */ - executorch::runtime::Result generate( + virtual executorch::runtime::Result generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -66,23 +68,13 @@ class TokenGenerator { logits_.size; } - private: - /** - * @brief Fill in I/O buffers with prompt token and position. - * @param cur_token Current token. - * @param start_pos Starting position. - */ - void prepare_io(uint64_t cur_token, int64_t start_pos); - + protected: tokenizers::Tokenizer* tokenizer_; DecoderRunner* decoder_runner_; KVManager* kv_manager_; std::string method_name_; std::unique_ptr> eos_ids_; - // metadata - Metadata metadata_; - // inputs and outputs TensorStruct input_toks_; TensorStruct input_pos_; @@ -105,5 +97,16 @@ class TokenGenerator { // stats executorch::llm::Stats* stats_; + + private: + /** + * @brief Fill in I/O buffers with prompt token and position. + * @param cur_token Current token. + * @param start_pos Starting position. + */ + void prepare_io(uint64_t cur_token, int64_t start_pos); + + // metadata + Metadata metadata_; }; } // namespace example