diff --git a/scripts/hn_mine.py b/scripts/hn_mine.py index b1d4718..be73625 100644 --- a/scripts/hn_mine.py +++ b/scripts/hn_mine.py @@ -159,6 +159,12 @@ def find_knn_neg( p_vecs = model.encode(corpus) print(f'inferencing embedding for queries (number={len(queries)})--------------') q_vecs = model.encode_queries(queries) + + # check if the embeddings are in dictionary format: M3Embedder + if isinstance(p_vecs, dict): + p_vecs = p_vecs["dense_vecs"] + if isinstance(q_vecs, dict): + q_vecs = q_vecs["dense_vecs"] print('create index and search------------------') index = create_index(p_vecs, use_gpu=use_gpu)