@inproceedings{nishida-etal-2025-long,
title = "Long-Tail Crisis in Nearest Neighbor Language Models",
author = "Nishida, Yuto and
Morishita, Makoto and
Deguchi, Hiroyuki and
Kamigaito, Hidetaka and
Watanabe, Taro",
editor = "Chiruzzo, Luis and
Ritter, Alan and
Wang, Lu",
booktitle = "Findings of the Association for Computational Linguistics: NAACL 2025",
month = apr,
year = "2025",
address = "Albuquerque, New Mexico",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2025.findings-naacl.331/",
pages = "5965--5978",
ISBN = "979-8-89176-195-7",
abstract = "The $k$-nearest-neighbor language model ($k$NN-LM), one of the retrieval-augmented language models, improves the perplexity for given text by directly accessing a large datastore built from any text data during inference.A widely held hypothesis for the success of $k$NN-LM is that its explicit memory, i.e., the datastore, enhances predictions for long-tail phenomena.However, prior works have primarily shown its ability to retrieve long-tail contexts, leaving the model`s performance remain underexplored in estimating the probabilities of long-tail target tokens during inference.In this paper, we investigate the behavior of $k$NN-LM on low-frequency tokens, examining prediction probability, retrieval accuracy, and token distribution in the datastore.Our experimental results reveal that $k$NN-LM does not improve prediction performance for low-frequency tokens but mainly benefits high-frequency tokens regardless of long-tail contexts in the datastore."
}
git clone https://github.com/naist-nlp/knnlm-longtail-analysis.git
cd knnlm-longtail-analysis/
pip install requirements.txt
cd data/
cat resplit_wikitext.tar.zst.shard* > resplit_wikitext.tar.zst
sha256sum -c resplit_wikitext.tar.zst.sha256
# resplit_wikitext.tar.zst: OK
tar --zstd -xf resplit_wikitext.tar.zst
cd ../
MODEL=gpt2-xl
eval_subset=validation
data_dir=data/resplit_wikitext
train_path=${data_dir}/train.txt
eval_subset_path=${data_dir}/${eval_subset}.txt
output_dir=checkpoints/dstores/${MODEL}
dstore_dir=dstores/${MODEL}
stats_dir=stats/${MODEL}
# save datastore
python -u run_clm.py \
--model_name_or_path ${MODEL} \
--train_file ${train_path} \
--validation_file ${eval_subset_path} \
--do_eval --eval_subset train \
--output_dir ${output_dir} \
--dstore_dir ${dstore_dir} \
--save_knnlm_dstore
# build FAISS index
dstore_size=`ls ${dstore_dir}/dstore_*_keys.npy | grep -oP '(?<=_)([0-9]+)(?=_)' | head -1`
python -u run_clm.py \
--model_name_or_path ${MODEL} \
--train_file ${train_path} \
--validation_file ${eval_subset_path} \
--output_dir ${output_dir} \
--dstore_dir ${dstore_dir} \
--dstore_size ${dstore_size} \
--build_index
for num_neighbors in 16 1024; do
for knn_temp in 1 10; do
knnlm_stats_file=${stats_dir}/knnlm_stats_k${num_neighbors}_temp${knn_temp}.pt
python -u run_clm.py \
--model_name_or_path ${MODEL} \
--cache_dir ${cache_dir} \
--train_file ${train_path} \
--validation_file ${eval_subset_path} \
--output_dir ${output_dir} \
--dstore_size ${dstore_size} \
--dstore_dir ${dstore_dir} \
--knn \
--knn_gpu \
--lmbda 0.25 \
--k ${num_neighbors} \
--knn_temp ${knn_temp} \
--knn_prob_save_path ${knnlm_stats_file} \
--do_eval
done
done
for num_neighbors in 16 1024; do
for knn_temp in 1 10; do
python preprocess_probability.py \
-m ${MODEL} -k ${num_neighbors} -t ${knn_temp}
done
done
python plot_probability.py -m ${MODEL} -t prob