Skip to content

Qualcomm AI Engine Direct - Enable Lookahead Decoding #11437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/qualcomm/oss_scripts/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ list(
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.h
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp
Expand Down
75 changes: 69 additions & 6 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import getpass
import json
import logging
import math
import os
import subprocess
import sys
Expand Down Expand Up @@ -90,6 +91,12 @@
logging.getLogger().setLevel(logging.INFO)


def next_power_of_two(n):
if n == 0:
return 1
return 2 ** math.ceil(math.log2(n))


def smart_mask_updater(
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
):
Expand Down Expand Up @@ -531,6 +538,28 @@ def compile(args, pte_filename, tokenizer):
use_i64_token=use_i64_token,
)
)
elif args.model_mode == "lookahead":
llama_instance_list.append(
LlamaModel(
kv_config,
# To get better performance, we round up to the nearest power of 2.
ar_len=next_power_of_two(
(args.window + args.gcap) * (args.ngram - 1)
),
output_new_cache_only=True,
output_cache=True,
use_i64_token=use_i64_token,
)
)
llama_instance_list.append(
LlamaModel(
prefill_config,
ar_len=args.prefill_ar_len,
output_new_cache_only=True,
output_cache=True,
use_i64_token=use_i64_token,
)
)
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down Expand Up @@ -630,8 +659,8 @@ def permute(w, heads):
tokenizer=tokenizer,
custom_annotations=custom_annotations,
)
# If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
if i == 0 and args.model_mode == "hybrid":
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
output_indices = 0
for node in llama_instance.llama_graph_module.graph.nodes:
if node.op == "output":
Expand Down Expand Up @@ -673,7 +702,7 @@ def permute(w, heads):
shared_buffer=args.shared_buffer,
)
quant_attrs = llama_instance_list[0].get_quant_attrs()
elif args.model_mode == "hybrid":
elif args.model_mode in ["hybrid", "lookahead"]:
sample_inputs_list = [
llama_instace.inputs for llama_instace in llama_instance_list
]
Expand Down Expand Up @@ -763,6 +792,8 @@ def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
eval_mode = 0
elif args.model_mode == "hybrid":
eval_mode = 1
elif args.model_mode == "lookahead":
eval_mode = 2
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down Expand Up @@ -836,6 +867,9 @@ def post_process():
"--output_path outputs/outputs.txt",
f"--performance_output_path {performance_output_path}",
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
f"--window {args.window}",
f"--gcap {args.gcap}",
f"--ngram {args.ngram}",
runner_args,
]
)
Expand Down Expand Up @@ -975,9 +1009,9 @@ def _build_parser():

parser.add_argument(
"--model_mode",
help="Export and inference kv mode or hybrid mode",
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
default="kv",
choices=["kv", "hybrid"],
choices=["kv", "hybrid", "lookahead"],
type=str,
)

Expand All @@ -990,7 +1024,7 @@ def _build_parser():

parser.add_argument(
"--prefill_ar_len",
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.",
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid and lookahead mode.",
default=32,
type=int,
)
Expand All @@ -1011,6 +1045,27 @@ def _build_parser():
help="Fallback to cpu embedding operator and type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '4,32'.",
)

parser.add_argument(
"--ngram",
help="Represents the size of the n-grams used in the lookahead process.",
default=5,
type=int,
)

parser.add_argument(
"--window",
help="Determines how many future tokens the algorithm attempts to predict in each step.",
default=8,
type=int,
)

parser.add_argument(
"--gcap",
help="Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.",
default=8,
type=int,
)

parser.add_argument("-v", "--verbose", action="store_true")

return parser
Expand All @@ -1027,6 +1082,14 @@ def export_llama(args) -> None:
args.max_seq_len >= args.prefill_ar_len
), "Please ensure max_seq_len is >= prefill_ar_len"
pte_filename = "hybrid_llama_qnn"
elif args.model_mode == "lookahead":
assert (
args.max_seq_len >= args.prefill_ar_len
), "Please ensure max_seq_len is >= prefill_ar_len"
assert args.max_seq_len > next_power_of_two(
(args.window + args.gcap) * (args.ngram - 1)
), "Please ensure max_seq_len is > next_power_of_two((args.window + args.gcap) * (args.ngram - 1))"
pte_filename = "lookahead_llama_qnn"
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down
23 changes: 19 additions & 4 deletions examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,24 @@ DEFINE_int32(
DEFINE_int32(
eval_mode,
0,
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
DEFINE_string(
kv_updater,
"How to update kv cache. Choose between SmartMask and ShiftPointer",
"SmartMask");
"SmartMask",
"How to update kv cache. Choose between SmartMask and ShiftPointer");
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
DEFINE_int32(
ngram,
0,
"[Lookahead Decoding] Represents the size of the n-grams used in the lookahead process.");
DEFINE_int32(
window,
0,
"[Lookahead Decoding] Determines how many future tokens the algorithm attempts to predict in each step.");
DEFINE_int32(
gcap,
0,
"[Lookahead Decoding] Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.");

std::vector<std::string> CollectPrompts(int argc, char** argv) {
// Collect all prompts from command line, example usage:
Expand Down Expand Up @@ -111,7 +123,10 @@ int main(int argc, char** argv) {
FLAGS_performance_output_path.c_str(),
FLAGS_temperature,
FLAGS_eval_mode,
FLAGS_kv_updater);
FLAGS_kv_updater,
FLAGS_ngram,
FLAGS_window,
FLAGS_gcap);
auto llama_version = runner.get_llama_version();
std::vector<char> buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
Expand Down
85 changes: 66 additions & 19 deletions examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void KVManager::init_attention_mask(
int32_t ar_len,
int32_t n_past) {
ET_CHECK_MSG(
attention_map.size() == ar_len,
attention_map.size() <= ar_len,
"The size of attention_map (%zu) doesn't match with ar_len (%d)",
attention_map.size(),
ar_len);
Expand Down Expand Up @@ -197,9 +197,11 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) {
? 0
: metadata_.max_cache_len - (metadata_.context_len - cur_ar_len_);
v_cache_[layer][head].buffer = single_layer_v_cache +
head * single_head_size_in + cache_gap * metadata_.head_dim;
v_cache_[layer][head].output_buffer =
single_layer_v_cache + (head + 1) * single_head_size_in;
head * metadata_.head_dim * metadata_.context_len +
cache_gap * metadata_.head_dim;
v_cache_[layer][head].output_buffer = single_layer_v_cache +
head * metadata_.head_dim * metadata_.context_len +
single_head_size_in;
}
}
break;
Expand Down Expand Up @@ -311,21 +313,29 @@ bool KVManager::update_cache_tensor(
return updated;
}

void KVManager::update_cache(int32_t ar_len, int32_t n_past, int32_t n_update) {
void KVManager::update_cache(
int32_t ar_len,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected) {
ET_CHECK_MSG(
cur_ar_len_ == ar_len,
"Current AR length (%d) is not matched with target AR length (%d). Please rearrange cache first.",
cur_ar_len_,
ar_len);
for (int layer = 0; layer < metadata_.num_layers; ++layer) {
for (int head = 0; head < metadata_.num_heads; ++head) {
update_key(k_cache_[layer][head], n_past, n_update);
update_value(v_cache_[layer][head], n_past, n_update);
update_key(k_cache_[layer][head], n_past, n_update, selected);
update_value(v_cache_[layer][head], n_past, n_update, selected);
}
}
}

void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
void KVManager::update_key(
KVCache& k_cache,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected) {
uint8_t* write_ptr = k_cache.buffer;
uint8_t* read_ptr = k_cache.output_buffer;
const int32_t copy_size = n_update * sizeof(uint8_t);
Expand All @@ -340,22 +350,35 @@ void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
write_ptr += iter_size + past_size;
if (kv_updater_ == KVManagerMode::SMART_MASK)
write_ptr += past_size;

for (int i = 0; i < n_iter; ++i) {
std::memcpy(write_ptr, read_ptr, copy_size);
write_ptr += iter_size;
read_ptr += out_size;
if (selected.empty()) {
for (int i = 0; i < n_iter; ++i) {
std::memcpy(write_ptr, read_ptr, copy_size);
write_ptr += iter_size;
read_ptr += out_size;
}
} else {
std::vector<int32_t> true_indices(n_update);
for (int i = 0, j = 0; i < selected.size() && j < n_update; ++i) {
if (selected[i]) {
true_indices[j++] = i;
}
}
for (int i = 0; i < n_iter; ++i) {
auto wp = write_ptr, rp = read_ptr;
for (auto ind : true_indices) {
*wp++ = rp[ind];
}
write_ptr += iter_size;
read_ptr += out_size;
}
}
}

void KVManager::update_value(
KVCache& v_cache,
int32_t n_past,
int32_t n_update) {
// Value cache doesn't need to copy for SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
return;

int32_t n_update,
const std::vector<bool>& selected) {
uint8_t* write_ptr = v_cache.buffer;
uint8_t* read_ptr = v_cache.output_buffer;
const int32_t copy_size = n_update * metadata_.head_dim * sizeof(uint8_t);
Expand All @@ -364,7 +387,31 @@ void KVManager::update_value(
if (kv_updater_ == KVManagerMode::SMART_MASK)
write_ptr += past_size;

std::memcpy(write_ptr, read_ptr, copy_size);
// Update the value cache for lookahead decoding in SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER) {
read_ptr += past_size;
write_ptr = read_ptr;
}

if (selected.empty()) {
// In general, value cache doesn't need to copy for SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
return;
std::memcpy(write_ptr, read_ptr, copy_size);
} else {
int32_t update_times = n_update;
auto wp = write_ptr, rp = read_ptr;
for (auto sel : selected) {
if (sel) {
std::memcpy(wp, rp, metadata_.head_dim * sizeof(uint8_t));
wp += metadata_.head_dim;
update_times--;
if (update_times == 0)
break;
}
rp += metadata_.head_dim;
}
}
}

} // namespace example
19 changes: 16 additions & 3 deletions examples/qualcomm/oss_scripts/llama/runner/kv_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,13 @@ class KVManager {
* @param ar_len Length of input tokens.
* @param n_past Number of past elements in the cache.
* @param n_update Number of elements to be updated.
* @param selected Indicate which position to be updated
*/
void update_cache(int32_t ar_len, int32_t n_past, int32_t n_update);
void update_cache(
int32_t ar_len,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected);

const std::vector<std::vector<KVCache>>& get_k_cache_() const {
return k_cache_;
Expand All @@ -138,8 +143,16 @@ class KVManager {
// Helper functions to rearrange and update key and value caches
void rearrange_key(KVCache& k_cache, int32_t ar_len_dst);
void rearrange_value(KVCache& v_cache, int32_t ar_len_dst);
void update_key(KVCache& k_cache, int32_t n_past, int32_t n_update);
void update_value(KVCache& v_cache, int32_t n_past, int32_t n_update);
void update_key(
KVCache& k_cache,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected);
void update_value(
KVCache& v_cache,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected);
KVManagerMode kv_updater_;

// metadata
Expand Down
Loading
Loading