Skip to content

Commit 1163d64

Browse files
Distance array now stores distances in float32
Fixes `/opt/conda/envs/lmi/lib/python3.11/site-packages/torch/_tensor.py:1089: RuntimeWarning: overflow encountered in cast`
1 parent ace60af commit 1163d64

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

task1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
45
os.environ['MKL_NUM_THREADS'] = '27'
56
os.environ['OMP_NUM_THREADS'] = '27'
67
os.environ['OMP_DYNAMIC'] = 'FALSE'
@@ -126,13 +127,13 @@ def _visit_buckets(
126127
def search(self, queries: Tensor, k: int, nprobe: int = 100) -> tuple[np.ndarray, np.ndarray]:
127128
predicted_bucket_ids = self._predict(queries, nprobe)
128129
n_queries = queries.shape[0]
129-
D = np.empty((n_queries, k), dtype=np.float16)
130+
D = np.empty((n_queries, k), dtype=np.float32)
130131
I = np.empty((n_queries, k), dtype=np.int32)
131132

132133
torch.set_num_threads(3)
133134
faiss.omp_set_num_threads(3)
134135

135-
with ThreadPoolExecutor(max_workers=9) as executor:
136+
with ThreadPoolExecutor(max_workers=9) as executor:
136137
results = executor.map(
137138
lambda i: self._visit_buckets(k, predicted_bucket_ids[i], queries[i : i + 1], i, nprobe),
138139
range(n_queries),

task2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
45
os.environ['MKL_NUM_THREADS'] = '4'
56
os.environ['OMP_NUM_THREADS'] = '4'
67
os.environ['OMP_DYNAMIC'] = 'FALSE'
@@ -148,7 +149,7 @@ def search(
148149
) -> tuple[np.ndarray, np.ndarray]:
149150
predicted_bucket_ids = self._predict(queries, nprobe)
150151
n_queries = queries.shape[0]
151-
D = np.empty((n_queries, k), dtype=np.float16)
152+
D = np.empty((n_queries, k), dtype=np.float32)
152153
I = np.empty((n_queries, k), dtype=np.int32)
153154

154155
torch.set_num_threads(2)

task3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
45
os.environ['MKL_NUM_THREADS'] = '27'
56
os.environ['OMP_NUM_THREADS'] = '27'
67
os.environ['OMP_DYNAMIC'] = 'FALSE'
@@ -135,13 +136,13 @@ def search(
135136
) -> tuple[np.ndarray, np.ndarray]:
136137
predicted_bucket_ids = self._predict(queries, nprobe)
137138
n_queries = queries.shape[0]
138-
D = np.empty((n_queries, k), dtype=np.float16)
139+
D = np.empty((n_queries, k), dtype=np.float32)
139140
I = np.empty((n_queries, k), dtype=np.int32)
140141

141142
torch.set_num_threads(3)
142143
faiss.omp_set_num_threads(3)
143144

144-
with ThreadPoolExecutor(max_workers=9) as executor: # max_workers=9
145+
with ThreadPoolExecutor(max_workers=9) as executor: # max_workers=9
145146
results = executor.map(
146147
lambda i: self._visit_buckets(k, predicted_bucket_ids[i], decomposed_queries[i : i + 1], i, nprobe),
147148
range(n_queries),

0 commit comments

Comments
 (0)